Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.types.ops

import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{DataType, TimeType}

/**
* Optional client-side type operations for the Types Framework.
*
* This trait extends TypeApiOps with operations needed by client-facing infrastructure: Arrow
* conversion (ArrowUtils), Python interop (EvaluatePython), Hive formatting (HiveResult), and
* Thrift type mapping (SparkExecuteStatementOperation).
*
* Lives in sql/api so it's visible from sql/core and sql/hive-thriftserver.
*
* USAGE - integration points use ClientTypeOps(dt) which returns Option[ClientTypeOps]:
* {{{
* // Forward lookup (most files):
* ClientTypeOps(dt).map(_.toArrowType(timeZoneId)).getOrElse { ... }
*
* // Reverse lookup (ArrowUtils.fromArrowType):
* ClientTypeOps.fromArrowType(at).getOrElse { ... }
* }}}
*
* @see
* TimeTypeApiOps for a reference implementation
* @since 4.2.0
*/
trait ClientTypeOps { self: TypeApiOps =>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeApiOps and ClientTypeOps are both for client side? shall we merge them into one? I'm ok to leave additional work for followup PRs, but not commit a draft version and rewrite it in followup.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. It's also true for CatalystTypeOps and TypeOps who are both on the catalyst side. However, in both cases, one is optional and the other is mandatory. I thought of merging it like this to enable functionality, and follow-up with a refactor, where I would introduce the concept of optional functions (returning Options) in the ops classes, since it's a framework design, rather than enabling functionality.
If you think however that we should do it here, I can, I was just explaining my reasoning.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to go with optional methods (return None by default), let's do it now and abandon the optional trait approach.


// ==================== Utilities ====================

/**
* Null-safe conversion helper. Returns null for null input, applies the partial function for
* non-null input, and returns null for unmatched values.
*/
protected def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = {
if (input == null) {
null
} else {
f.applyOrElse(input, (_: Any) => null)
}
}

// ==================== Arrow Conversion ====================

/**
* Converts this DataType to its Arrow representation.
*
* Used by ArrowUtils.toArrowType.
*
* @param timeZoneId
* the session timezone (needed by some temporal types)
* @return
* the corresponding ArrowType
*/
def toArrowType(timeZoneId: String): ArrowType

// ==================== Python Interop ====================

/**
* Returns true if values of this type need conversion when passed to/from Python.
*
* Used by EvaluatePython.needConversionInPython.
*/
def needConversionInPython: Boolean

/**
* Creates a converter function for Python/Py4J interop.
*
* Used by EvaluatePython.makeFromJava. The returned function handles null-safe conversion of
* Java/Py4J values to the internal Catalyst representation.
*
* @return
* a function that converts a Java value to the internal representation
*/
def makeFromJava: Any => Any

// ==================== Hive Formatting ====================

/**
* Formats an external-type value for Hive output.
*
* Used by HiveResult.toHiveString. The input is an external-type value (e.g.,
* java.time.LocalTime for TimeType), NOT the internal representation. Most types override this
* simple version. Types that need different formatting when nested (e.g., quoting) should
* override the 2-param overload instead.
*/
def formatExternal(value: Any): String

/**
* Formats an external-type value for Hive output with nesting context. Default delegates to the
* simple version. Override if nesting affects formatting.
*/
def formatExternal(value: Any, nested: Boolean): String = formatExternal(value)

// ==================== Thrift Mapping ====================

/**
* Returns the Thrift TTypeId name for this type.
*
* Used by SparkExecuteStatementOperation.toTTypeId. Returns a String that maps to a TTypeId
* enum value (e.g., "STRING_TYPE") since TTypeId is only available in the hive-thriftserver
* module.
*
* @return
* TTypeId enum name (e.g., "STRING_TYPE")
*/
def thriftTypeName: String
}

