diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala index 581ffffff2f9..dd8f0398aba9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.types.ops +import java.time.LocalTime + +import org.apache.arrow.vector.types.TimeUnit +import org.apache.arrow.vector.types.pojo.ArrowType + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder -import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, TimeFormatter} +import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, SparkDateTimeUtils, TimeFormatter} import org.apache.spark.sql.types.{DataType, TimeType} /** @@ -28,6 +33,10 @@ import org.apache.spark.sql.types.{DataType, TimeType} * This class implements all TypeApiOps methods for the TIME data type: * - String formatting: uses FractionTimeFormatter for consistent output * - Row encoding: uses LocalTimeEncoder for java.time.LocalTime + * - Arrow conversion (ArrowUtils) + * - Python interop (EvaluatePython) + * - Hive formatting (HiveResult) + * - Thrift type mapping (SparkExecuteStatementOperation) * * RELATIONSHIP TO TimeTypeOps: TimeTypeOps (in catalyst package) extends this class to inherit * client-side operations while adding server-side operations (physical type, literals, etc.). @@ -56,4 +65,26 @@ class TimeTypeApiOps(val t: TimeType) extends TypeApiOps { // ==================== Row Encoding ==================== override def getEncoder: AgnosticEncoder[_] = LocalTimeEncoder + + // ==================== Optional Operations ==================== + + override def toArrowType(timeZoneId: String): Option[ArrowType] = { + Some(new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8)) + } + + override def needConversionInPython: Option[Boolean] = Some(true) + + override def makeFromJava: Option[Any => Any] = Some((obj: Any) => + nullSafeConvert(obj) { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong + }) + + override def formatExternal(value: Any): Option[String] = { + val nanos = SparkDateTimeUtils.localTimeToNanos(value.asInstanceOf[LocalTime]) + Some(timeFormatter.format(nanos)) + } + + override def thriftTypeName: Option[String] = Some("STRING_TYPE") } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TypeApiOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TypeApiOps.scala index f16e8fbc3b55..fff5b8b6a022 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TypeApiOps.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TypeApiOps.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types.ops +import org.apache.arrow.vector.types.pojo.ArrowType + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{DataType, TimeType} @@ -26,14 +28,9 @@ import org.apache.spark.unsafe.types.UTF8String * Client-side (spark-api) type operations for the Types Framework. * * This trait consolidates all client-side operations that a data type must implement to be usable - * in the Spark SQL API layer. All methods are mandatory because a type cannot function correctly - * without string formatting (needed for CAST to STRING, EXPLAIN, SHOW) or encoding (needed for - * Dataset[T] operations). - * - * This single-interface design was chosen over separate FormatTypeOps/EncodeTypeOps traits to - * make it clear what a new type must implement - there is one mandatory interface, and it - * contains everything required. Optional capabilities (e.g., proto, Arrow, JDBC) are defined as - * separate traits that can be mixed in incrementally. + * in the Spark SQL API layer. Mandatory methods (format, toSQLValue, getEncoder) must be + * implemented by every type. Optional methods (Arrow, Python, Hive, Thrift) return Option and + * default to None - types implement them as they expand their integration coverage. * * RELATIONSHIP TO TypeOps: * - TypeOps (catalyst): Server-side operations - physical types, literals, conversions @@ -93,6 +90,49 @@ trait TypeApiOps extends Serializable { * AgnosticEncoder instance (e.g., LocalTimeEncoder for TimeType) */ def getEncoder: AgnosticEncoder[_] + + // ==================== Utilities ==================== + + /** + * Null-safe conversion helper. Returns null for null input, applies the partial function for + * non-null input, and returns null for unmatched values. + */ + protected def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, (_: Any) => null) + } + } + + // ==================== Arrow Conversion (optional) ==================== + + /** Converts this DataType to its Arrow representation. Returns None if not supported. */ + def toArrowType(timeZoneId: String): Option[ArrowType] = None + + // ==================== Python Interop (optional) ==================== + + /** Returns true if values of this type need conversion when passed to/from Python. */ + def needConversionInPython: Option[Boolean] = None + + /** Creates a converter function for Python/Py4J interop. */ + def makeFromJava: Option[Any => Any] = None + + // ==================== Hive Formatting (optional) ==================== + + /** + * Formats an external-type value for Hive output. Most types override this simple version. + * Types that need different formatting when nested should override the 2-param overload. + */ + def formatExternal(value: Any): Option[String] = None + + /** Formats with nesting context. Default delegates to the simple version. */ + def formatExternal(value: Any, nested: Boolean): Option[String] = formatExternal(value) + + // ==================== Thrift Mapping (optional) ==================== + + /** Returns the Thrift TTypeId name for this type (e.g., "STRING_TYPE"). */ + def thriftTypeName: Option[String] = None } /** @@ -123,4 +163,18 @@ object TypeApiOps { case _ => None } } + + /** + * Reverse lookup: converts an Arrow type to a Spark DataType. + */ + def fromArrowType(at: ArrowType): Option[DataType] = { + import org.apache.arrow.vector.types.TimeUnit + if (!SqlApiConf.get.typesFrameworkEnabled) return None + at match { + case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 => + Some(TimeType(TimeType.MICROS_PRECISION)) + // Add new framework types here + case _ => None + } + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 92b52d4ae634..1c1024fc0152 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -29,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.SparkException import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.util.ArrayImplicits._ private[sql] object ArrowUtils { @@ -39,6 +40,14 @@ private[sql] object ArrowUtils { /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = + TypeApiOps(dt) + .flatMap(_.toArrowType(timeZoneId)) + .getOrElse(toArrowTypeDefault(dt, timeZoneId, largeVarTypes)) + + private def toArrowTypeDefault( + dt: DataType, + timeZoneId: String, + largeVarTypes: Boolean): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) @@ -67,7 +76,10 @@ private[sql] object ArrowUtils { throw ExecutionErrors.unsupportedDataTypeError(dt) } - def fromArrowType(dt: ArrowType): DataType = dt match { + def fromArrowType(dt: ArrowType): DataType = + TypeApiOps.fromArrowType(dt).getOrElse(fromArrowTypeDefault(dt)) + + private def fromArrowTypeDefault(dt: ArrowType): DataType = dt match { case ArrowType.Bool.INSTANCE => BooleanType case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 080794643fa0..8bd162afd56d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder} import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} +import org.apache.spark.sql.catalyst.types.ops.TypeOps import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils, STUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -289,6 +290,16 @@ object DeserializerBuildHelper { * @param isTopLevel true if we are creating a deserializer for the top level value. */ private def createDeserializer( + enc: AgnosticEncoder[_], + path: Expression, + walkedTypePath: WalkedTypePath, + isTopLevel: Boolean = false): Expression = + // Framework dispatch runs before encoder-type checks in the default path. Safe because + // framework types use dedicated leaf encoders, never migration shims or native primitives. + TypeOps(enc.dataType).flatMap(_.createDeserializer(path, walkedTypePath, isTopLevel)) + .getOrElse(createDeserializerDefault(enc, path, walkedTypePath, isTopLevel)) + + private def createDeserializerDefault( enc: AgnosticEncoder[_], path: Expression, walkedTypePath: WalkedTypePath, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index b8b2406a5813..37a4efc65739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.types.ops.TypeOps import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils, STUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -332,7 +333,15 @@ object SerializerBuildHelper { * representation. The mapping between the external and internal representations is described * by encoder `enc`. */ - private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match { + private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = + // Framework dispatch runs before encoder-type checks (AgnosticExpressionPathEncoder, + // isNativeEncoder) in the default path. This is safe because framework types use dedicated + // leaf encoders (e.g., LocalTimeEncoder), never migration shims or native primitives. + TypeOps(enc.dataType).flatMap(_.createSerializer(input)) + .getOrElse(createSerializerDefault(enc, input)) + + private def createSerializerDefault( + enc: AgnosticEncoder[_], input: Expression): Expression = enc match { case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input) case _ if isNativeEncoder(enc) => input case BoxedBooleanEncoder => createSerializerForBoolean(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala index 74198c956edc..3c0eda1b1c05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.catalyst.types.ops import java.time.LocalTime +import org.apache.arrow.vector.{TimeNanoVector, ValueVector} + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, MutableLong, MutableValue} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, MutableLong, MutableValue} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.TimeType +import org.apache.spark.sql.execution.arrow.{ArrowFieldWriter, TimeWriter} +import org.apache.spark.sql.types.{ObjectType, TimeType} import org.apache.spark.sql.types.ops.TimeTypeApiOps /** @@ -37,6 +41,8 @@ import org.apache.spark.sql.types.ops.TimeTypeApiOps * It also inherits client-side operations from TimeTypeApiOps: * - String formatting (FractionTimeFormatter) * - Row encoding (LocalTimeEncoder) + * - Serializer/deserializer expression building (SerializerBuildHelper, DeserializerBuildHelper) + * - Arrow field writer creation (ArrowWriter) * * INTERNAL REPRESENTATION: * - Values stored as Long nanoseconds since midnight @@ -48,7 +54,8 @@ import org.apache.spark.sql.types.ops.TimeTypeApiOps * The TimeType with precision information * @since 4.2.0 */ -case class TimeTypeOps(override val t: TimeType) extends TimeTypeApiOps(t) with TypeOps { +case class TimeTypeOps(override val t: TimeType) + extends TimeTypeApiOps(t) with TypeOps { // ==================== Physical Type Representation ==================== @@ -81,4 +88,28 @@ case class TimeTypeOps(override val t: TimeType) extends TimeTypeApiOps(t) with override def toScalaImpl(row: InternalRow, column: Int): Any = { DateTimeUtils.nanosToLocalTime(row.getLong(column)) } + + // ==================== Optional Operations ==================== + + override def createSerializer(input: Expression): Option[Expression] = { + Some(StaticInvoke( + DateTimeUtils.getClass, + t, + "localTimeToNanos", + input :: Nil, + returnNullable = false)) + } + + override def createDeserializer(path: Expression): Option[Expression] = { + Some(StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.time.LocalTime]), + "nanosToLocalTime", + path :: Nil, + returnNullable = false)) + } + + override def createArrowFieldWriter(vector: ValueVector): Option[ArrowFieldWriter] = { + Some(new TimeWriter(vector.asInstanceOf[TimeNanoVector])) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TypeOps.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TypeOps.scala index 628dfe941407..7240f0533aa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TypeOps.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TypeOps.scala @@ -19,9 +19,13 @@ package org.apache.spark.sql.catalyst.types.ops import javax.annotation.Nullable +import org.apache.arrow.vector.ValueVector + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, MutableValue} +import org.apache.spark.sql.catalyst.WalkedTypePath +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, MutableValue} import org.apache.spark.sql.catalyst.types.PhysicalDataType +import org.apache.spark.sql.execution.arrow.ArrowFieldWriter import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, TimeType} @@ -29,14 +33,9 @@ import org.apache.spark.sql.types.{DataType, TimeType} * Server-side (catalyst) type operations for the Types Framework. * * This trait consolidates all server-side operations that a data type must implement to function in - * the Spark SQL engine. All methods are mandatory because without any of them the type would fail - * at runtime - physical type mapping is needed for storage, literals for the optimizer, and - * external type conversion for user-facing operations like collect() and UDFs. - * - * This single-interface design was chosen over separate PhyTypeOps/LiteralTypeOps/ExternalTypeOps - * traits to make it clear what a new type must implement. There is one mandatory interface with - * everything required. Optional capabilities (e.g., proto serialization, client integration) are - * defined as separate traits that can be mixed in incrementally as a type's support expands. + * the Spark SQL engine. Mandatory methods (physical type, literals, external conversion) must be + * implemented by every type. Optional methods (serialization, Arrow writer) return Option and + * default to None - types implement them as they expand their integration coverage. * * USAGE - integration points use TypeOps(dt) which returns Option[TypeOps]: * {{{ @@ -181,6 +180,28 @@ trait TypeOps extends Serializable { final def toScala(row: InternalRow, column: Int): Any = { if (row.isNullAt(column)) null else toScalaImpl(row, column) } + + // ==================== Serialization (optional) ==================== + + /** Creates a serializer expression (external -> internal). */ + def createSerializer(input: Expression): Option[Expression] = None + + /** + * Creates a deserializer expression (internal -> external). + * Most types override this simple version. + */ + def createDeserializer(path: Expression): Option[Expression] = None + + /** Creates a deserializer with full context. Default delegates to the simple version. */ + def createDeserializer( + path: Expression, + walkedTypePath: WalkedTypePath, + isTopLevel: Boolean): Option[Expression] = createDeserializer(path) + + // ==================== Arrow Writer (optional) ==================== + + /** Creates an ArrowFieldWriter for this type. */ + def createArrowFieldWriter(vector: ValueVector): Option[ArrowFieldWriter] = None } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index b5269da035f3..b59732f4820e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.types.ops.TypeOps import org.apache.spark.sql.catalyst.util.STUtils import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ @@ -52,7 +53,14 @@ object ArrowWriter { private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() - (ArrowUtils.fromArrowField(field), vector) match { + val dt = ArrowUtils.fromArrowField(field) + TypeOps(dt).flatMap(_.createArrowFieldWriter(vector)) + .getOrElse(createFieldWriterDefault(dt, vector)) + } + + private def createFieldWriterDefault( + dt: DataType, vector: ValueVector): ArrowFieldWriter = { + (dt, vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) @@ -146,7 +154,7 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { } } -private[arrow] abstract class ArrowFieldWriter { +private[sql] abstract class ArrowFieldWriter { def valueVector: ValueVector @@ -371,7 +379,7 @@ private[arrow] class TimestampNTZWriter( } } -private[arrow] class TimeWriter( +private[sql] class TimeWriter( val valueVector: TimeNanoVector) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index cc853f3c8a8c..f2786c61d1b5 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.connect.common.types.ops.ConnectTypeOps import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors} import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.{CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator} @@ -89,6 +90,15 @@ object ArrowDeserializers { } private[arrow] def deserializerFor( + encoder: AgnosticEncoder[_], + data: AnyRef, + timeZoneId: String): Deserializer[Any] = + ConnectTypeOps + .forEncoder(encoder) + .map(_.createArrowDeserializer(encoder, data, timeZoneId)) + .getOrElse(deserializerForDefault(encoder, data, timeZoneId)) + + private def deserializerForDefault( encoder: AgnosticEncoder[_], data: AnyRef, timeZoneId: String): Deserializer[Any] = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index d547c81afe5a..786d6a1d3bbb 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} +import org.apache.spark.sql.connect.common.types.ops.ConnectTypeOps import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator} @@ -239,7 +240,13 @@ object ArrowSerializer { } // TODO throw better errors on class cast exceptions. - private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { + private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = + ConnectTypeOps + .forEncoder(encoder) + .map(_.createArrowSerializer(v)) + .getOrElse(serializerForDefault(encoder, v)) + + private def serializerForDefault[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { (encoder, v) match { case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => new FieldSerializer[Boolean, BitVector](v) { @@ -562,7 +569,8 @@ object ArrowSerializer { def write(index: Int, value: Any): Unit } - private abstract class FieldSerializer[E, V <: FieldVector](val vector: V) extends Serializer { + private[connect] abstract class FieldSerializer[E, V <: FieldVector](val vector: V) + extends Serializer { def set(index: Int, value: E): Unit override def write(index: Int, raw: Any): Unit = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala index ea57e0e1c77f..54311cecc162 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, Ti import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ +import org.apache.spark.sql.connect.common.types.ops.ConnectTypeOps import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, UpCastRule, YearMonthIntervalType} import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.SparkStringUtils @@ -37,7 +38,7 @@ import org.apache.spark.util.SparkStringUtils * the read methods. If upcasting is allowed for the given vector, then all allowed read methods * must be implemented. */ -private[arrow] abstract class ArrowVectorReader { +private[connect] abstract class ArrowVectorReader { def isNull(i: Int): Boolean def getBoolean(i: Int): Boolean = unsupported() def getByte(i: Int): Byte = unsupported() @@ -66,6 +67,15 @@ private[arrow] abstract class ArrowVectorReader { object ArrowVectorReader { def apply( + targetDataType: DataType, + vector: FieldVector, + timeZoneId: String): ArrowVectorReader = + ConnectTypeOps + .forDataType(targetDataType) + .map(_.createArrowVectorReader(vector)) + .getOrElse(applyDefault(targetDataType, vector, timeZoneId)) + + private def applyDefault( targetDataType: DataType, vector: FieldVector, timeZoneId: String): ArrowVectorReader = { @@ -279,7 +289,7 @@ private[arrow] class TimeStampMicroVectorReader(v: TimeStampMicroVector, timeZon override def getString(i: Int): String = formatter.format(utcMicros(i)) } -private[arrow] class TimeVectorReader(v: TimeNanoVector) +private[connect] class TimeVectorReader(v: TimeNanoVector) extends TypedArrowVectorReader[TimeNanoVector](v) { private lazy val formatter = TimeFormatter.getFractionFormatter() private def nanos(i: Int): Long = vector.get(i) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala new file mode 100644 index 000000000000..ca3d940d3e50 --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.arrow.types.ops + +import java.time.LocalTime + +import org.apache.arrow.vector.{FieldVector, TimeNanoVector} + +import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder +import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils +import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer, ArrowVectorReader, TimeVectorReader} +import org.apache.spark.sql.connect.common.types.ops.ConnectTypeOps +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Combined Connect operations for TimeType. + * + * Implements ConnectTypeOps for TimeType, providing both proto DataType/Literal conversions and + * Arrow serialization/deserialization. + * + * Lives under the arrow.types.ops sub-package to access arrow-private types (TimeVectorReader, + * ArrowSerializer.Serializer, ArrowDeserializers.LeafFieldDeserializer) while keeping ops + * implementations separate from core arrow infrastructure. + * + * @param t + * The TimeType with precision information + * @since 4.2.0 + */ +private[connect] class TimeTypeConnectOps(val t: TimeType) extends ConnectTypeOps { + + override def dataType: DataType = t + + override def encoder: AgnosticEncoder[_] = LocalTimeEncoder + + // ==================== Proto Conversions ==================== + + override def toCatalystTypeFromProto(t: proto.DataType): DataType = { + val time = t.getTime + if (time.hasPrecision) TimeType(time.getPrecision) else TimeType() + } + + override def toConnectProtoType: proto.DataType = { + proto.DataType + .newBuilder() + .setTime(proto.DataType.Time.newBuilder().setPrecision(t.precision).build()) + .build() + } + + override def toLiteralProto( + value: Any, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder = { + val v = value.asInstanceOf[LocalTime] + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(TimeType.DEFAULT_PRECISION)) + } + + override def toLiteralProtoWithType( + value: Any, + dt: DataType, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder = { + val v = value.asInstanceOf[LocalTime] + val timeType = dt.asInstanceOf[TimeType] + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(timeType.precision)) + } + + override def getScalaConverter: proto.Expression.Literal => Any = { v => + SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) + } + + override def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): proto.DataType = { + val timeBuilder = proto.DataType.Time.newBuilder() + if (literal.getTime.hasPrecision) { + timeBuilder.setPrecision(literal.getTime.getPrecision) + } + proto.DataType.newBuilder().setTime(timeBuilder.build()).build() + } + + // ==================== Arrow Serialization ==================== + + override def createArrowSerializer(vector: AnyRef): ArrowSerializer.Serializer = { + val v = vector.asInstanceOf[TimeNanoVector] + new ArrowSerializer.FieldSerializer[LocalTime, TimeNanoVector](v) { + override def set(index: Int, value: LocalTime): Unit = + v.setSafe(index, SparkDateTimeUtils.localTimeToNanos(value)) + } + } + + override def createArrowDeserializer( + enc: AgnosticEncoder[_], + data: AnyRef, + timeZoneId: String): ArrowDeserializers.Deserializer[Any] = { + val v = data.asInstanceOf[FieldVector] + new ArrowDeserializers.LeafFieldDeserializer[LocalTime](enc, v, timeZoneId) { + override def value(i: Int): LocalTime = reader.getLocalTime(i) + } + } + + override def createArrowVectorReader(vector: FieldVector): ArrowVectorReader = { + new TimeVectorReader(vector.asInstanceOf[TimeNanoVector]) + } +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index ceccf780f586..11b1b394b1ed 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.common.types.ops.ConnectTypeOps import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils @@ -29,7 +30,10 @@ import org.apache.spark.util.SparkClassUtils * Helper class for conversions between [[DataType]] and [[proto.DataType]]. */ object DataTypeProtoConverter { - def toCatalystType(t: proto.DataType): DataType = { + def toCatalystType(t: proto.DataType): DataType = + ConnectTypeOps.toCatalystType(t).getOrElse(toCatalystTypeDefault(t)) + + private def toCatalystTypeDefault(t: proto.DataType): DataType = { t.getKindCase match { case proto.DataType.KindCase.NULL => NullType @@ -174,7 +178,12 @@ object DataTypeProtoConverter { toConnectProtoTypeInternal(t, bytesToBinary) } - private def toConnectProtoTypeInternal(t: DataType, bytesToBinary: Boolean): proto.DataType = { + private def toConnectProtoTypeInternal(t: DataType, bytesToBinary: Boolean): proto.DataType = + ConnectTypeOps(t) + .map(_.toConnectProtoType) + .getOrElse(toConnectProtoTypeDefault(t, bytesToBinary)) + + private def toConnectProtoTypeDefault(t: DataType, bytesToBinary: Boolean): proto.DataType = { t match { case NullType => ProtoDataTypes.NullType diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 026f5441c6ca..29623c7cbdb9 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} import org.apache.spark.sql.connect.common.DataTypeProtoConverter._ +import org.apache.spark.sql.connect.common.types.ops.ConnectTypeOps import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -57,7 +58,15 @@ object LiteralValueProtoConverter { literal: Any, options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() + ConnectTypeOps.toLiteralProtoForValue(literal, builder).getOrElse { + toLiteralProtoBuilderDefault(literal, builder, options) + } + } + private def toLiteralProtoBuilderDefault( + literal: Any, + builder: proto.Expression.Literal.Builder, + options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { def decimalBuilder(precision: Int, scale: Int, value: String) = { builder.getDecimalBuilder.setPrecision(precision).setScale(scale).setValue(value) } @@ -127,6 +136,16 @@ object LiteralValueProtoConverter { dataType: DataType, options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() + ConnectTypeOps(dataType) + .map(_.toLiteralProtoWithType(literal, dataType, builder)) + .getOrElse(toLiteralProtoWithTypeDefault(literal, dataType, builder, options)) + } + + private def toLiteralProtoWithTypeDefault( + literal: Any, + dataType: DataType, + builder: proto.Expression.Literal.Builder, + options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { def arrayBuilder(scalaValue: Any, elementType: DataType) = { val ab = builder.getArrayBuilder @@ -384,7 +403,8 @@ object LiteralValueProtoConverter { getScalaConverter(getProtoDataType(literal))(literal) } - private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { + private def getScalaConverterDefault( + dataType: proto.DataType): proto.Expression.Literal => Any = { val converter: proto.Expression.Literal => Any = dataType.getKindCase match { case proto.DataType.KindCase.NULL => v => @@ -428,6 +448,14 @@ object LiteralValueProtoConverter { "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", Map("typeInfo" -> dataType.getKindCase.toString)) } + converter + } + + private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { + val converter: proto.Expression.Literal => Any = + ConnectTypeOps.getScalaConverterForKind(dataType.getKindCase).getOrElse { + getScalaConverterDefault(dataType) + } v => if (v.hasNull) null else converter(v) } @@ -500,102 +528,9 @@ object LiteralValueProtoConverter { if (literal.getLiteralTypeCase == proto.Expression.Literal.LiteralTypeCase.NULL) { literal.getNull } else { - val builder = proto.DataType.newBuilder() - literal.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.BINARY => - builder.setBinary(proto.DataType.Binary.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BYTE => - builder.setByte(proto.DataType.Byte.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.SHORT => - builder.setShort(proto.DataType.Short.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.INTEGER => - builder.setInteger(proto.DataType.Integer.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.LONG => - builder.setLong(proto.DataType.Long.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.FLOAT => - builder.setFloat(proto.DataType.Float.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - builder.setDouble(proto.DataType.Double.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DECIMAL => - val decimal = Decimal.apply(literal.getDecimal.getValue) - var precision = decimal.precision - if (literal.getDecimal.hasPrecision) { - precision = math.max(precision, literal.getDecimal.getPrecision) - } - var scale = decimal.scale - if (literal.getDecimal.hasScale) { - scale = math.max(scale, literal.getDecimal.getScale) - } - builder.setDecimal( - proto.DataType.Decimal - .newBuilder() - .setPrecision(math.max(precision, scale)) - .setScale(scale) - .build()) - case proto.Expression.Literal.LiteralTypeCase.STRING => - builder.setString(proto.DataType.String.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DATE => - builder.setDate(proto.DataType.Date.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => - builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => - builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => - builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => - builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => - builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIME => - val timeBuilder = proto.DataType.Time.newBuilder() - if (literal.getTime.hasPrecision) { - timeBuilder.setPrecision(literal.getTime.getPrecision) - } - builder.setTime(timeBuilder.build()) - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - if (literal.getArray.hasElementType) { - builder.setArray( - proto.DataType.Array - .newBuilder() - .setElementType(literal.getArray.getElementType) - .setContainsNull(true) - .build()) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.ARRAY_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case proto.Expression.Literal.LiteralTypeCase.MAP => - if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { - builder.setMap( - proto.DataType.Map - .newBuilder() - .setKeyType(literal.getMap.getKeyType) - .setValueType(literal.getMap.getValueType) - .setValueContainsNull(true) - .build()) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - if (literal.getStruct.hasStructType) { - builder.setStruct(literal.getStruct.getStructType.getStruct) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.STRUCT_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case _ => - val literalCase = literal.getLiteralTypeCase - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", - Map("typeInfo" -> s"${literalCase.name}(${literalCase.getNumber})")) - } - builder.build() + ConnectTypeOps + .getProtoDataTypeFromLiteral(literal) + .getOrElse(getProtoDataTypeDefault(literal)) } } @@ -610,6 +545,103 @@ object LiteralValueProtoConverter { dataType } + private def getProtoDataTypeDefault(literal: proto.Expression.Literal): proto.DataType = { + val builder = proto.DataType.newBuilder() + literal.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.BINARY => + builder.setBinary(proto.DataType.Binary.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => + builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.BYTE => + builder.setByte(proto.DataType.Byte.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.SHORT => + builder.setShort(proto.DataType.Short.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.INTEGER => + builder.setInteger(proto.DataType.Integer.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.LONG => + builder.setLong(proto.DataType.Long.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.FLOAT => + builder.setFloat(proto.DataType.Float.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DOUBLE => + builder.setDouble(proto.DataType.Double.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DECIMAL => + val decimal = Decimal.apply(literal.getDecimal.getValue) + var precision = decimal.precision + if (literal.getDecimal.hasPrecision) { + precision = math.max(precision, literal.getDecimal.getPrecision) + } + var scale = decimal.scale + if (literal.getDecimal.hasScale) { + scale = math.max(scale, literal.getDecimal.getScale) + } + builder.setDecimal( + proto.DataType.Decimal + .newBuilder() + .setPrecision(math.max(precision, scale)) + .setScale(scale) + .build()) + case proto.Expression.Literal.LiteralTypeCase.STRING => + builder.setString(proto.DataType.String.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DATE => + builder.setDate(proto.DataType.Date.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => + builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => + builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => + builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => + builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => + builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIME => + val timeBuilder = proto.DataType.Time.newBuilder() + if (literal.getTime.hasPrecision) { + timeBuilder.setPrecision(literal.getTime.getPrecision) + } + builder.setTime(timeBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.ARRAY => + if (literal.getArray.hasElementType) { + builder.setArray( + proto.DataType.Array + .newBuilder() + .setElementType(literal.getArray.getElementType) + .setContainsNull(true) + .build()) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.ARRAY_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case proto.Expression.Literal.LiteralTypeCase.MAP => + if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { + builder.setMap( + proto.DataType.Map + .newBuilder() + .setKeyType(literal.getMap.getKeyType) + .setValueType(literal.getMap.getValueType) + .setValueContainsNull(true) + .build()) + } else { + throw InvalidPlanInput("CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", Map.empty) + } + case proto.Expression.Literal.LiteralTypeCase.STRUCT => + if (literal.getStruct.hasStructType) { + builder.setStruct(literal.getStruct.getStructType.getStruct) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.STRUCT_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case _ => + val literalCase = literal.getLiteralTypeCase + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", + Map("typeInfo" -> s"${literalCase.name}(${literalCase.getNumber})")) + } + builder.build() + } + private def toScalaArrayInternal( literal: proto.Expression.Literal, arrayType: proto.DataType.Array): Array[_] = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectTypeOps.scala new file mode 100644 index 000000000000..3390bf1d0659 --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectTypeOps.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.common.types.ops + +import org.apache.arrow.vector.FieldVector + +import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder +import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer, ArrowVectorReader} +import org.apache.spark.sql.connect.client.arrow.types.ops.TimeTypeConnectOps +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Optional type operations for Spark Connect infrastructure. + * + * Consolidates both proto conversions (DataTypeProtoConverter, LiteralValueProtoConverter) and + * Arrow serialization/deserialization (ArrowSerializer, ArrowDeserializer, ArrowVectorReader) for + * framework-managed types. + * + * @since 4.2.0 + */ +trait ConnectTypeOps extends Serializable { + + def dataType: DataType + + def encoder: AgnosticEncoder[_] + + // ==================== Proto Conversions ==================== + + /** Converts this DataType to its Connect proto representation. */ + def toConnectProtoType: proto.DataType + + /** Converts a value to a proto literal builder (generic, no DataType context). */ + def toLiteralProto( + value: Any, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder + + /** Converts a value to a proto literal builder (with DataType context). */ + def toLiteralProtoWithType( + value: Any, + dt: DataType, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder + + /** + * Returns a converter from proto literal to Scala value. The returned converter assumes + * non-null input - null handling is done by the caller (LiteralValueProtoConverter wraps with + * `if (v.hasNull) null`). + */ + def getScalaConverter: proto.Expression.Literal => Any + + /** Returns a proto DataType inferred from a proto literal (for type inference). */ + def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): proto.DataType + + /** Converts a proto DataType to a Spark DataType (reverse of toConnectProtoType). */ + def toCatalystTypeFromProto(t: proto.DataType): DataType + + // ==================== Arrow Serialization ==================== + + /** Creates an Arrow serializer for writing values to a vector. */ + def createArrowSerializer(vector: AnyRef): ArrowSerializer.Serializer + + /** Creates an Arrow deserializer for reading values from a vector. */ + def createArrowDeserializer( + enc: AgnosticEncoder[_], + data: AnyRef, + timeZoneId: String): ArrowDeserializers.Deserializer[Any] + + /** Creates an ArrowVectorReader for this type's vector. */ + def createArrowVectorReader(vector: FieldVector): ArrowVectorReader +} + +/** + * Factory object for ConnectTypeOps lookup. + * + * Provides separate factory methods for proto (server-side, feature-flag-gated) and Arrow + * (client-side, no flag) dispatch. + */ +object ConnectTypeOps { + + // ==================== Proto Dispatch (server-side, flag-gated) ==================== + + /** DataType-keyed dispatch for proto conversions. Checks feature flag. */ + def apply(dt: DataType): Option[ConnectTypeOps] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + dt match { + case tt: TimeType => Some(new TimeTypeConnectOps(tt)) + // Add new framework types here + case _ => None + } + } + + /** Reverse lookup by value class for the generic literal builder. Checks feature flag. */ + def toLiteralProtoForValue( + value: Any, + builder: proto.Expression.Literal.Builder): Option[proto.Expression.Literal.Builder] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + value match { + case v: java.time.LocalTime => + Some(new TimeTypeConnectOps(TimeType()).toLiteralProto(v, builder)) + // Add new framework value types here + case _ => None + } + } + + /** + * Shared KindCase -> ConnectTypeOps lookup. All reverse lookups by proto enum case dispatch + * through this single registration point. Checks feature flag. + */ + private def opsForKindCase(kindCase: proto.DataType.KindCase): Option[ConnectTypeOps] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + kindCase match { + case proto.DataType.KindCase.TIME => Some(new TimeTypeConnectOps(TimeType())) + // Add new framework proto kinds here - single registration for all KindCase lookups + case _ => None + } + } + + /** Reverse lookup: converts a proto DataType to a Spark DataType. Checks feature flag. */ + def toCatalystType(t: proto.DataType): Option[DataType] = + opsForKindCase(t.getKindCase).map(_.toCatalystTypeFromProto(t)) + + /** Reverse lookup: returns a Scala converter for a proto literal KindCase. */ + def getScalaConverterForKind( + kindCase: proto.DataType.KindCase): Option[proto.Expression.Literal => Any] = + opsForKindCase(kindCase).map(_.getScalaConverter) + + /** Reverse lookup: returns the proto DataType inferred from a proto literal. */ + def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): Option[proto.DataType] = + opsForKindCase(literalCaseToKindCase(literal.getLiteralTypeCase)) + .map(_.getProtoDataTypeFromLiteral(literal)) + + /** Maps LiteralTypeCase to KindCase (1:1 for framework types). */ + private def literalCaseToKindCase( + litCase: proto.Expression.Literal.LiteralTypeCase): proto.DataType.KindCase = + litCase match { + case proto.Expression.Literal.LiteralTypeCase.TIME => proto.DataType.KindCase.TIME + // Add new framework literal-to-kind mappings here + case _ => proto.DataType.KindCase.KIND_NOT_SET + } + + // ==================== Arrow Dispatch (client-side, NO flag check) ==================== + + /** Encoder-keyed dispatch for Arrow serialization. No feature flag check. */ + def forEncoder(enc: AgnosticEncoder[_]): Option[ConnectTypeOps] = + enc match { + case LocalTimeEncoder => Some(new TimeTypeConnectOps(TimeType())) + // Add new framework encoders here + case _ => None + } + + /** DataType-keyed dispatch for ArrowVectorReader. No feature flag check. */ + def forDataType(dt: DataType): Option[ConnectTypeOps] = + dt match { + case tt: TimeType => Some(new TimeTypeConnectOps(tt)) + // Add new framework types here + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 8a666bbb9dad..927227325fbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTab import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.ArrayImplicits._ @@ -112,6 +113,17 @@ object HiveResult extends SQLConfHelper { formatters: TimeFormatters, binaryFormatter: BinaryFormatter): String = a match { case (null, _) => if (nested) "null" else "NULL" + case (value, dt) => + TypeApiOps(dt).flatMap(_.formatExternal(value, nested)).getOrElse { + toHiveStringDefault(a, nested, formatters, binaryFormatter) + } + } + + private def toHiveStringDefault( + a: (Any, DataType), + nested: Boolean, + formatters: TimeFormatters, + binaryFormatter: BinaryFormatter): String = a match { case (b, BooleanType) => b.toString case (d: Date, DateType) => formatters.date.format(d) case (ld: LocalDate, DateType) => formatters.date.format(ld) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 33622ca7349a..6a9b4978e27b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData, STUtils} import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String, VariantVal} object EvaluatePython { @@ -41,7 +42,11 @@ object EvaluatePython { */ private[python] class BytesWrapper(val data: Array[Byte]) - def needConversionInPython(dt: DataType): Boolean = dt match { + def needConversionInPython(dt: DataType): Boolean = + TypeApiOps(dt).flatMap(_.needConversionInPython) + .getOrElse(needConversionInPythonDefault(dt)) + + private def needConversionInPythonDefault(dt: DataType): Boolean = dt match { case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType | _: TimeType | _: GeometryType | _: GeographyType => true case _: StructType => true @@ -111,7 +116,10 @@ object EvaluatePython { * Make a converter that converts `obj` to the type specified by the data type, or returns * null if the type of obj is unexpected. Because Python doesn't enforce the type. */ - def makeFromJava(dataType: DataType): Any => Any = dataType match { + def makeFromJava(dataType: DataType): Any => Any = + TypeApiOps(dataType).flatMap(_.makeFromJava).getOrElse(makeFromJavaDefault(dataType)) + + private def makeFromJavaDefault(dataType: DataType): Any => Any = dataType match { case BooleanType => (obj: Any) => nullSafeConvert(obj) { case b: Boolean => b } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index cd2bdefcc306..46302b316b75 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution} import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.util.{Utils => SparkUtils} private[hive] class SparkExecuteStatementOperation( @@ -326,7 +327,11 @@ private[hive] class SparkExecuteStatementOperation( object SparkExecuteStatementOperation { - def toTTypeId(typ: DataType): TTypeId = typ match { + def toTTypeId(typ: DataType): TTypeId = + TypeApiOps(typ).flatMap(_.thriftTypeName).map(TTypeId.valueOf) + .getOrElse(toTTypeIdDefault(typ)) + + private def toTTypeIdDefault(typ: DataType): TTypeId = typ match { case NullType => TTypeId.NULL_TYPE case BooleanType => TTypeId.BOOLEAN_TYPE case ByteType => TTypeId.TINYINT_TYPE