Skip to content

Commit d72a4b5

Browse files
committed
Improve ANSI fallback
1 parent 1b344de commit d72a4b5

File tree

13 files changed

+89
-49
lines changed

13 files changed

+89
-49
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -600,14 +600,11 @@ object CometConf extends ShimCometConf {
600600
.toSequence
601601
.createWithDefault(Seq("Range,InMemoryTableScan"))
602602

603-
val COMET_ANSI_MODE_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.ansi.enabled")
603+
val COMET_IGNORE_ANSI_MODE: ConfigEntry[Boolean] = conf("spark.comet.ansi.ignore")
604604
.internal()
605-
.doc(
606-
"Comet does not respect ANSI mode in most cases and by default will not accelerate " +
607-
"queries when ansi mode is enabled. Enable this setting to test Comet's experimental " +
608-
"support for ANSI mode. This should not be used in production.")
605+
.doc("Internal config to avoid falling back to Spark when ANSI is enabled. Used for testing.")
609606
.booleanConf
610-
.createWithDefault(COMET_ANSI_MODE_ENABLED_DEFAULT)
607+
.createWithDefault(false)
611608

612609
val COMET_CASE_CONVERSION_ENABLED: ConfigEntry[Boolean] =
613610
conf("spark.comet.caseConversion.enabled")

common/src/main/spark-3.x/org/apache/comet/shims/ShimCometConf.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,4 @@ package org.apache.comet.shims
2121

2222
trait ShimCometConf {
2323
protected val COMET_SCHEMA_EVOLUTION_ENABLED_DEFAULT = false
24-
protected val COMET_ANSI_MODE_ENABLED_DEFAULT = false
2524
}

common/src/main/spark-4.0/org/apache/comet/shims/ShimCometConf.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,4 @@ package org.apache.comet.shims
2121

2222
trait ShimCometConf {
2323
protected val COMET_SCHEMA_EVOLUTION_ENABLED_DEFAULT = true
24-
protected val COMET_ANSI_MODE_ENABLED_DEFAULT = true
2524
}

dev/diffs/3.4.3.diff

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/pom.xml b/pom.xml
2-
index d3544881af1..5cc127f064d 100644
2+
index d3544881af1..9c174496a4b 100644
33
--- a/pom.xml
44
+++ b/pom.xml
55
@@ -148,6 +148,8 @@
@@ -881,7 +881,7 @@ index b5b34922694..a72403780c4 100644
881881
protected val baseResourcePath = {
882882
// use the same way as `SQLQueryTestSuite` to get the resource path
883883
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
884-
index 525d97e4998..8a3e7457618 100644
884+
index 525d97e4998..5e04319dd97 100644
885885
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
886886
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
887887
@@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
@@ -894,6 +894,19 @@ index 525d97e4998..8a3e7457618 100644
894894
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
895895
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
896896
}
897+
@@ -4467,7 +4468,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
898+
val msg = intercept[SparkException] {
899+
sql(query).collect()
900+
}.getMessage
901+
- assert(msg.contains(query))
902+
+ if (!isCometEnabled) {
903+
+ // Comet's error message does not include the original SQL query
904+
+ // https://github.com/apache/datafusion-comet/issues/2215
905+
+ assert(msg.contains(query))
906+
+ }
907+
}
908+
}
909+
}
897910
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
898911
index 48ad10992c5..51d1ee65422 100644
899912
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

dev/diffs/3.5.6.diff

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/pom.xml b/pom.xml
2-
index 68e2c422a24..fb9c2e88fac 100644
2+
index 68e2c422a24..d971894ffe6 100644
33
--- a/pom.xml
44
+++ b/pom.xml
55
@@ -152,6 +152,8 @@
@@ -866,7 +866,7 @@ index c26757c9cff..d55775f09d7 100644
866866
protected val baseResourcePath = {
867867
// use the same way as `SQLQueryTestSuite` to get the resource path
868868
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
869-
index 793a0da6a86..6ccb9d62582 100644
869+
index 793a0da6a86..e48e74091cb 100644
870870
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
871871
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
872872
@@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
@@ -879,6 +879,19 @@ index 793a0da6a86..6ccb9d62582 100644
879879
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
880880
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
881881
}
882+
@@ -4497,7 +4498,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
883+
val msg = intercept[SparkException] {
884+
sql(query).collect()
885+
}.getMessage
886+
- assert(msg.contains(query))
887+
+ if (!isCometEnabled) {
888+
+ // Comet's error message does not include the original SQL query
889+
+ // https://github.com/apache/datafusion-comet/issues/2215
890+
+ assert(msg.contains(query))
891+
+ }
892+
}
893+
}
894+
}
882895
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
883896
index fa1a64460fc..1d2e215d6a3 100644
884897
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

