diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala index 2795911da3..042fd9ced3 100644 --- a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala @@ -261,7 +261,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim { val subExprsCode = ctx.subexprFunctionsCode val (cls, setup, snippet) = CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx) - (cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode)) + (cls, setup, defaultBody(boundExpr, inputSchema, ev, snippet, subExprsCode)) } val typedFieldDecls = CometBatchKernelCodegenInput.emitInputFieldDecls(inputSchema) @@ -343,6 +343,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim { */ private def defaultBody( boundExpr: Expression, + inputSchema: Seq[ArrowColumnSpec], ev: ExprCode, writeSnippet: String, subExprsCode: String): String = { @@ -353,9 +354,17 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim { // make this incorrect (`coalesce(null, x)` is `x`); `allNullIntolerant` rejects those. val inputOrdinals = boundExpr.collect { case b: BoundReference => b.ordinal }.distinct + // Primitive Arrow vectors are wrapped in `CometPlainVector` at input-cast time, which + // exposes `isNullAt(int)` rather than the raw Arrow `isNull(int)`. Pick the right method + // per ordinal so the short-circuit compiles for timestamp / int / float columns too, + // not just VarChar / Decimal vectors that stay as raw Arrow types. + def nullCheckCall(ord: Int): String = { + val method = CometBatchKernelCodegenInput.nullCheckMethod(inputSchema(ord)) + s"this.col$ord.$method(i)" + } val nullCheck = if (inputOrdinals.isEmpty) "false" - else inputOrdinals.map(ord => s"this.col$ord.isNull(i)").mkString(" || ") + else inputOrdinals.map(nullCheckCall).mkString(" || ") s""" |if ($nullCheck) { | output.setNull(i); diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala index 79a2af6837..74e4881de0 100644 --- a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala @@ -404,8 +404,10 @@ private[codegen] object CometBatchKernelCodegenInput { /** * Java method name for the per-column null check. Primitive scalars wrapped in * [[CometPlainVector]] expose `isNullAt`; Arrow typed fields expose `isNull`. Same semantics. + * Used both by `emitTypedGetters` (for the kernel's `isNullAt` switch) and by + * `CometBatchKernelCodegen.defaultBody` (for the `NullIntolerant` short-circuit). */ - private def nullCheckMethod(spec: ArrowColumnSpec): String = spec match { + def nullCheckMethod(spec: ArrowColumnSpec): String = spec match { case sc: ScalarColumnSpec if wrapsInCometPlainVector(sc.vectorClass) => "isNullAt" case _ => "isNull" } diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala index bf636f7221..010e3dd402 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala @@ -20,7 +20,7 @@ package org.apache.comet.serde import org.apache.spark.SparkEnv -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Literal, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Expression, Literal, ScalaUDF} import org.apache.spark.sql.types.BinaryType import org.apache.comet.CometConf @@ -45,15 +45,35 @@ import org.apache.comet.udf.codegen.CometScalaUDFCodegen * * Gated by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When disabled, plans containing a * `ScalaUDF` fall back to Spark for the enclosing operator. + * + * [[emitJvmCodegenDispatch]] exposes the same closure-serialize + dispatcher-proto path to other + * serdes that want to keep a built-in Spark expression inside the Comet pipeline when no native + * lowering is viable. See [[CometDateFormat]] for an example. */ object CometScalaUDF extends CometExpressionSerde[ScalaUDF] { - override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = + emitJvmCodegenDispatch(expr, inputs, binding) + + /** + * Bind `expr`, closure-serialize it, and emit a `JvmScalarUdf` proto routed through + * [[CometScalaUDFCodegen]] so that native execution evaluates the expression inside the + * Arrow-direct codegen dispatcher. The dispatcher will Janino-compile `expr.doGenCode` into a + * batch kernel on first invocation per task. + * + * Returns `None` (with `withInfo` tagging the reason) when the dispatcher is disabled via + * [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]] or when [[CometBatchKernelCodegen.canHandle]] + * refuses the expression tree. Callers should treat `None` as a clean Spark-fallback signal. + */ + def emitJvmCodegenDispatch( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) { withInfo( expr, - s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; ScalaUDF has no native path " + - "so the plan falls back to Spark") + s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; expression has no native " + + "path so the plan falls back to Spark") return None } diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index b57b1e4e56..24ca862fb9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType} import org.apache.spark.unsafe.types.UTF8String +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.CometGetDateField.CometGetDateField @@ -593,17 +594,23 @@ object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] { } /** - * Converts Spark DateFormatClass expression to DataFusion's to_char function. + * Converts Spark `DateFormatClass` to DataFusion's `to_char` when format and timezone are + * mappable, otherwise routes the expression through the Arrow-direct codegen dispatcher so that + * Spark's own `DateFormatClass.doGenCode` runs inside the Comet pipeline. * - * Spark uses Java SimpleDateFormat patterns while DataFusion uses strftime patterns. This - * implementation supports a whitelist of common format strings that can be reliably mapped - * between the two systems. + * Routing: + * - format is a literal in `supportedFormats` AND timezone is UTC -> native `to_char` + * - format is a literal in `supportedFormats` AND timezone is non-UTC, with the per-expression + * `allowIncompatible` flag set -> native `to_char` (results may differ from Spark) + * - all other cases -> JVM codegen dispatcher ([[CometScalaUDF.emitJvmCodegenDispatch]]), gated + * by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When that flag is disabled the operator + * falls back to Spark. */ object CometDateFormat extends CometExpressionSerde[DateFormatClass] { /** * Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map - * are supported. + * are supported by the native path. */ val supportedFormats: Map[String, String] = Map( // Full date formats @@ -637,66 +644,50 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] { // ISO formats "yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S") - override def getIncompatibleReasons(): Seq[String] = Seq( - "Non-UTC timezones may produce different results than Spark") + // Compatibility is decided inside `convert`: the native path covers a subset, and the codegen + // dispatcher covers everything else when enabled. Plan-time tagging happens via `withInfo` on + // the path that returns None. + override def getSupportLevel(expr: DateFormatClass): SupportLevel = Compatible() - override def getUnsupportedReasons(): Seq[String] = Seq( - "Only the following formats are supported:" + - supportedFormats.keys.toSeq.sorted - .map(k => s"`$k`") - .mkString("\n - ", "\n - ", "")) - - override def getSupportLevel(expr: DateFormatClass): SupportLevel = { - // Check timezone - only UTC is fully compatible - val timezone = expr.timeZoneId.getOrElse("UTC") - val isUtc = timezone == "UTC" || timezone == "Etc/UTC" - - expr.right match { - case Literal(fmt: UTF8String, _) => - val format = fmt.toString - if (supportedFormats.contains(format)) { - if (isUtc) { - Compatible() - } else { - Incompatible(Some(s"Non-UTC timezone '$timezone' may produce different results")) - } - } else { - Unsupported( - Some( - s"Format '$format' is not supported. Supported formats: " + - supportedFormats.keys.mkString(", "))) - } - case _ => - Unsupported(Some("Only literal format strings are supported")) - } - } + override def getCompatibleNotes(): Seq[String] = Seq( + "Format strings in a curated allow-list run natively via DataFusion's `to_char` for UTC " + + "sessions. Other format strings (including non-literal formats), as well as non-UTC " + + "sessions, route through Spark's own `DateFormatClass.doGenCode` via the Arrow-direct " + + "codegen dispatcher when `spark.comet.exec.scalaUDF.codegen.enabled=true`. When the " + + "codegen dispatcher is disabled (default) the operator falls back to Spark in those " + + "cases.") override def convert( expr: DateFormatClass, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - // Get the format string - must be a literal for us to map it - val strftimeFormat = expr.right match { - case Literal(fmt: UTF8String, _) => - supportedFormats.get(fmt.toString) + val timezone = expr.timeZoneId.getOrElse("UTC") + val isUtc = timezone == "UTC" || timezone == "Etc/UTC" + + val nativeFormat: Option[String] = expr.right match { + case Literal(fmt: UTF8String, _) => supportedFormats.get(fmt.toString) case _ => None } - strftimeFormat match { - case Some(format) => - val childExpr = exprToProtoInternal(expr.left, inputs, binding) - val formatExpr = exprToProtoInternal(Literal(format), inputs, binding) - - val optExpr = scalarFunctionExprToProtoWithReturnType( - "to_char", - StringType, - false, - childExpr, - formatExpr) - optExprWithInfo(optExpr, expr, expr.left, expr.right) - case None => - withInfo(expr, expr.left, expr.right) - None + val canUseNative = nativeFormat.isDefined && { + isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr)) + } + + if (canUseNative) { + val childExpr = exprToProtoInternal(expr.left, inputs, binding) + val formatExpr = exprToProtoInternal(Literal(nativeFormat.get), inputs, binding) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "to_char", + StringType, + false, + childExpr, + formatExpr) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } else { + // Hand the full `DateFormatClass` (with `timeZoneId` already stamped by `ResolveTimeZone`) + // to the codegen dispatcher. It closure-serializes the bound tree, so non-UTC timezones + // and non-whitelisted / non-literal format strings produce Spark-identical results. + CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding) } } } diff --git a/spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql b/spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql index 09333f44d3..dec690cb6a 100644 --- a/spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql +++ b/spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql @@ -15,21 +15,27 @@ -- specific language governing permissions and limitations -- under the License. +-- Pin the session timezone so the test exercises the non-UTC path regardless of the JVM +-- default. Enable the codegen dispatcher so non-UTC and non-whitelisted formats stay inside +-- Comet via Spark's own DateFormatClass.doGenCode instead of falling back to Spark. +-- Config: spark.sql.session.timeZone=America/Los_Angeles +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=true + statement CREATE TABLE test_date_format(ts timestamp) USING parquet statement INSERT INTO test_date_format VALUES (timestamp('2024-06-15 10:30:45')), (timestamp('1970-01-01 00:00:00')), (NULL) -query expect_fallback(Non-UTC timezone) +query SELECT date_format(ts, 'yyyy-MM-dd') FROM test_date_format -query expect_fallback(Non-UTC timezone) +query SELECT date_format(ts, 'HH:mm:ss') FROM test_date_format -query expect_fallback(Non-UTC timezone) +query SELECT date_format(ts, 'yyyy-MM-dd HH:mm:ss') FROM test_date_format -- literal arguments -query expect_fallback(Non-UTC timezone) +query SELECT date_format(timestamp('2024-06-15 10:30:45'), 'yyyy-MM-dd') diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala index 27a5830c6d..274b70bce1 100644 --- a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala @@ -22,9 +22,10 @@ package org.apache.comet import org.scalatest.funsuite.AnyFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Add, BoundReference, Coalesce, Concat, CreateArray, CreateMap, ElementAt, Expression, GetStructField, LeafExpression, Length, Literal, Nondeterministic, Rand, Size, Unevaluable, Upper} +import org.apache.spark.sql.catalyst.expressions.{Add, BoundReference, Coalesce, Concat, CreateArray, CreateMap, DateFormatClass, ElementAt, Expression, GetStructField, LeafExpression, Length, Literal, Nondeterministic, Rand, Size, Unevaluable, Upper} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.codegen.CometBatchKernelCodegen import org.apache.comet.codegen.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec, StructFieldSpec} @@ -61,6 +62,26 @@ class CometCodegenSourceSuite extends AnyFunSuite { specs: ArrowColumnSpec*): String = CometBatchKernelCodegen.generateSource(expr, specs.toIndexedSeq).body + test("NullIntolerant short-circuit uses isNullAt for CometPlainVector-wrapped columns") { + // Primitive Arrow vectors (timestamp / int / float / ...) are wrapped in `CometPlainVector` + // at input-cast time. The short-circuit must call `isNullAt(i)`, not `isNull(i)`, otherwise + // Janino fails to compile the kernel with "method isNull not declared". Verified end-to-end + // by `CometTemporalExpressionSuite` date_format tests over `TimeStampMicroTZVector` inputs. + val tsVec = CometBatchKernelCodegen.vectorClassBySimpleName("TimeStampMicroTZVector") + val spec = ArrowColumnSpec(tsVec, nullable = true) + val expr = DateFormatClass( + BoundReference(0, TimestampType, nullable = true), + Literal(UTF8String.fromString("yyyy-MM-dd EEEE"), StringType), + Some("UTC")) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)).body + assert( + src.contains("if (this.col0.isNullAt(i))"), + s"expected short-circuit to use isNullAt for CometPlainVector-wrapped col0; got:\n$src") + assert( + !src.contains("if (this.col0.isNull(i))"), + s"expected no raw Arrow isNull on the CometPlainVector-wrapped col0; got:\n$src") + } + test("non-nullable column emits literal-false isNullAt case") { val expr = Length(BoundReference(0, StringType, nullable = false)) val src = gen(expr, nonNullableString) diff --git a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala index a8147089d9..20ad90a91c 100644 --- a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala @@ -214,26 +214,21 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH } test("date_format - timestamp_ntz input") { - // TimestampNTZ is timezone-independent, so date_format should produce the same - // formatted string regardless of session timezone. Comet currently only runs this - // natively for UTC; for non-UTC it falls back to Spark. We verify correctness - // (matching Spark's output) in all cases. + // TimestampNTZ is timezone-independent, so date_format must produce the same string + // regardless of session timezone. With the codegen dispatcher enabled, non-UTC sessions + // stay in Comet by running Spark's own `DateFormatClass.doGenCode` via the dispatcher. val r = new Random(42) val ntzSchema = StructType(Seq(StructField("ts_ntz", DataTypes.TimestampNTZType, true))) val ntzDF = FuzzDataGenerator.generateDataFrame(r, spark, ntzSchema, 100, DataGenOptions()) ntzDF.createOrReplaceTempView("ntz_tbl") val supportedFormats = CometDateFormat.supportedFormats.keys.toSeq.filterNot(_.contains("'")) - for (tz <- crossTimezones) { - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { - for (format <- supportedFormats) { - if (tz == "UTC") { + withSQLConf(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") { + for (tz <- crossTimezones) { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { + for (format <- supportedFormats) { checkSparkAnswerAndOperator( s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz") - } else { - // Non-UTC falls back to Spark but should still produce correct results - checkSparkAnswer( - s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz") } } } @@ -476,34 +471,62 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH } } - test("date_format unsupported format falls back to Spark") { + test("date_format unsupported format routes via codegen dispatcher") { createTimestampTestData.createOrReplaceTempView("tbl") - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { - // Unsupported format string + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") { + checkSparkAnswerAndOperator( + "SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0") + } + } + + test("date_format unsupported format falls back when codegen dispatcher disabled") { + createTimestampTestData.createOrReplaceTempView("tbl") + + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") { checkSparkAnswerAndFallbackReason( "SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0", - "Format 'yyyy-MM-dd EEEE' is not supported") + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key) } } - test("date_format with non-UTC timezone falls back to Spark") { + test("date_format with non-UTC timezone routes via codegen dispatcher") { createTimestampTestData.createOrReplaceTempView("tbl") val nonUtcTimezones = Seq("America/New_York", "America/Los_Angeles", "Europe/London", "Asia/Tokyo") for (tz <- nonUtcTimezones) { - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { - // Non-UTC timezones should fall back to Spark as Incompatible + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz, + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") { + checkSparkAnswerAndOperator( + "SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0") + } + } + } + + test("date_format with non-UTC timezone falls back when codegen dispatcher disabled") { + createTimestampTestData.createOrReplaceTempView("tbl") + + val nonUtcTimezones = Seq("America/New_York", "Europe/London") + + for (tz <- nonUtcTimezones) { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz, + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") { checkSparkAnswerAndFallbackReason( "SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0", - s"Non-UTC timezone '$tz' may produce different results") + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key) } } } - test("date_format with non-UTC timezone works when allowIncompatible is enabled") { + test("date_format with non-UTC timezone takes native path when allowIncompatible is enabled") { createTimestampTestData.createOrReplaceTempView("tbl") val nonUtcTimezones = Seq("America/New_York", "Europe/London", "Asia/Tokyo") @@ -511,10 +534,13 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH for (tz <- nonUtcTimezones) { withSQLConf( SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz, - "spark.comet.expr.DateFormatClass.allowIncompatible" -> "true") { - // With allowIncompatible enabled, Comet will execute the expression - // Results may differ from Spark but should not throw errors - checkSparkAnswer("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl order by c0") + "spark.comet.expression.DateFormatClass.allowIncompatible" -> "true") { + // Native to_char results may diverge from Spark for non-UTC timezones (the reason the + // JVM UDF is the default), so we only check that execution stays inside Comet. ORDER BY + // is omitted to keep the plan free of AQEShuffleRead. + val df = sql("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl") + df.collect() + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan)) } } }