Skip to content

Commit d6f1e22

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-54169][GEO][SQL] Introduce Geography and Geometry types in Arrow writer
### What changes were proposed in this pull request? Add Arrow serialization/deserialization support for `Geography` and `Geometry` types. ### Why are the changes needed? Supporting geospatial types for clients (Spark Connect / Thrift Server / etc.) which consume result sets in Arrow format. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit tests: - `ArrowUtilsSuite` - `ArrowWriterSuite` - `ArrowEncoderSuite` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52863 from uros-db/geo-arrow-serde. Authored-by: Uros Bojanic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 2ec7439 commit d6f1e22

File tree

9 files changed

+478
-17
lines changed

9 files changed

+478
-17
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util
19+
20+
/* The MaybeNull class is a utility that introduces controlled nullability into a sequence
21+
* of invocations. It is designed to return a ~null~ value at a specified interval while returning
22+
* the provided value for all other invocations.
23+
*/
24+
case class MaybeNull(interval: Int) {
25+
assert(interval > 1)
26+
private var invocations = 0
27+
def apply[T](value: T): T = {
28+
val result = if (invocations % interval == 0) {
29+
null.asInstanceOf[T]
30+
} else {
31+
value
32+
}
33+
invocations += 1
34+
result
35+
}
36+
}

sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,43 @@ private[sql] object ArrowUtils {
143143
largeVarTypes)).asJava)
144144
case udt: UserDefinedType[_] =>
145145
toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
146+
case g: GeometryType =>
147+
val fieldType =
148+
new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
149+
150+
// WKB field is tagged with additional metadata so we can identify that the arrow
151+
// struct actually represents a geometry schema.
152+
val wkbFieldType = new FieldType(
153+
false,
154+
toArrowType(BinaryType, timeZoneId, largeVarTypes),
155+
null,
156+
Map("geometry" -> "true", "srid" -> g.srid.toString).asJava)
157+
158+
new Field(
159+
name,
160+
fieldType,
161+
Seq(
162+
toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes),
163+
new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava)
164+
165+
case g: GeographyType =>
166+
val fieldType =
167+
new FieldType(nullable, ArrowType.Struct.INSTANCE, null, null)
168+
169+
// WKB field is tagged with additional metadata so we can identify that the arrow
170+
// struct actually represents a geography schema.
171+
val wkbFieldType = new FieldType(
172+
false,
173+
toArrowType(BinaryType, timeZoneId, largeVarTypes),
174+
null,
175+
Map("geography" -> "true", "srid" -> g.srid.toString).asJava)
176+
177+
new Field(
178+
name,
179+
fieldType,
180+
Seq(
181+
toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes),
182+
new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava)
146183
case _: VariantType =>
147184
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
148185
// The metadata field is tagged with additional metadata so we can identify that the arrow
@@ -175,6 +212,26 @@ private[sql] object ArrowUtils {
175212
}
176213
}
177214