dev/diffs/4.0.0.diff

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ index ad424b3a7cc..4ece0117a34 100644
10571057
protected val baseResourcePath = {
10581058
// use the same way as `SQLQueryTestSuite` to get the resource path
10591059
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
1060-
index b3fce19979e..345acb4811a 100644
1060+
index b3fce19979e..67edf5eb91c 100644
10611061
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
10621062
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
10631063
@@ -1524,7 +1524,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
@@ -1086,11 +1086,24 @@ index b3fce19979e..345acb4811a 100644
10861086
test("SPARK-39175: Query context of Cast should be serialized to executors" +
10871087
- " when WSCG is off") {
10881088
+ " when WSCG is off",
1089-
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) {
1089+
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2218")) {
10901090
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
10911091
SQLConf.ANSI_ENABLED.key -> "true") {
10921092
withTable("t") {
1093-
@@ -4497,7 +4500,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
1093+
@@ -4490,14 +4493,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
1094+
assert(ex.isInstanceOf[SparkNumberFormatException] ||
1095+
ex.isInstanceOf[SparkDateTimeException] ||
1096+
ex.isInstanceOf[SparkRuntimeException])
1097+
- assert(ex.getMessage.contains(query))
1098+
+
1099+
+ if (!isCometEnabled) {
1100+
+ // Comet's error message does not include the original SQL query
1101+
+ // https://github.com/apache/datafusion-comet/issues/2215
1102+
+ assert(ex.getMessage.contains(query))
1103+
+ }
1104+
}
1105+
}
1106+
}
10941107
}
10951108

10961109
test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " +

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ object CometCast {
120120
Compatible()
121121
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
122122
DataTypes.LongType =>
123-
Compatible()
123+
if (evalMode == CometEvalMode.ANSI) {
124+
Incompatible(Some("ANSI mode not supported"))
125+
} else {
126+
Compatible()
127+
}
124128
case DataTypes.BinaryType =>
125129
Compatible()
126130
case DataTypes.FloatType | DataTypes.DoubleType =>
@@ -139,7 +143,7 @@ object CometCast {
139143
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
140144
case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
141145
Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
142-
case DataTypes.TimestampType if evalMode == "ANSI" =>
146+
case DataTypes.TimestampType if evalMode == CometEvalMode.ANSI =>
143147
Incompatible(Some("ANSI mode not supported"))
144148
case DataTypes.TimestampType =>
145149
// https://github.com/apache/datafusion-comet/issues/328

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.sql.internal.SQLConf
3939
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType}
4040

4141
import org.apache.comet.{CometConf, ExtendedExplainInfo}
42-
import org.apache.comet.CometConf.{COMET_ANSI_MODE_ENABLED, COMET_EXEC_SHUFFLE_ENABLED}
42+
import org.apache.comet.CometConf.COMET_EXEC_SHUFFLE_ENABLED
4343
import org.apache.comet.CometSparkSessionExtensions._
4444
import org.apache.comet.serde.OperatorOuterClass.Operator
4545
import org.apache.comet.serde.QueryPlanSerde
@@ -605,19 +605,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
605605
}
606606

