diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index a0a6e8ef70c8..eeeeddd00e3a 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -39,7 +39,7 @@ - :class:`pyspark.sql.Window` For working with window functions. """ -from pyspark.sql.types import Row, VariantVal +from pyspark.sql.types import Geography, Geometry, Row, VariantVal from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, UDTFRegistration from pyspark.sql.session import SparkSession from pyspark.sql.column import Column @@ -69,6 +69,8 @@ "DataFrameNaFunctions", "DataFrameStatFunctions", "VariantVal", + "Geography", + "Geometry", "Window", "WindowSpec", "DataFrameReader", diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 327e3941d938..d8a45daa77e8 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -50,6 +50,10 @@ UserDefinedType, VariantType, VariantVal, + GeometryType, + Geometry, + GeographyType, + Geography, _create_row, ) from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError @@ -202,6 +206,28 @@ def to_arrow_type( pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}), ] arrow_type = pa.struct(fields) + elif type(dt) == GeometryType: + fields = [ + pa.field("srid", pa.int32(), nullable=False), + pa.field( + "wkb", + pa.binary(), + nullable=False, + metadata={b"geometry": b"true", b"srid": str(dt.srid)}, + ), + ] + arrow_type = pa.struct(fields) + elif type(dt) == GeographyType: + fields = [ + pa.field("srid", pa.int32(), nullable=False), + pa.field( + "wkb", + pa.binary(), + nullable=False, + metadata={b"geography": b"true", b"srid": str(dt.srid)}, + ), + ] + arrow_type = pa.struct(fields) else: raise PySparkTypeError( errorClass="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", @@ -272,6 +298,38 @@ def is_variant(at: "pa.DataType") -> bool: ) and any(field.name == "value" for field in at) +def is_geometry(at: "pa.DataType") -> bool: + """Check if a PyArrow struct data type represents a geometry""" + import pyarrow.types as types + + assert types.is_struct(at) + + return any( + ( + field.name == "wkb" + and b"geometry" in field.metadata + and field.metadata[b"geometry"] == b"true" + ) + for field in at + ) and any(field.name == "srid" for field in at) + + +def is_geography(at: "pa.DataType") -> bool: + """Check if a PyArrow struct data type represents a geography""" + import pyarrow.types as types + + assert types.is_struct(at) + + return any( + ( + field.name == "wkb" + and b"geography" in field.metadata + and field.metadata[b"geography"] == b"true" + ) + for field in at + ) and any(field.name == "srid" for field in at) + + def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType: """Convert pyarrow type to Spark data type.""" import pyarrow.types as types @@ -337,6 +395,18 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da elif types.is_struct(at): if is_variant(at): return VariantType() + elif is_geometry(at): + srid = int(at.field("wkb").metadata.get(b"srid")) + if srid == GeometryType.MIXED_SRID: + return GeometryType("ANY") + else: + return GeometryType(srid) + elif is_geography(at): + srid = int(at.field("wkb").metadata.get(b"srid")) + if srid == GeographyType.MIXED_SRID: + return GeographyType("ANY") + else: + return GeographyType(srid) return StructType( [ StructField( @@ -1098,6 +1168,40 @@ def convert_variant(value: Any) -> Any: return convert_variant + elif isinstance(dt, GeographyType): + + def convert_geography(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["wkb", "srid"]) + and isinstance(value["wkb"], bytes) + and isinstance(value["srid"], int) + ): + return Geography.fromWKB(value["wkb"], value["srid"]) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY") + + return convert_geography + + elif isinstance(dt, GeometryType): + + def convert_geometry(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["wkb", "srid"]) + and isinstance(value["wkb"], bytes) + and isinstance(value["srid"], int) + ): + return Geometry.fromWKB(value["wkb"], value["srid"]) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOMETRY") + + return convert_geometry + else: return None @@ -1360,6 +1464,22 @@ def convert_variant(variant: Any) -> Any: return convert_variant + elif isinstance(dt, GeographyType): + + def convert_geography(value: Any) -> Any: + assert isinstance(value, Geography) + return {"srid": value.srid, "wkb": value.wkb} + + return convert_geography + + elif isinstance(dt, GeometryType): + + def convert_geometry(value: Any) -> Any: + assert isinstance(value, Geometry) + return {"srid": value.srid, "wkb": value.wkb} + + return convert_geometry + return None conv = _converter(data_type) diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 6d06611def6a..a39e92233bc0 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -34,6 +34,10 @@ def test_apply_schema_to_dict_and_rows(self): def test_apply_schema_to_row(self): super().test_apply_schema_to_row() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") + def test_geospatial_create_dataframe_rdd(self): + super().test_geospatial_create_dataframe_rdd() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_create_dataframe_schema_mismatch(self): super().test_create_dataframe_schema_mismatch() diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 6979095acca8..4ff2ab3e5cd7 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -29,6 +29,8 @@ from pyspark.sql import functions as F from pyspark.errors import ( AnalysisException, + IllegalArgumentException, + SparkRuntimeException, ParseException, PySparkTypeError, PySparkValueError, @@ -51,6 +53,8 @@ MapType, StringType, CharType, + Geography, + Geometry, VarcharType, StructType, StructField, @@ -1365,6 +1369,7 @@ def test_parse_datatype_json_string(self): NullType(), GeographyType(4326), GeographyType("ANY"), + GeometryType(0), GeometryType(4326), GeometryType("ANY"), VariantType(), @@ -2447,6 +2452,323 @@ def test_variant_type(self): with self.assertRaises(PySparkValueError, msg="Rows cannot be of type VariantVal"): self.spark.createDataFrame([VariantVal.parseJson("2")], "v variant") + def test_geospatial_encoding(self): + df = self.spark.createDataFrame( + [ + ( + bytes.fromhex("0101000000000000000000F03F0000000000000040"), + 4326, + ) + ], + ["wkb", "srid"], + ) + row = df.select( + F.st_geomfromwkb(df.wkb).alias("geom"), + F.st_geogfromwkb(df.wkb).alias("geog"), + ).collect()[0] + + self.assertEqual(type(row["geom"]), Geometry) + self.assertEqual(type(row["geog"]), Geography) + self.assertEqual( + row["geom"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040") + ) + self.assertEqual(row["geom"].getSrid(), 0) + self.assertEqual( + row["geog"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040") + ) + self.assertEqual(row["geog"].getSrid(), 4326) + + def test_geospatial_create_dataframe_rdd(self): + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField("geom", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geog", GeographyType(4326), True), + ] + ) + geospatial_data = [ + ( + 1, + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + ), + ( + 2, + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + ), + ] + rdd_data = self.sc.parallelize(geospatial_data) + df = self.spark.createDataFrame(rdd_data, schema) + rows = df.select( + F.st_asbinary(df.geom).alias("geom_wkb"), + F.st_srid(df.geom).alias("geom_srid"), + F.st_asbinary(df.geom4326).alias("geom4326_wkb"), + F.st_srid(df.geom4326).alias("geom4326_srid"), + F.st_asbinary(df.geog).alias("geog_wkb"), + F.st_srid(df.geog).alias("geog_srid"), + ).collect() + + point0_wkb = bytes.fromhex("010100000000000000000031400000000000001c40") + point1_wkb = bytes.fromhex("010100000000000000000014400000000000001440") + self.assertEqual(rows[0]["geom_wkb"], point0_wkb) + self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb) + self.assertEqual(rows[0]["geog_wkb"], point0_wkb) + self.assertEqual(rows[1]["geom_wkb"], point1_wkb) + self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb) + self.assertEqual(rows[1]["geog_wkb"], point1_wkb) + self.assertEqual(rows[0]["geom_srid"], 0) + self.assertEqual(rows[0]["geom4326_srid"], 4326) + self.assertEqual(rows[0]["geog_srid"], 4326) + self.assertEqual(rows[1]["geom_srid"], 0) + self.assertEqual(rows[1]["geom4326_srid"], 4326) + self.assertEqual(rows[1]["geog_srid"], 4326) + schema_df = self.spark.createDataFrame(rdd_data).select( + F.col("_1").alias("id"), + F.col("_2").alias("geom"), + F.col("_3").alias("geom4326"), + F.col("_4").alias("geog"), + ) + self.assertEqual(df.collect(), schema_df.collect()) + + def test_geospatial_create_dataframe(self): + # Positive Test: Creating DataFrame from a list of tuples with explicit schema + geospatial_data = [ + ( + 1, + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + ), + ( + 2, + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + ), + ] + named_geospatial_data = [ + { + "id": 1, + "geom": Geometry.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 0 + ), + "geom4326": Geometry.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + "geog": Geography.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + }, + { + "id": 2, + "geom": Geometry.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 0 + ), + "geom4326": Geometry.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + "geog": Geography.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + }, + ] + GeospatialRow = Row("id", "geom", "geom4326", "geog") + spark_row_data = [ + GeospatialRow( + 1, + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + ), + GeospatialRow( + 2, + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + ), + ] + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField("geom", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geog", GeographyType(4326), True), + ] + ) + # Negative Test: Schema mismatch + mismatched_schema = StructType( + [ + StructField("id", IntegerType(), True), # Should be GeometryType + StructField("geom", GeometryType(4326), True), # Should be GeometryType + StructField("geom4326", GeometryType(4326), True), # Should be GeometryType + StructField("geog", GeographyType(4326), True), # Should be GeographyType + ] + ) + + # Explicitly validate single test case + # rest will be compared with this one. + df = self.spark.createDataFrame(geospatial_data, schema) + rows = df.select( + F.st_asbinary(df.geom).alias("geom_wkb"), + F.st_srid(df.geom).alias("geom_srid"), + F.st_asbinary(df.geom4326).alias("geom4326_wkb"), + F.st_srid(df.geom4326).alias("geom4326_srid"), + F.st_asbinary(df.geog).alias("geog_wkb"), + F.st_srid(df.geog).alias("geog_srid"), + ).collect() + + point0_wkb = bytes.fromhex("010100000000000000000031400000000000001c40") + point1_wkb = bytes.fromhex("010100000000000000000014400000000000001440") + self.assertEqual(rows[0]["geom_wkb"], point0_wkb) + self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb) + self.assertEqual(rows[0]["geog_wkb"], point0_wkb) + self.assertEqual(rows[1]["geom_wkb"], point1_wkb) + self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb) + self.assertEqual(rows[1]["geog_wkb"], point1_wkb) + self.assertEqual(rows[0]["geom_srid"], 0) + self.assertEqual(rows[0]["geom4326_srid"], 4326) + self.assertEqual(rows[0]["geog_srid"], 4326) + self.assertEqual(rows[1]["geom_srid"], 0) + self.assertEqual(rows[1]["geom4326_srid"], 4326) + self.assertEqual(rows[1]["geog_srid"], 4326) + + # Just the data set without parameters. + self.assertEqual( + self.spark.createDataFrame(named_geospatial_data) + .select("id", "geom", "geom4326", "geog") + .collect(), + df.collect(), + ) + self.assertEqual(self.spark.createDataFrame(geospatial_data).collect(), df.collect()) + self.assertEqual(self.spark.createDataFrame(spark_row_data).collect(), df.collect()) + + # Define DataFrame creation methods + datasets = [named_geospatial_data, geospatial_data, spark_row_data] + schemas = [ + schema, + "id INT, geom GEOMETRY(0), geom4326 GEOMETRY(4326), geog GEOGRAPHY(4326)", + ["id", "geom", "geom4326", "geog"], + ] + + for dataset_to_check, schema_to_check in zip(datasets, schemas): + df_to_check = self.spark.createDataFrame(dataset_to_check, schema_to_check).select( + "id", "geom", "geom4326", "geog" + ) + self.assertEqual(df_to_check.collect(), df.collect(), "DataFrame creation with schema") + + # Negative Test: Schema mismatch + for dataset_to_check in datasets: + with self.assertRaises(SparkRuntimeException) as pe: + self.spark.createDataFrame(dataset_to_check, mismatched_schema).collect() + + self.check_error( + exception=pe.exception, + errorClass="GEO_ENCODER_SRID_MISMATCH_ERROR", + messageParameters={"type": "GEOMETRY", "typeSrid": "4326", "valueSrid": "0"}, + ) + + def test_geospatial_schema_inferrence(self): + # Mixed data with different SRIDs + wkb = bytes.fromhex("010100000000000000000031400000000000001c40") + geometry_dataset = [ + (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 4326)), + (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 0)), + (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 4326)), + (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 0)), + ] + # Create DataFrame with mixed data types + df = self.spark.createDataFrame(geometry_dataset, ["geom0", "geom4326", "geom_any"]) + expected_schema = StructType( + [ + StructField("geom0", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geom_any", GeometryType("ANY"), True), + ] + ) + self.assertEqual(df.schema, expected_schema) + + rows = df.select( + F.st_asbinary("geom0").alias("geom0_wkb"), + F.st_srid("geom0").alias("geom0_srid"), + F.st_asbinary("geom4326").alias("geom4326_wkb"), + F.st_srid("geom4326").alias("geom4326_srid"), + F.st_asbinary("geom_any").alias("geom_any_wkb"), + F.st_srid("geom_any").alias("geom_any_srid"), + ).collect() + + point_wkb = bytes.fromhex("010100000000000000000031400000000000001c40") + self.assertEqual(rows[0]["geom0_wkb"], point_wkb) + self.assertEqual(rows[1]["geom0_wkb"], point_wkb) + self.assertEqual(rows[0]["geom4326_wkb"], point_wkb) + self.assertEqual(rows[1]["geom4326_wkb"], point_wkb) + self.assertEqual(rows[0]["geom_any_wkb"], point_wkb) + self.assertEqual(rows[1]["geom_any_wkb"], point_wkb) + self.assertEqual(rows[0]["geom0_srid"], 0) + self.assertEqual(rows[1]["geom0_srid"], 0) + self.assertEqual(rows[0]["geom4326_srid"], 4326) + self.assertEqual(rows[1]["geom4326_srid"], 4326) + self.assertEqual(rows[0]["geom_any_srid"], 4326) + self.assertEqual(rows[1]["geom_any_srid"], 0) + + def test_geospatial_mixed_check_srid_validity(self): + geometry_mixed_invalid_data = [ + (1, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0)), + (2, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 1)), + ] + + with self.assertRaises(IllegalArgumentException) as pe: + self.spark.createDataFrame(geometry_mixed_invalid_data).collect() + self.check_error( + exception=pe.exception, + errorClass="ST_INVALID_SRID_VALUE", + messageParameters={"srid": "1"}, + ) + + with self.assertRaises(IllegalArgumentException) as pe: + self.spark.createDataFrame( + geometry_mixed_invalid_data, "id INT, geom GEOMETRY(ANY)" + ).collect() + self.check_error( + exception=pe.exception, + errorClass="ST_INVALID_SRID_VALUE", + messageParameters={"srid": "1"}, + ) + + def test_geospatial_result_encoding(self): + point_wkb = "010100000000000000000031400000000000001c40" + point_bytes = bytes.fromhex(point_wkb) + df = self.spark.sql( + f""" + SELECT ST_GeomFromWKB(X'{point_wkb}') AS geom, + ST_GeogFromWKB(X'{point_wkb}') AS geog""" + ) + GeospatialRow = Row("geom", "geog") + self.assertEqual( + df.collect(), + [ + GeospatialRow( + Geometry.fromWKB(point_bytes, 0), + Geography.fromWKB(point_bytes, 4326), + ) + ], + ) + def test_to_ddl(self): schema = StructType().add("a", NullType()).add("b", BooleanType()).add("c", BinaryType()) self.assertEqual(schema.toDDL(), "a VOID,b BOOLEAN,c BINARY") diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8aae39880072..95307ea3859c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -2517,6 +2517,8 @@ def _assert_valid_collation_provider(provider: str) -> None: # Mapping Python types to Spark SQL DataType _type_mappings = { type(None): NullType, + Geometry: GeometryType, + Geography: GeographyType, bool: BooleanType, int: LongType, float: DoubleType, @@ -2648,6 +2650,12 @@ def _infer_type( return obj.__UDT__ dataType = _type_mappings.get(type(obj)) + if dataType is GeographyType: + assert isinstance(obj, Geography) + return GeographyType(obj.getSrid()) + if dataType is GeometryType: + assert isinstance(obj, Geometry) + return GeometryType(obj.getSrid()) if dataType is DecimalType: # the precision and scale of `obj` may be different from row to row. return DecimalType(38, 18) @@ -2915,6 +2923,10 @@ def new_name(n: str) -> str: return a elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType): return b + elif isinstance(a, GeometryType) and isinstance(b, GeometryType) and a.srid != b.srid: + return GeometryType("ANY") + elif isinstance(a, GeographyType) and isinstance(b, GeographyType) and a.srid != b.srid: + return GeographyType("ANY") elif isinstance(a, AtomicType) and isinstance(b, StringType): return b elif isinstance(a, StringType) and isinstance(b, AtomicType): @@ -3068,6 +3080,8 @@ def convert_struct(obj: Any) -> Optional[Tuple]: ArrayType: (list, tuple, array), MapType: (dict,), StructType: (tuple, list, dict), + GeometryType: (Geometry,), + GeographyType: (Geography,), VariantType: ( bool, int, @@ -3419,6 +3433,24 @@ def verify_variant(obj: Any) -> None: verify_value = verify_variant + elif isinstance(dataType, GeometryType): + + def verify_geometry(obj: Any) -> None: + assert_acceptable_types(obj) + verify_acceptable_types(obj) + assert isinstance(obj, Geometry) + + verify_value = verify_geometry + + elif isinstance(dataType, GeographyType): + + def verify_geography(obj: Any) -> None: + assert_acceptable_types(obj) + verify_acceptable_types(obj) + assert isinstance(obj, Geography) + + verify_value = verify_geography + else: def verify_default(obj: Any) -> None: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 212cc5db124c..33622ca7349a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -29,9 +29,9 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData, STUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{UTF8String, VariantVal} +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String, VariantVal} object EvaluatePython { @@ -43,7 +43,7 @@ object EvaluatePython { def needConversionInPython(dt: DataType): Boolean = dt match { case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType - | _: TimeType => true + | _: TimeType | _: GeometryType | _: GeographyType => true case _: StructType => true case _: UserDefinedType[_] => true case ArrayType(elementType, _) => needConversionInPython(elementType) @@ -92,6 +92,10 @@ object EvaluatePython { case (s: UTF8String, _: StringType) => s.toString + case (g: GeometryVal, gt: GeometryType) => STUtils.deserializeGeom(g, gt) + + case (g: GeographyVal, gt: GeographyType) => STUtils.deserializeGeog(g, gt) + case (bytes: Array[Byte], BinaryType) => if (binaryAsBytes) { new BytesWrapper(bytes) @@ -228,6 +232,23 @@ object EvaluatePython { ) } + case g: GeographyType => (obj: Any) => nullSafeConvert(obj) { + case s: java.util.HashMap[_, _] => + val geographySrid = s.get("srid").asInstanceOf[Int] + g.assertSridAllowedForType(geographySrid) + STUtils.stGeogFromWKB( + s.get("wkb").asInstanceOf[Array[Byte]]) + } + + case g: GeometryType => (obj: Any) => nullSafeConvert(obj) { + case s: java.util.HashMap[_, _] => + val geometrySrid = s.get("srid").asInstanceOf[Int] + g.assertSridAllowedForType(geometrySrid) + STUtils.stGeomFromWKB( + s.get("wkb").asInstanceOf[Array[Byte]], + geometrySrid) + } + case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty) }