Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49960][SQL] Provide extension point for custom AgnosticEncoder serde #48477

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
Expand Down Expand Up @@ -270,6 +270,8 @@ object DeserializerBuildHelper {
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] =>
ae.fromCatalyst(path)
case _ if isNativeEncoder(enc) =>
path
case _: BoxedLeafEncoder[_, _] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.language.existentials

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
Expand Down Expand Up @@ -306,6 +306,7 @@ object SerializerBuildHelper {
* by encoder `enc`.
*/
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input)
case _ if isNativeEncoder(enc) => input
case BoxedBooleanEncoder => createSerializerForBoolean(input)
case BoxedByteEncoder => createSerializerForByte(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders

import scala.collection.Map

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder}
import org.apache.spark.sql.catalyst.expressions.Expression
Expand All @@ -26,6 +27,29 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}

/**
* :: DeveloperApi ::
* Extensible [[AgnosticEncoder]] providing conversion extension points over type T
* @tparam T over T
*/
@DeveloperApi
trait AgnosticExpressionPathEncoder[T]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chris-twiner can you give me an example of what you exactly are missing from the agnostic encoder framework. I'd rather solve this problem at that level than create an escape hatch to raw catalyst expressions. I am not saying that we should not do this, but I'd like to have a (small) discussion first.

My rationale for pushing for agnostic encoders is that I want to create a situation where the Classic and Connect Spark SQL interfaces are on par. Catalyst bespoke encoders - sort of - defeat that.

Copy link
Contributor Author

@chris-twiner chris-twiner Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell thanks for getting back to me, per the JIRA this is existing pre 4 functionality that is no longer fully working.
Frameless for example uses an extensible encoder derivation with injection/ADT support to provide type safe usage at compile time. Quality for example uses injections to store a result ADT efficiently, this SO has a similar often occurring example that can be solved. Lastly as the inbuilt encoders are not extensible you can bump into issues of it's derivation limitation (java.util.Calendar for example).

wrt to fully a unified interface impl, that's understood but this change is a minimal requirement to re-enable frameless style usage. I don't have any direct way to provide parity for connect yet (although your unification work provides a clear basis), I track it under frameless #701, although to go further down that route I'd also need custom expressions support in connect (but that's off topic and I know it's there to be used).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense. However, I do want to call out that this is mostly internal API; we do not guarantee any compatibility between (minor) releases. For that reason, historically, most spark libraries have to create per spark version releases. The issue here IMO falls in that category.

I understand that this is a somewhat frustrating and impractical stance. I am open to having this interface for now, provided that in the future we can migrate towards AgnosticEncoders. The latter probably requires us to add additional encoders to the agnostic framework (e.g. an encoder for union types...).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrt internal - very much understood, it's the price paid for the functionality and performance gains, as I target Databricks as well there is yet more fun - hence shim's complicated version support

Copy link
Contributor

@hvanhovell hvanhovell Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(deleted my previous comment) I thought GH had lost it....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Databricks compatibility something you want, then Agnostic Encoders are your friend.

extends AgnosticEncoder[T] {
/**
* Converts from T to InternalRow
* @param input the starting input path
* @return
*/
def toCatalyst(input: Expression): Expression

/**
* Converts from InternalRow to T
* @param inputPath path expression from InternalRow
* @return
*/
def fromCatalyst(inputPath: Expression): Expression
}

/**
* Helper class for Generating [[ExpressionEncoder]]s.
*/
Expand Down
55 changes: 52 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ import org.apache.spark.TestUtils.withListener
import org.apache.spark.internal.config.MAX_RESULT_SIZE
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoders, ExpressionEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BoxedIntEncoder
import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.createDeserializerForString
import org.apache.spark.sql.catalyst.SerializerBuildHelper.createSerializerForString
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoders, AgnosticExpressionPathEncoder, ExpressionEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, ProductEncoder}
import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, Expression, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext
import org.apache.spark.sql.catalyst.util.sideBySide
Expand Down Expand Up @@ -2803,6 +2805,21 @@ class DatasetSuite extends QueryTest
}
}

test("SPARK-49960: joinWith custom encoder") {
/*
test based on "joinWith class with primitive, toDF"
with "custom" encoder. Removing the use of AgnosticExpressionPathEncoder
within SerializerBuildHelper and DeserializerBuildHelper will trigger MatchErrors
*/
val ds1 = Seq(1, 1, 2).toDS()
val ds2 = SparkSession.active.createDataset[ClassData](Seq(ClassData("a", 1),
ClassData("b", 2)))(CustomPathEncoder.custClassDataEnc)

checkAnswer(
ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"),
Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil)
}

test("SPARK-49961: transform type should be consistent (classic)") {
val ds = Seq(1, 2).toDS()
val f: classic.Dataset[Int] => classic.Dataset[Int] =
Expand All @@ -2828,6 +2845,38 @@ class DatasetSuite extends QueryTest
}
}

/**
* SPARK-49960 - Mimic a custom encoder such as those provided by typelevel Frameless
*/
object CustomPathEncoder {

val realClassDataEnc: ProductEncoder[ClassData] =
Encoders.product[ClassData].asInstanceOf[ProductEncoder[ClassData]]

val custStringEnc: AgnosticExpressionPathEncoder[String] =
new AgnosticExpressionPathEncoder[String] {

override def toCatalyst(input: Expression): Expression =
createSerializerForString(input)

override def fromCatalyst(inputPath: Expression): Expression =
createDeserializerForString(inputPath, returnNullable = false)

override def isPrimitive: Boolean = false

override def dataType: DataType = StringType

override def clsTag: ClassTag[String] = implicitly[ClassTag[String]]

override def isStruct: Boolean = true
}

val custClassDataEnc: ProductEncoder[ClassData] = realClassDataEnc.copy(fields =
Seq(realClassDataEnc.fields.head.copy(enc = custStringEnc),
realClassDataEnc.fields.last)
)
}

class DatasetLargeResultCollectingSuite extends QueryTest
with SharedSparkSession {

Expand Down