/**
* Factory object for ClientTypeOps lookup.
*
* Delegates to TypeApiOps and narrows via collect to find implementations that mix in
* ClientTypeOps.
*/
object ClientTypeOps {

/**
* Returns a ClientTypeOps instance for the given DataType, if available.
*
* @param dt
* the DataType to get operations for
* @return
* Some(ClientTypeOps) if supported, None otherwise
*/
// Delegates to TypeApiOps and narrows: a type must implement TypeApiOps AND mix in
// ClientTypeOps to be found here. No separate registration needed - the collect
// filter handles incremental trait adoption automatically.
def apply(dt: DataType): Option[ClientTypeOps] =
TypeApiOps(dt).collect { case co: ClientTypeOps => co }

/**
* Reverse lookup: converts an Arrow type to a Spark DataType, if it belongs to a
* framework-managed type.
*
* Used by ArrowUtils.fromArrowType. Returns None if the Arrow type doesn't correspond to any
* framework-managed type, or the framework is disabled.
*
* @param at
* the ArrowType to convert
* @return
* Some(DataType) if recognized, None otherwise
*/
def fromArrowType(at: ArrowType): Option[DataType] = {
import org.apache.arrow.vector.types.TimeUnit
if (!SqlApiConf.get.typesFrameworkEnabled) return None
at match {
case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 =>
Some(TimeType(TimeType.MICROS_PRECISION))
// Add new framework types here
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.sql.types.ops

import java.time.LocalTime

import org.apache.arrow.vector.types.TimeUnit
import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder
import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, TimeFormatter}
import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, SparkDateTimeUtils, TimeFormatter}
import org.apache.spark.sql.types.{DataType, TimeType}

/**
Expand All @@ -29,14 +34,21 @@ import org.apache.spark.sql.types.{DataType, TimeType}
* - String formatting: uses FractionTimeFormatter for consistent output
* - Row encoding: uses LocalTimeEncoder for java.time.LocalTime
*
* Additionally, it implements ClientTypeOps for:
* - Arrow conversion (ArrowUtils)
* - JDBC mapping (JdbcUtils)
* - Python interop (EvaluatePython)
* - Hive formatting (HiveResult)
* - Thrift type mapping (SparkExecuteStatementOperation)
*
* RELATIONSHIP TO TimeTypeOps: TimeTypeOps (in catalyst package) extends this class to inherit
* client-side operations while adding server-side operations (physical type, literals, etc.).
*
* @param t
* The TimeType with precision information
* @since 4.2.0
*/
class TimeTypeApiOps(val t: TimeType) extends TypeApiOps {
class TimeTypeApiOps(val t: TimeType) extends TypeApiOps with ClientTypeOps {

override def dataType: DataType = t

Expand All @@ -56,4 +68,26 @@ class TimeTypeApiOps(val t: TimeType) extends TypeApiOps {
// ==================== Row Encoding ====================

override def getEncoder: AgnosticEncoder[_] = LocalTimeEncoder

// ==================== Client Type Operations (ClientTypeOps) ====================

override def toArrowType(timeZoneId: String): ArrowType = {
new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8)
}

override def needConversionInPython: Boolean = true

override def makeFromJava: Any => Any = (obj: Any) =>
nullSafeConvert(obj) {
case c: Long => c
// Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
case c: Int => c.toLong
}

override def formatExternal(value: Any): String = {
val nanos = SparkDateTimeUtils.localTimeToNanos(value.asInstanceOf[LocalTime])
timeFormatter.format(nanos)
}

override def thriftTypeName: String = "STRING_TYPE"
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.SparkException
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.ops.ClientTypeOps
import org.apache.spark.util.ArrayImplicits._

private[sql] object ArrowUtils {
Expand All @@ -39,6 +40,14 @@ private[sql] object ArrowUtils {

/** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType =
ClientTypeOps(dt)
.map(_.toArrowType(timeZoneId))
.getOrElse(toArrowTypeDefault(dt, timeZoneId, largeVarTypes))

private def toArrowTypeDefault(
dt: DataType,
timeZoneId: String,
largeVarTypes: Boolean): ArrowType =
dt match {
case BooleanType => ArrowType.Bool.INSTANCE
case ByteType => new ArrowType.Int(8, true)
Expand Down Expand Up @@ -67,7 +76,10 @@ private[sql] object ArrowUtils {
throw ExecutionErrors.unsupportedDataTypeError(dt)
}

def fromArrowType(dt: ArrowType): DataType = dt match {
def fromArrowType(dt: ArrowType): DataType =
ClientTypeOps.fromArrowType(dt).getOrElse(fromArrowTypeDefault(dt))

private def fromArrowTypeDefault(dt: ArrowType): DataType = dt match {
case ArrowType.Bool.INSTANCE => BooleanType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils, STUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -289,6 +290,16 @@ object DeserializerBuildHelper {
* @param isTopLevel true if we are creating a deserializer for the top level value.
*/
private def createDeserializer(
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath,
isTopLevel: Boolean = false): Expression =
// Framework dispatch runs before encoder-type checks in the default path. Safe because
// framework types use dedicated leaf encoders, never migration shims or native primitives.
CatalystTypeOps(enc.dataType).map(_.createDeserializer(path, walkedTypePath, isTopLevel))
.getOrElse(createDeserializerDefault(enc, path, walkedTypePath, isTopLevel))

private def createDeserializerDefault(
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps
import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils, STUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -332,7 +333,15 @@ object SerializerBuildHelper {
* representation. The mapping between the external and internal representations is described
* by encoder `enc`.
*/
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression =
// Framework dispatch runs before encoder-type checks (AgnosticExpressionPathEncoder,
// isNativeEncoder) in the default path. This is safe because framework types use dedicated
// leaf encoders (e.g., LocalTimeEncoder), never migration shims or native primitives.
CatalystTypeOps(enc.dataType).map(_.createSerializer(input))
.getOrElse(createSerializerDefault(enc, input))

private def createSerializerDefault(
enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input)
case _ if isNativeEncoder(enc) => input
case BoxedBooleanEncoder => createSerializerForBoolean(input)
Expand Down
Loading