Skip to content
Closed
4 changes: 3 additions & 1 deletion python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
- :class:`pyspark.sql.Window`
For working with window functions.
"""
from pyspark.sql.types import Row, VariantVal
from pyspark.sql.types import Geography, Geometry, Row, VariantVal
from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, UDTFRegistration
from pyspark.sql.session import SparkSession
from pyspark.sql.column import Column
Expand Down Expand Up @@ -69,6 +69,8 @@
"DataFrameNaFunctions",
"DataFrameStatFunctions",
"VariantVal",
"Geography",
"Geometry",
"Window",
"WindowSpec",
"DataFrameReader",
Expand Down
60 changes: 60 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
UserDefinedType,
VariantType,
VariantVal,
GeometryType,
Geometry,
GeographyType,
Geography,
_create_row,
)
from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError
Expand Down Expand Up @@ -202,6 +206,18 @@ def to_arrow_type(
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
]
arrow_type = pa.struct(fields)
elif type(dt) == GeometryType:
fields = [
pa.field("srid", pa.int32(), nullable=False),
pa.field("wkb", pa.binary(), nullable=False, metadata={b"geometry": b"true", b"srid": str(dt.srid)}),
]
arrow_type = pa.struct(fields)
elif type(dt) == GeographyType:
fields = [
pa.field("srid", pa.int32(), nullable=False),
pa.field("wkb", pa.binary(), nullable=False, metadata={b"geography": b"true", b"srid": str(dt.srid)}),
]
arrow_type = pa.struct(fields)
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
Expand Down Expand Up @@ -272,6 +288,38 @@ def is_variant(at: "pa.DataType") -> bool:
) and any(field.name == "value" for field in at)


def is_geometry(at: "pa.DataType") -> bool:
"""Check if a PyArrow struct data type represents a geometry"""
import pyarrow.types as types

assert types.is_struct(at)

return any(
(
field.name == "wkb"
and b"geometry" in field.metadata
and field.metadata[b"geometry"] == b"true"
)
for field in at
) and any(field.name == "srid" for field in at)


def is_geography(at: "pa.DataType") -> bool:
"""Check if a PyArrow struct data type represents a geography"""
import pyarrow.types as types

assert types.is_struct(at)

return any(
(
field.name == "wkb"
and b"geography" in field.metadata
and field.metadata[b"geography"] == b"true"
)
for field in at
) and any(field.name == "srid" for field in at)


def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType:
"""Convert pyarrow type to Spark data type."""
import pyarrow.types as types
Expand Down Expand Up @@ -337,6 +385,18 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
elif types.is_struct(at):
if is_variant(at):
return VariantType()
elif is_geometry(at):
srid = int(at.field("wkb").metadata.get(b"srid"))
if srid == GeometryType.MIXED_SRID:
return GeometryType("ANY")
else:
return GeometryType(srid)
elif is_geography(at):
srid = int(at.field("wkb").metadata.get(b"srid"))
if srid == GeographyType.MIXED_SRID:
return GeographyType("ANY")
else:
return GeographyType(srid)
return StructType(
[
StructField(
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_apply_schema_to_dict_and_rows(self):
def test_apply_schema_to_row(self):
super().test_apply_schema_to_row()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_geospatial_create_dataframe_rdd(self):
super().test_geospatial_create_dataframe_rdd()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_create_dataframe_schema_mismatch(self):
super().test_create_dataframe_schema_mismatch()
Expand Down
Loading