Skip to content

Commit e77998a

Browse files
authored
fix: checkSparkMaybeThrows should compare Spark and Comet results in success case (apache#2728)
1 parent 4cfceb7 commit e77998a

File tree

6 files changed

+26
-24
lines changed

6 files changed

+26
-24
lines changed

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
116116

117117
(fromType, toType) match {
118118
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
119+
case (dt: ArrayType, DataTypes.StringType) if dt.elementType == DataTypes.BinaryType =>
120+
Incompatible()
119121
case (dt: ArrayType, DataTypes.StringType) =>
120122
isSupported(dt.elementType, DataTypes.StringType, timeZoneId, evalMode)
121123
case (dt: ArrayType, dt1: ArrayType) =>

spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHe
7676

7777
test("bitwise_get - throws exceptions") {
7878
def checkSparkAndCometEqualThrows(query: String): Unit = {
79-
checkSparkMaybeThrows(sql(query)) match {
79+
checkSparkAnswerMaybeThrows(sql(query)) match {
8080
case (Some(sparkExc), Some(cometExc)) =>
8181
assert(sparkExc.getMessage == cometExc.getMessage)
8282
case _ => fail("Exception should be thrown")

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
3131
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3232
import org.apache.spark.sql.functions.col
3333
import org.apache.spark.sql.internal.SQLConf
34-
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
34+
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
3535

3636
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
3737
import org.apache.comet.expressions.{CometCast, CometEvalMode}
@@ -1035,7 +1035,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10351035

10361036
test("cast between decimals with negative precision") {
10371037
// cast to negative scale
1038-
checkSparkMaybeThrows(
1038+
checkSparkAnswerMaybeThrows(
10391039
spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) match {
10401040
case (expected, actual) =>
10411041
assert(expected.contains("PARSE_SYNTAX_ERROR") === actual.contains("PARSE_SYNTAX_ERROR"))
@@ -1062,11 +1062,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10621062
IntegerType,
10631063
LongType,
10641064
ShortType,
1065-
// FloatType,
1066-
// DoubleType,
1065+
// FloatType,
1066+
// DoubleType,
1067+
// BinaryType
10671068
DecimalType(10, 2),
1068-
DecimalType(38, 18),
1069-
BinaryType).foreach { dt =>
1069+
DecimalType(38, 18)).foreach { dt =>
10701070
val input = generateArrays(100, dt)
10711071
castTest(input, StringType, hasIncompatibleType = hasIncompatibleType(input.schema))
10721072
}
@@ -1272,7 +1272,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12721272

12731273
// cast() should throw exception on invalid inputs when ansi mode is enabled
12741274
val df = data.withColumn("converted", col("a").cast(toType))
1275-
checkSparkMaybeThrows(df) match {
1275+
checkSparkAnswerMaybeThrows(df) match {
12761276
case (None, None) =>
12771277
// neither system threw an exception
12781278
case (None, Some(e)) =>

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
312312
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
313313
withParquetTable(path.toString, "tbl") {
314314
val (sparkErr, cometErr) =
315-
checkSparkMaybeThrows(sql(s"SELECT _20 + ${Int.MaxValue} FROM tbl"))
315+
checkSparkAnswerMaybeThrows(sql(s"SELECT _20 + ${Int.MaxValue} FROM tbl"))
316316
if (isSpark40Plus) {
317317
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
318318
} else {
@@ -359,7 +359,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
359359
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
360360
withParquetTable(path.toString, "tbl") {
361361
val (sparkErr, cometErr) =
362-
checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM tbl"))
362+
checkSparkAnswerMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM tbl"))
363363
if (isSpark40Plus) {
364364
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
365365
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
@@ -2022,7 +2022,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
20222022
val expectedDivideByZeroError =
20232023
"[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead."
20242024

2025-
checkSparkMaybeThrows(sql(query)) match {
2025+
checkSparkAnswerMaybeThrows(sql(query)) match {
20262026
case (Some(sparkException), Some(cometException)) =>
20272027
assert(sparkException.getMessage.contains(expectedDivideByZeroError))
20282028
assert(cometException.getMessage.contains(expectedDivideByZeroError))
@@ -2174,7 +2174,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
21742174
}
21752175

21762176
def checkOverflow(query: String, dtype: String): Unit = {
2177-
checkSparkMaybeThrows(sql(query)) match {
2177+
checkSparkAnswerMaybeThrows(sql(query)) match {
21782178
case (Some(sparkException), Some(cometException)) =>
21792179
assert(sparkException.getMessage.contains(dtype + " overflow"))
21802180
assert(cometException.getMessage.contains(dtype + " overflow"))
@@ -2700,7 +2700,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
27002700

27012701
test("ListExtract") {
27022702
def assertBothThrow(df: DataFrame): Unit = {
2703-
checkSparkMaybeThrows(df) match {
2703+
checkSparkAnswerMaybeThrows(df) match {
27042704
case (Some(_), Some(_)) => ()
27052705
case (spark, comet) =>
27062706
fail(
@@ -2850,7 +2850,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
28502850
| from tbl
28512851
| """.stripMargin)
28522852

2853-
checkSparkMaybeThrows(res) match {
2853+
checkSparkAnswerMaybeThrows(res) match {
28542854
case (Some(sparkExc), Some(cometExc)) =>
28552855
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
28562856
assert(sparkExc.getMessage.contains("overflow"))
@@ -2869,7 +2869,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
28692869
| _1 - _2
28702870
| from tbl
28712871
| """.stripMargin)
2872-
checkSparkMaybeThrows(res) match {
2872+
checkSparkAnswerMaybeThrows(res) match {
28732873
case (Some(sparkExc), Some(cometExc)) =>
28742874
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
28752875
assert(sparkExc.getMessage.contains("overflow"))
@@ -2889,7 +2889,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
28892889
| from tbl
28902890
| """.stripMargin)
28912891

2892-
checkSparkMaybeThrows(res) match {
2892+
checkSparkAnswerMaybeThrows(res) match {
28932893
case (Some(sparkExc), Some(cometExc)) =>
28942894
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
28952895
assert(sparkExc.getMessage.contains("overflow"))
@@ -2909,7 +2909,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29092909
| from tbl
29102910
| """.stripMargin)
29112911

2912-
checkSparkMaybeThrows(res) match {
2912+
checkSparkAnswerMaybeThrows(res) match {
29132913
case (Some(sparkExc), Some(cometExc)) =>
29142914
assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
29152915
assert(sparkExc.getMessage.contains("Division by zero"))
@@ -2929,7 +2929,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29292929
| from tbl
29302930
| """.stripMargin)
29312931

2932-
checkSparkMaybeThrows(res) match {
2932+
checkSparkAnswerMaybeThrows(res) match {
29332933
case (Some(sparkExc), Some(cometExc)) =>
29342934
assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
29352935
assert(sparkExc.getMessage.contains("Division by zero"))
@@ -2950,7 +2950,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29502950
| from tbl
29512951
| """.stripMargin)
29522952

2953-
checkSparkMaybeThrows(res) match {
2953+
checkSparkAnswerMaybeThrows(res) match {
29542954
case (Some(sparkException), Some(cometException)) =>
29552955
assert(sparkException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
29562956
assert(cometException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
@@ -2985,7 +2985,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29852985
Seq(true, false).foreach { ansi =>
29862986
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) {
29872987
val res = spark.sql(s"SELECT round(_1, $scale) from tbl")
2988-
checkSparkMaybeThrows(res) match {
2988+
checkSparkAnswerMaybeThrows(res) match {
29892989
case (Some(sparkException), Some(cometException)) =>
29902990
assert(sparkException.getMessage.contains("ARITHMETIC_OVERFLOW"))
29912991
assert(cometException.getMessage.contains("ARITHMETIC_OVERFLOW"))

spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe
5656
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
5757
for (field <- df.schema.fields) {
5858
val col = field.name
59-
checkSparkMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY $col")) match {
59+
checkSparkAnswerMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY $col")) match {
6060
case (Some(sparkExc), Some(cometExc)) =>
6161
val cometErrorPattern =
6262
""".+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ abstract class CometTestBase
306306
* This method does not check that Comet replaced any operators or that the results match in the
307307
* case where the query is successful against both Spark and Comet.
308308
*/
309-
protected def checkSparkMaybeThrows(
309+
protected def checkSparkAnswerMaybeThrows(
310310
df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
311311
var expected: Try[Array[Row]] = null
312312
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
@@ -316,8 +316,8 @@ abstract class CometTestBase
316316

317317
(expected, actual) match {
318318
case (Success(_), Success(_)) =>
319-
// TODO compare results and confirm that they match
320-
// https://github.com/apache/datafusion-comet/issues/2657
319+
// compare results and confirm that they match
320+
checkSparkAnswer(df)
321321
(None, None)
322322
case _ =>
323323
(expected.failed.toOption, actual.failed.toOption)

0 commit comments

Comments
 (0)