607607
private def _apply(plan: SparkPlan): SparkPlan = {
608-
// DataFusion doesn't have ANSI mode. For now we just disable CometExec if ANSI mode is
609-
// enabled.
610-
if (isANSIEnabled(conf)) {
611-
if (COMET_ANSI_MODE_ENABLED.get()) {
612-
if (!isSpark40Plus) {
613-
logWarning("Using Comet's experimental support for ANSI mode.")
614-
}
615-
} else {
616-
logInfo("Comet extension disabled for ANSI mode")
617-
return plan
618-
}
619-
}
620-
621608
// We shouldn't transform Spark query plan if Comet is not loaded.
622609
if (!isCometLoaded(conf)) return plan
623610

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,11 @@ object QueryPlanSerde extends Logging with CometExprShim {
11611161
optExprWithInfo(optExpr, expr, left, right)
11621162

11631163
case r: Round =>
1164+
if (r.ansiEnabled && !CometConf.COMET_IGNORE_ANSI_MODE.get()) {
1165+
withInfo(r, "ANSI mode not supported")
1166+
return None
1167+
}
1168+
11641169
// _scale s a constant, copied from Spark's RoundBase because it is a protected val
11651170
val scaleV: Any = r.scale.eval(EmptyRow)
11661171
val _scale: Int = scaleV.asInstanceOf[Int]

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.math.min
2424
import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Subtract}
2525
import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType}
2626

27+
import org.apache.comet.CometConf
2728
import org.apache.comet.CometSparkSessionExtensions.withInfo
2829
import org.apache.comet.expressions.CometEvalMode
2930
import org.apache.comet.serde.QueryPlanSerde.{castToProto, evalModeToProto, exprToProtoInternal, serializeDataType}
@@ -90,6 +91,10 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase {
9091
expr: Add,
9192
inputs: Seq[Attribute],
9293
binding: Boolean): Option[ExprOuterClass.Expr] = {
94+
if (expr.evalMode == EvalMode.ANSI && !CometConf.COMET_IGNORE_ANSI_MODE.get()) {
95+
withInfo(expr, "ANSI mode not supported")
96+
return None
97+
}
9398
if (!supportedDataType(expr.left.dataType)) {
9499
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
95100
return None
@@ -111,6 +116,10 @@ object CometSubtract extends CometExpressionSerde[Subtract] with MathBase {
111116
expr: Subtract,
112117
inputs: Seq[Attribute],
113118
binding: Boolean): Option[ExprOuterClass.Expr] = {
119+
if (expr.evalMode == EvalMode.ANSI && !CometConf.COMET_IGNORE_ANSI_MODE.get()) {
120+
withInfo(expr, "ANSI mode not supported")
121+
return None
122+
}
114123
if (!supportedDataType(expr.left.dataType)) {
115124
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
116125
return None
@@ -132,6 +141,10 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
132141
expr: Multiply,
133142
inputs: Seq[Attribute],
134143
binding: Boolean): Option[ExprOuterClass.Expr] = {
144+
if (expr.evalMode == EvalMode.ANSI && !CometConf.COMET_IGNORE_ANSI_MODE.get()) {
145+
withInfo(expr, "ANSI mode not supported")
146+
return None
147+
}
135148
if (!supportedDataType(expr.left.dataType)) {
136149
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
137150
return None
@@ -153,6 +166,10 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
153166
expr: Divide,
154167
inputs: Seq[Attribute],
155168
binding: Boolean): Option[ExprOuterClass.Expr] = {
169+
if (expr.evalMode == EvalMode.ANSI && !CometConf.COMET_IGNORE_ANSI_MODE.get()) {
170+
withInfo(expr, "ANSI mode not supported")
171+
return None
172+
}
156173
// Datafusion now throws an exception for dividing by zero
157174
// See https://github.com/apache/arrow-datafusion/pull/6792
158175
// For now, use NullIf to swap zeros with nulls.
@@ -178,6 +195,10 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
178195
expr: IntegralDivide,
179196
inputs: Seq[Attribute],
180197
binding: Boolean): Option[ExprOuterClass.Expr] = {
198+
if (expr.evalMode == EvalMode.ANSI && !CometConf.COMET_IGNORE_ANSI_MODE.get()) {
199+
withInfo(expr, "ANSI mode not supported")
200+
return None
201+
}
181202
if (!supportedDataType(expr.left.dataType)) {
182203
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
183204
return None
@@ -241,6 +262,10 @@ object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
241262
expr: Remainder,
242263
inputs: Seq[Attribute],
243264
binding: Boolean): Option[ExprOuterClass.Expr] = {
265+
if (expr.evalMode == EvalMode.ANSI && !CometConf.COMET_IGNORE_ANSI_MODE.get()) {
266+
withInfo(expr, "ANSI mode not supported")
267+
return None
268+
}
244269
if (!supportedDataType(expr.left.dataType)) {
245270
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
246271
return None

0 commit comments

Comments
 (0)