215+
def isGeometryField(field: Field): Boolean = {
216+
assert(field.getType.isInstanceOf[ArrowType.Struct])
217+
field.getChildren.asScala
218+
.map(_.getName)
219+
.asJava
220+
.containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child =>
221+
child.getName == "wkb" && child.getMetadata.getOrDefault("geometry", "false") == "true"
222+
}
223+
}
224+
225+
def isGeographyField(field: Field): Boolean = {
226+
assert(field.getType.isInstanceOf[ArrowType.Struct])
227+
field.getChildren.asScala
228+
.map(_.getName)
229+
.asJava
230+
.containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child =>
231+
child.getName == "wkb" && child.getMetadata.getOrDefault("geography", "false") == "true"
232+
}
233+
}
234+
178235
def fromArrowField(field: Field): DataType = {
179236
field.getType match {
180237
case _: ArrowType.Map =>
@@ -188,6 +245,26 @@ private[sql] object ArrowUtils {
188245
ArrayType(elementType, containsNull = elementField.isNullable)
189246
case ArrowType.Struct.INSTANCE if isVariantField(field) =>
190247
VariantType
248+
case ArrowType.Struct.INSTANCE if isGeometryField(field) =>
249+
// We expect that type metadata is associated with wkb field.
250+
val metadataField =
251+
field.getChildren.asScala.filter { child => child.getName == "wkb" }.head
252+
val srid = metadataField.getMetadata.get("srid").toInt
253+
if (srid == GeometryType.MIXED_SRID) {
254+
GeometryType("ANY")
255+
} else {
256+
GeometryType(srid)
257+
}
258+
case ArrowType.Struct.INSTANCE if isGeographyField(field) =>
259+
// We expect that type metadata is associated with wkb field.
260+
val metadataField =
261+
field.getChildren.asScala.filter { child => child.getName == "wkb" }.head
262+
val srid = metadataField.getMetadata.get("srid").toInt
263+
if (srid == GeographyType.MIXED_SRID) {
264+
GeographyType("ANY")
265+
} else {
266+
GeographyType(srid)
267+
}
191268
case ArrowType.Struct.INSTANCE =>
192269
val fields = field.getChildren().asScala.map { child =>
193270
val dt = fromArrowField(child)

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ public static GeometryVal stGeomFromWKB(byte[] wkb) {
111111
return toPhysVal(Geometry.fromWkb(wkb));
112112
}
113113

114+
public static GeometryVal stGeomFromWKB(byte[] wkb, int srid) {
115+
return toPhysVal(Geometry.fromWkb(wkb, srid));
116+
}
117+
114118
// ST_SetSrid
115119
public static GeographyVal stSetSrid(GeographyVal geo, int srid) {
116120
// We only allow setting the SRID to geographic values.

sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525

2626
import org.apache.spark.SparkUnsupportedOperationException;
2727
import org.apache.spark.annotation.DeveloperApi;
28+
import org.apache.spark.sql.catalyst.util.STUtils;
2829
import org.apache.spark.sql.util.ArrowUtils;
2930
import org.apache.spark.sql.types.*;
3031
import org.apache.spark.unsafe.types.CalendarInterval;
32+
import org.apache.spark.unsafe.types.GeographyVal;
33+
import org.apache.spark.unsafe.types.GeometryVal;
3134
import org.apache.spark.unsafe.types.UTF8String;
3235

3336
/**
@@ -146,6 +149,30 @@ public ColumnarMap getMap(int rowId) {
146149
super(type);
147150
}
148151

152+
@Override
153+
public GeographyVal getGeography(int rowId) {
154+
if (isNullAt(rowId)) return null;
155+
156+
GeographyType gt = (GeographyType) this.type;
157+
int srid = getChild(0).getInt(rowId);
158+
byte[] bytes = getChild(1).getBinary(rowId);
159+
gt.assertSridAllowedForType(srid);
160+
// TODO(GEO-602): Geog still does not support different SRIDs, once it does,
161+
// we need to update this.
162+
return (bytes == null) ? null : STUtils.stGeogFromWKB(bytes);
163+
}
164+
165+
@Override
166+
public GeometryVal getGeometry(int rowId) {
167+
if (isNullAt(rowId)) return null;
168+
169+
GeometryType gt = (GeometryType) this.type;
170+
int srid = getChild(0).getInt(rowId);
171+
byte[] bytes = getChild(1).getBinary(rowId);
172+
gt.assertSridAllowedForType(srid);
173+
return (bytes == null) ? null : STUtils.stGeomFromWKB(bytes, srid);
174+
}
175+
149176
public ArrowColumnVector(ValueVector vector) {
150177
this(ArrowUtils.fromArrowField(vector.getField()));
151178
initAccessor(vector);

sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._
2424

2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
27+
import org.apache.spark.sql.catalyst.util.STUtils
2728
import org.apache.spark.sql.errors.ExecutionErrors
2829
import org.apache.spark.sql.types._
2930
import org.apache.spark.sql.util.ArrowUtils
@@ -92,6 +93,16 @@ object ArrowWriter {
9293
createFieldWriter(vector.getChildByOrdinal(ordinal))
9394
}
9495
new StructWriter(vector, children.toArray)
96+
case (dt: GeometryType, vector: StructVector) =>
97+
val children = (0 until vector.size()).map { ordinal =>
98+
createFieldWriter(vector.getChildByOrdinal(ordinal))
99+
}
100+
new GeometryWriter(dt, vector, children.toArray)
101+
case (dt: GeographyType, vector: StructVector) =>
102+
val children = (0 until vector.size()).map { ordinal =>
103+
createFieldWriter(vector.getChildByOrdinal(ordinal))
104+
}
105+
new GeographyWriter(dt, vector, children.toArray)
95106
case (dt, _) =>
96107
throw ExecutionErrors.unsupportedDataTypeError(dt)
97108
}
@@ -446,6 +457,42 @@ private[arrow] class StructWriter(
446457
}
447458
}
448459

460+
private[arrow] class GeographyWriter(
461+
dt: GeographyType,
462+
valueVector: StructVector,
463+
children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) {
464+
465+
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
466+
valueVector.setIndexDefined(count)
467+
468+
val geom = STUtils.deserializeGeog(input.getGeography(ordinal), dt)
469+
val bytes = geom.getBytes
470+
val srid = geom.getSrid
471+
472+
val row = InternalRow(srid, bytes)
473+
children(0).write(row, 0)
474+
children(1).write(row, 1)
475+
}
476+
}
477+
478+
private[arrow] class GeometryWriter(
479+
dt: GeometryType,
480+
valueVector: StructVector,
481+
children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) {
482+
483+
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
484+
valueVector.setIndexDefined(count)
485+
486+
val geom = STUtils.deserializeGeom(input.getGeometry(ordinal), dt)
487+
val bytes = geom.getBytes
488+
val srid = geom.getSrid
489+
490+
val row = InternalRow(srid, bytes)
491+
children(0).write(row, 0)
492+
children(1).write(row, 1)
493+
}
494+
}
495+
449496
private[arrow] class MapWriter(
450497
val valueVector: MapVector,
451498
val structVector: StructVector,

sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class ArrowUtilsSuite extends SparkFunSuite {
4949
roundtrip(BinaryType)
5050
roundtrip(DecimalType.SYSTEM_DEFAULT)
5151
roundtrip(DateType)
52+
roundtrip(GeometryType("ANY"))
53+
roundtrip(GeometryType(4326))
54+
roundtrip(GeographyType("ANY"))
55+
roundtrip(GeographyType(4326))
5256
roundtrip(YearMonthIntervalType())
5357
roundtrip(DayTimeIntervalType())
5458
checkError(

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
4646
import org.apache.spark.sql.connect.test.ConnectFunSuite
4747
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType}
4848
import org.apache.spark.unsafe.types.VariantVal
49-
import org.apache.spark.util.SparkStringUtils
49+
import org.apache.spark.util.{MaybeNull, SparkStringUtils}
5050

5151
/**
5252
* Tests for encoding external data to and from arrow.
@@ -218,20 +218,6 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
218218
}
219219
}
220220

221-
private case class MaybeNull(interval: Int) {
222-
assert(interval > 1)
223-
private var invocations = 0
224-
def apply[T](value: T): T = {
225-
val result = if (invocations % interval == 0) {
226-
null.asInstanceOf[T]
227-
} else {
228-
value
229-
}
230-
invocations += 1
231-
result
232-
}
233-
}
234-
235221
private def javaBigDecimal(i: Int): java.math.BigDecimal = {
236222
javaBigDecimal(i, DecimalType.DEFAULT_SCALE)
237223
}

sql/core/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@
4949
<artifactId>spark-sketch_${scala.binary.version}</artifactId>
5050
<version>${project.version}</version>
5151
</dependency>
52+
<dependency>
53+
<groupId>org.apache.spark</groupId>
54+
<artifactId>spark-common-utils_${scala.binary.version}</artifactId>
55+
<version>${project.version}</version>
56+
<classifier>tests</classifier>
57+
<scope>test</scope>
58+
</dependency>
5259
<dependency>
5360
<groupId>org.apache.spark</groupId>
5461
<artifactId>spark-core_${scala.binary.version}</artifactId>

0 commit comments

Comments
 (0)