diff --git a/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala b/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala new file mode 100644 index 0000000000000..44bdffeacfde6 --- /dev/null +++ b/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/* The MaybeNull class is a utility that introduces controlled nullability into a sequence + * of invocations. It is designed to return a ~null~ value at a specified interval while returning + * the provided value for all other invocations. + */ +case class MaybeNull(interval: Int) { + assert(interval > 1) + private var invocations = 0 + def apply[T](value: T): T = { + val result = if (invocations % interval == 0) { + null.asInstanceOf[T] + } else { + value + } + invocations += 1 + result + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 6caabf20f8f6b..23d8a0bbb65b5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -143,6 +143,43 @@ private[sql] object ArrowUtils { largeVarTypes)).asJava) case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) + case g: GeometryType => + val fieldType = + new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + + // WKB field is tagged with additional metadata so we can identify that the arrow + // struct actually represents a geometry schema. + val wkbFieldType = new FieldType( + false, + toArrowType(BinaryType, timeZoneId, largeVarTypes), + null, + Map("geometry" -> "true", "srid" -> g.srid.toString).asJava) + + new Field( + name, + fieldType, + Seq( + toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes), + new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava) + + case g: GeographyType => + val fieldType = + new FieldType(nullable, ArrowType.Struct.INSTANCE, null, null) + + // WKB field is tagged with additional metadata so we can identify that the arrow + // struct actually represents a geography schema. + val wkbFieldType = new FieldType( + false, + toArrowType(BinaryType, timeZoneId, largeVarTypes), + null, + Map("geography" -> "true", "srid" -> g.srid.toString).asJava) + + new Field( + name, + fieldType, + Seq( + toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes), + new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava) case _: VariantType => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) // The metadata field is tagged with additional metadata so we can identify that the arrow @@ -175,6 +212,26 @@ private[sql] object ArrowUtils { } } + def isGeometryField(field: Field): Boolean = { + assert(field.getType.isInstanceOf[ArrowType.Struct]) + field.getChildren.asScala + .map(_.getName) + .asJava + .containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child => + child.getName == "wkb" && child.getMetadata.getOrDefault("geometry", "false") == "true" + } + } + + def isGeographyField(field: Field): Boolean = { + assert(field.getType.isInstanceOf[ArrowType.Struct]) + field.getChildren.asScala + .map(_.getName) + .asJava + .containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child => + child.getName == "wkb" && child.getMetadata.getOrDefault("geography", "false") == "true" + } + } + def fromArrowField(field: Field): DataType = { field.getType match { case _: ArrowType.Map => @@ -188,6 +245,26 @@ private[sql] object ArrowUtils { ArrayType(elementType, containsNull = elementField.isNullable) case ArrowType.Struct.INSTANCE if isVariantField(field) => VariantType + case ArrowType.Struct.INSTANCE if isGeometryField(field) => + // We expect that type metadata is associated with wkb field. + val metadataField = + field.getChildren.asScala.filter { child => child.getName == "wkb" }.head + val srid = metadataField.getMetadata.get("srid").toInt + if (srid == GeometryType.MIXED_SRID) { + GeometryType("ANY") + } else { + GeometryType(srid) + } + case ArrowType.Struct.INSTANCE if isGeographyField(field) => + // We expect that type metadata is associated with wkb field. + val metadataField = + field.getChildren.asScala.filter { child => child.getName == "wkb" }.head + val srid = metadataField.getMetadata.get("srid").toInt + if (srid == GeographyType.MIXED_SRID) { + GeographyType("ANY") + } else { + GeographyType(srid) + } case ArrowType.Struct.INSTANCE => val fields = field.getChildren().asScala.map { child => val dt = fromArrowField(child) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java index 3cf4b84ac0330..0a9942c4cf557 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java @@ -111,6 +111,10 @@ public static GeometryVal stGeomFromWKB(byte[] wkb) { return toPhysVal(Geometry.fromWkb(wkb)); } + public static GeometryVal stGeomFromWKB(byte[] wkb, int srid) { + return toPhysVal(Geometry.fromWkb(wkb, srid)); + } + // ST_SetSrid public static GeographyVal stSetSrid(GeographyVal geo, int srid) { // We only allow setting the SRID to geographic values. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 66116d7c952fd..019bc258579a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -25,9 +25,12 @@ import org.apache.spark.SparkUnsupportedOperationException; import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.sql.catalyst.util.STUtils; import org.apache.spark.sql.util.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.GeographyVal; +import org.apache.spark.unsafe.types.GeometryVal; import org.apache.spark.unsafe.types.UTF8String; /** @@ -146,6 +149,30 @@ public ColumnarMap getMap(int rowId) { super(type); } + @Override + public GeographyVal getGeography(int rowId) { + if (isNullAt(rowId)) return null; + + GeographyType gt = (GeographyType) this.type; + int srid = getChild(0).getInt(rowId); + byte[] bytes = getChild(1).getBinary(rowId); + gt.assertSridAllowedForType(srid); + // TODO(GEO-602): Geog still does not support different SRIDs, once it does, + // we need to update this. + return (bytes == null) ? null : STUtils.stGeogFromWKB(bytes); + } + + @Override + public GeometryVal getGeometry(int rowId) { + if (isNullAt(rowId)) return null; + + GeometryType gt = (GeometryType) this.type; + int srid = getChild(0).getInt(rowId); + byte[] bytes = getChild(1).getBinary(rowId); + gt.assertSridAllowedForType(srid); + return (bytes == null) ? null : STUtils.stGeomFromWKB(bytes, srid); + } + public ArrowColumnVector(ValueVector vector) { this(ArrowUtils.fromArrowField(vector.getField())); initAccessor(vector); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 275fecebdafb8..8d68e74dbf874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.STUtils import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils @@ -92,6 +93,16 @@ object ArrowWriter { createFieldWriter(vector.getChildByOrdinal(ordinal)) } new StructWriter(vector, children.toArray) + case (dt: GeometryType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new GeometryWriter(dt, vector, children.toArray) + case (dt: GeographyType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new GeographyWriter(dt, vector, children.toArray) case (dt, _) => throw ExecutionErrors.unsupportedDataTypeError(dt) } @@ -446,6 +457,42 @@ private[arrow] class StructWriter( } } +private[arrow] class GeographyWriter( + dt: GeographyType, + valueVector: StructVector, + children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) { + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setIndexDefined(count) + + val geom = STUtils.deserializeGeog(input.getGeography(ordinal), dt) + val bytes = geom.getBytes + val srid = geom.getSrid + + val row = InternalRow(srid, bytes) + children(0).write(row, 0) + children(1).write(row, 1) + } +} + +private[arrow] class GeometryWriter( + dt: GeometryType, + valueVector: StructVector, + children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) { + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setIndexDefined(count) + + val geom = STUtils.deserializeGeom(input.getGeometry(ordinal), dt) + val bytes = geom.getBytes + val srid = geom.getSrid + + val row = InternalRow(srid, bytes) + children(0).write(row, 0) + children(1).write(row, 1) + } +} + private[arrow] class MapWriter( val valueVector: MapVector, val structVector: StructVector, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala index 7124c94b390d0..8011e69e724c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala @@ -49,6 +49,10 @@ class ArrowUtilsSuite extends SparkFunSuite { roundtrip(BinaryType) roundtrip(DecimalType.SYSTEM_DEFAULT) roundtrip(DateType) + roundtrip(GeometryType("ANY")) + roundtrip(GeometryType(4326)) + roundtrip(GeographyType("ANY")) + roundtrip(GeographyType(4326)) roundtrip(YearMonthIntervalType()) roundtrip(DayTimeIntervalType()) checkError( diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index b29d73be359b5..bc840df5c3fac 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.connect.test.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType} import org.apache.spark.unsafe.types.VariantVal -import org.apache.spark.util.SparkStringUtils +import org.apache.spark.util.{MaybeNull, SparkStringUtils} /** * Tests for encoding external data to and from arrow. @@ -218,20 +218,6 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } - private case class MaybeNull(interval: Int) { - assert(interval > 1) - private var invocations = 0 - def apply[T](value: T): T = { - val result = if (invocations % interval == 0) { - null.asInstanceOf[T] - } else { - value - } - invocations += 1 - result - } - } - private def javaBigDecimal(i: Int): java.math.BigDecimal = { javaBigDecimal(i, DecimalType.DEFAULT_SCALE) } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 6b9d08f0dde71..a5b5c399d4fc6 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -49,6 +49,13 @@ spark-sketch_${scala.binary.version} ${project.version} + + org.apache.spark + spark-common-utils_${scala.binary.version} + ${project.version} + tests + test + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 99d245529e96d..2c0c0494bbacf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -17,16 +17,23 @@ package org.apache.spark.sql.execution.arrow +import scala.jdk.CollectionConverters._ + import org.apache.arrow.vector.VectorSchemaRoot import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.YearUDT import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.{Geography => InternalGeography, Geometry => InternalGeometry} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, GeographyVal, GeometryVal, UTF8String} +import org.apache.spark.util.MaybeNull class ArrowWriterSuite extends SparkFunSuite { @@ -52,8 +59,16 @@ class ArrowWriterSuite extends SparkFunSuite { } writer.finish() + val dataModified = data.map { datum => + dt match { + case _: GeometryType => datum.asInstanceOf[GeometryVal].getBytes + case _: GeographyType => datum.asInstanceOf[GeographyVal].getBytes + case _ => datum + } + } + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - data.zipWithIndex.foreach { + dataModified.zipWithIndex.foreach { case (null, rowId) => assert(reader.isNullAt(rowId)) case (datum, rowId) => val value = datatype match { @@ -74,12 +89,31 @@ class ArrowWriterSuite extends SparkFunSuite { case _: YearMonthIntervalType => reader.getInt(rowId) case _: DayTimeIntervalType => reader.getLong(rowId) case CalendarIntervalType => reader.getInterval(rowId) + case _: GeometryType => reader.getGeometry(rowId).getBytes + case _: GeographyType => reader.getGeography(rowId).getBytes } assert(value === datum) } writer.root.close() } + + val wkbs = Seq("010100000000000000000031400000000000001c40", + "010100000000000000000034400000000000003540") + .map { x => + x.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + } + + val geographies = wkbs.map(x => InternalGeography.fromWkb(x, 4326).getValue) + val geometries = wkbs.map(x => InternalGeometry.fromWkb(x, 0).getValue) + val mixedGeometries = wkbs.zip(Seq(0, 4326)).map { + case (g, srid) => InternalGeometry.fromWkb(g, srid).getValue + } + + check(GeometryType(0), geometries) + check(GeographyType(4326), geographies) + check(GeometryType("ANY"), mixedGeometries) + check(GeographyType("ANY"), geographies) check(BooleanType, Seq(true, null, false)) check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte)) check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort)) @@ -110,6 +144,245 @@ class ArrowWriterSuite extends SparkFunSuite { check(new YearUDT, Seq(2020, 2021, null, 2022)) } + test("nested geographies") { + def check( + dt: StructType, + data: Seq[InternalRow]): Unit = { + val writer = ArrowWriter.create(dt.asInstanceOf[StructType], "UTC") + + // Write data to arrow. + data.toSeq.foreach { datum => + writer.write(datum) + } + writer.finish() + + // Create arrow vector readers. + val vectors = writer.root.getFieldVectors.asScala + .map { new ArrowColumnVector(_) }.toArray.asInstanceOf[Array[ColumnVector]] + + val batch = new ColumnarBatch(vectors, writer.root.getRowCount.toInt) + + data.zipWithIndex.foreach { case (datum, i) => + // Read data from arrow. + val internalRow = batch.getRow(i) + + // All nullable results first must check whether the value is null. + if (datum.getStruct(0, 4) == null || internalRow.getStruct(0, 4) == null) { + assert(datum.getStruct(0, 4) == null && internalRow.getStruct(0, 4) == null) + } else { + val expectedStruct = datum.getStruct(0, 4) + val actualStruct = internalRow.getStruct(0, 4) + assert(expectedStruct.getInt(0) === actualStruct.getInt(0)) + assert(expectedStruct.getInt(2) === actualStruct.getInt(2)) + + if (expectedStruct.getGeography(1) == null || + actualStruct.getGeography(1) == null) { + assert(expectedStruct.getGeography(1) == null && actualStruct.getGeography(1) == null) + } else { + assert(expectedStruct.getGeography(1).getBytes === + actualStruct.getGeography(1).getBytes) + } + if (expectedStruct.getGeography(3) == null || + actualStruct.getGeography(3) == null) { + assert(expectedStruct.getGeography(3) == null && actualStruct.getGeography(3) == null) + } else { + assert(expectedStruct.getGeography(3).getBytes === + actualStruct.getGeography(3).getBytes) + } + + if (datum.getArray(1) == null || + internalRow.getArray(1) == null) { + assert(internalRow.getArray(1) == null && datum.getArray(1) == null) + } else { + internalRow.getArray(1).toSeq[GeographyVal](GeographyType(4326)) + .zip(datum.getArray(1).toSeq[GeographyVal](GeographyType(4326))).foreach { + case (actual, expected) => + assert(actual.getBytes === expected.getBytes) + } + } + + if (datum.getMap(2) == null || + internalRow.getMap(2) == null) { + assert(internalRow.getMap(2) == null && datum.getMap(2) == null) + } else { + assert(internalRow.getMap(2).keyArray().toSeq(StringType) === + datum.getMap(2).keyArray().toSeq(StringType)) + internalRow.getMap(2).valueArray().toSeq[GeographyVal](GeographyType("ANY")) + .zip(datum.getMap(2).valueArray().toSeq[GeographyVal](GeographyType("ANY"))).foreach { + case (actual, expected) => + assert((actual == null && expected == null) || + actual.getBytes === expected.getBytes) + } + } + } + } + + writer.root.close() + } + + val point1 = "010100000000000000000031400000000000001C40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + val point2 = "010100000000000000000035400000000000001E40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + + val schema = new StructType() + .add( + "s", + new StructType() + .add("i1", "int") + .add("g0", "geography(4326)") + .add("i2", "int") + .add("g1", "geography(4326)")) + .add("a", "array") + .add("m", "map") + + val maybeNull5 = MaybeNull(5) + val maybeNull7 = MaybeNull(7) + val maybeNull11 = MaybeNull(11) + val maybeNull13 = MaybeNull(13) + val maybeNull17 = MaybeNull(17) + + val nestedGeographySerializer = ExpressionEncoder(toRowEncoder(schema)).createSerializer() + val data = Iterator + .tabulate(100)(i => + nestedGeographySerializer.apply( + (Row( + maybeNull5( + Row( + i, + maybeNull7(org.apache.spark.sql.types.Geography.fromWKB(point1)), + i + 1, + maybeNull11(org.apache.spark.sql.types.Geography.fromWKB(point2, 4326)))), + maybeNull7((0 until 10).map(j => + org.apache.spark.sql.types.Geography.fromWKB(point2, 4326))), + maybeNull13( + Map((i.toString, maybeNull17( + org.apache.spark.sql.types.Geography.fromWKB(point1, 4326))))))))) + .map(_.copy()).toSeq + + check(schema, data) + } + + test("nested geometries") { + def check( + dt: StructType, + data: Seq[InternalRow]): Unit = { + val writer = ArrowWriter.create(dt.asInstanceOf[StructType], "UTC") + + // Write data to arrow. + data.toSeq.foreach { datum => + writer.write(datum) + } + writer.finish() + + // Create arrow vector readers. + val vectors = writer.root.getFieldVectors.asScala + .map { new ArrowColumnVector(_) }.toArray.asInstanceOf[Array[ColumnVector]] + + val batch = new ColumnarBatch(vectors, writer.root.getRowCount.toInt) + data.zipWithIndex.foreach { case (datum, i) => + // Read data from arrow. + val internalRow = batch.getRow(i) + + // All nullable results first must check whether the value is null. + if (datum.getStruct(0, 4) == null || internalRow.getStruct(0, 4) == null) { + assert(datum.getStruct(0, 4) == null && internalRow.getStruct(0, 4) == null) + } else { + val expectedStruct = datum.getStruct(0, 4) + val actualStruct = internalRow.getStruct(0, 4) + assert(expectedStruct.getInt(0) === actualStruct.getInt(0)) + assert(expectedStruct.getInt(2) === actualStruct.getInt(2)) + + if (expectedStruct.getGeometry(1) == null || + actualStruct.getGeometry(1) == null) { + assert(expectedStruct.getGeometry(1) == null && actualStruct.getGeometry(1) == null) + } else { + assert(expectedStruct.getGeometry(1).getBytes === + actualStruct.getGeometry(1).getBytes) + } + if (expectedStruct.getGeometry(3) == null || + actualStruct.getGeometry(3) == null) { + assert(expectedStruct.getGeometry(3) == null && actualStruct.getGeometry(3) == null) + } else { + assert(expectedStruct.getGeometry(3).getBytes === + actualStruct.getGeometry(3).getBytes) + } + + if (datum.getArray(1) == null || + internalRow.getArray(1) == null) { + assert(internalRow.getArray(1) == null && datum.getArray(1) == null) + } else { + internalRow.getArray(1).toSeq[GeometryVal](GeometryType(0)) + .zip(datum.getArray(1).toSeq[GeometryVal](GeometryType(0))).foreach { + case (actual, expected) => + assert(actual.getBytes === expected.getBytes) + } + } + + if (datum.getMap(2) == null || + internalRow.getMap(2) == null) { + assert(internalRow.getMap(2) == null && datum.getMap(2) == null) + } else { + assert(internalRow.getMap(2).keyArray().toSeq(StringType) === + datum.getMap(2).keyArray().toSeq(StringType)) + internalRow.getMap(2).valueArray().toSeq[GeometryVal](GeometryType("ANY")) + .zip(datum.getMap(2).valueArray().toSeq[GeometryVal](GeometryType("ANY"))).foreach { + case (actual, expected) => + assert((actual == null && expected == null) || + actual.getBytes === expected.getBytes) + } + } + } + } + + writer.root.close() + } + + val point1 = "010100000000000000000031400000000000001C40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + val point2 = "010100000000000000000035400000000000001E40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + + val schema = new StructType() + .add( + "s", + new StructType() + .add("i1", "int") + .add("g0", "geometry(0)") + .add("i2", "int") + .add("g4326", "geometry(4326)")) + .add("a", "array") + .add("m", "map") + + val maybeNull5 = MaybeNull(5) + val maybeNull7 = MaybeNull(7) + val maybeNull11 = MaybeNull(11) + val maybeNull13 = MaybeNull(13) + val maybeNull17 = MaybeNull(17) + + val nestedGeometrySerializer = ExpressionEncoder(toRowEncoder(schema)).createSerializer() + val data = Iterator + .tabulate(100) { i => + val mixedSrid = if (i % 2 == 0) 0 else 4326 + + nestedGeometrySerializer.apply( + (Row( + maybeNull5( + Row( + i, + maybeNull7(org.apache.spark.sql.types.Geometry.fromWKB(point1, 0)), + i + 1, + maybeNull11(org.apache.spark.sql.types.Geometry.fromWKB(point2, 4326)))), + maybeNull7((0 until 10).map(j => + org.apache.spark.sql.types.Geometry.fromWKB(point2, 0))), + maybeNull13( + Map((i.toString, maybeNull17( + org.apache.spark.sql.types.Geometry.fromWKB(point1, mixedSrid)))))))) + }.map(_.copy()).toSeq + + check(schema, data) + } + test("get multiple") { def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val datatype = dt match {