Skip to content

Commit dbe1e4c

Browse files
committed
support_ansi_sum_decimal_input
1 parent 5227115 commit dbe1e4c

File tree

3 files changed

+150
-16
lines changed

3 files changed

+150
-16
lines changed

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType
2929
import org.apache.comet.CometConf
3030
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
3131
import org.apache.comet.CometSparkSessionExtensions.withInfo
32-
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
32+
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType}
33+
import org.apache.comet.shims.CometEvalModeUtil
3334

3435
object CometMin extends CometAggregateExpressionSerde[Min] {
3536

@@ -212,17 +213,6 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
212213

213214
object CometSum extends CometAggregateExpressionSerde[Sum] {
214215

215-
override def getSupportLevel(sum: Sum): SupportLevel = {
216-
sum.evalMode match {
217-
case EvalMode.ANSI =>
218-
Incompatible(Some("ANSI mode is not supported"))
219-
case EvalMode.TRY =>
220-
Incompatible(Some("TRY mode is not supported"))
221-
case _ =>
222-
Compatible()
223-
}
224-
}
225-
226216
override def convert(
227217
aggExpr: AggregateExpression,
228218
sum: Sum,
@@ -242,7 +232,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
242232
val builder = ExprOuterClass.Sum.newBuilder()
243233
builder.setChild(childExpr.get)
244234
builder.setDatatype(dataType.get)
245-
builder.setFailOnError(sum.evalMode == EvalMode.ANSI)
235+
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode)))
246236

247237
Some(
248238
ExprOuterClass.AggExpr

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
2727
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
2828
import org.apache.spark.sql.comet.CometHashAggregateExec
2929
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
30-
import org.apache.spark.sql.functions.{avg, count_distinct, sum}
30+
import org.apache.spark.sql.functions.{avg, col, count_distinct, sum}
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
3333

@@ -1471,6 +1471,151 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14711471
}
14721472
}
14731473

1474+
test("ANSI support for decimal sum - null test") {
1475+
Seq(true, false).foreach { ansiEnabled =>
1476+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
1477+
withParquetTable(
1478+
Seq(
1479+
(null.asInstanceOf[java.math.BigDecimal], "a"),
1480+
(null.asInstanceOf[java.math.BigDecimal], "b")),
1481+
"null_tbl") {
1482+
val res = sql("SELECT sum(_1) FROM null_tbl")
1483+
checkSparkAnswerAndOperator(res)
1484+
assert(res.collect() === Array(Row(null)))
1485+
}
1486+
}
1487+
}
1488+
}
1489+
1490+
test("ANSI support for try_sum decimal - null test") {
1491+
Seq(true, false).foreach { ansiEnabled =>
1492+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
1493+
withParquetTable(
1494+
Seq(
1495+
(null.asInstanceOf[java.math.BigDecimal], "a"),
1496+
(null.asInstanceOf[java.math.BigDecimal], "b")),
1497+
"null_tbl") {
1498+
val res = sql("SELECT try_sum(_1) FROM null_tbl")
1499+
checkSparkAnswerAndOperator(res)
1500+
assert(res.collect() === Array(Row(null)))
1501+
}
1502+
}
1503+
}
1504+
}
1505+
1506+
test("ANSI support for decimal sum - null test (group by)") {
1507+
Seq(true, false).foreach { ansiEnabled =>
1508+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
1509+
withParquetTable(
1510+
Seq(
1511+
(null.asInstanceOf[java.math.BigDecimal], "a"),
1512+
(null.asInstanceOf[java.math.BigDecimal], "a"),
1513+
(null.asInstanceOf[java.math.BigDecimal], "b"),
1514+
(null.asInstanceOf[java.math.BigDecimal], "b"),
1515+
(null.asInstanceOf[java.math.BigDecimal], "b")),
1516+
"tbl") {
1517+
val res = sql("SELECT _2, sum(_1) FROM tbl group by 1")
1518+
checkSparkAnswerAndOperator(res)
1519+
assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null)))
1520+
}
1521+
}
1522+
}
1523+
}
1524+
1525+
test("ANSI support for try_sum decimal - null test (group by)") {
1526+
Seq(true, false).foreach { ansiEnabled =>
1527+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
1528+
withParquetTable(
1529+
Seq(
1530+
(null.asInstanceOf[java.math.BigDecimal], "a"),
1531+
(null.asInstanceOf[java.math.BigDecimal], "a"),
1532+
(null.asInstanceOf[java.math.BigDecimal], "b"),
1533+
(null.asInstanceOf[java.math.BigDecimal], "b"),
1534+
(null.asInstanceOf[java.math.BigDecimal], "b")),
1535+
"tbl") {
1536+
val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1")
1537+
checkSparkAnswerAndOperator(res)
1538+
assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null)))
1539+
}
1540+
}
1541+
}
1542+
}
1543+
1544+
protected def generateOverflowDecimalInputs: Seq[(java.math.BigDecimal, Int)] = {
1545+
val maxDec38_0 = new java.math.BigDecimal("99999999999999999999")
1546+
(1 to 50).flatMap(_ => Seq((maxDec38_0, 1)))
1547+
}
1548+
1549+
test("ANSI support - decimal SUM function") {
1550+
Seq(true, false).foreach { ansiEnabled =>
1551+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
1552+
withParquetTable(generateOverflowDecimalInputs, "tbl") {
1553+
val input = sql("SELECT _1 FROM tbl")
1554+
val res = sql("SELECT SUM(_1) FROM tbl")
1555+
if (ansiEnabled) {
1556+
checkSparkAnswerMaybeThrows(res) match {
1557+
case (Some(sparkExc), Some(cometExc)) =>
1558+
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1559+
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1560+
case _ =>
1561+
fail("Exception should be thrown for decimal overflow in ANSI mode")
1562+
}
1563+
} else {
1564+
checkSparkAnswerAndOperator(res)
1565+
}
1566+
}
1567+
}
1568+
}
1569+
}
1570+
1571+
test("ANSI support for decimal SUM - GROUP BY") {
1572+
Seq(true, false).foreach { ansiEnabled =>
1573+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
1574+
withParquetTable(generateOverflowDecimalInputs, "tbl") {
1575+
val res =
1576+
sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2)
1577+
if (ansiEnabled) {
1578+
checkSparkAnswerMaybeThrows(res) match {
1579+
case (Some(sparkExc), Some(cometExc)) =>
1580+
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1581+
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1582+
case _ =>
1583+
fail("Exception should be thrown for decimal overflow with GROUP BY in ANSI mode")
1584+
}
1585+
} else {
1586+
checkSparkAnswerAndOperator(res)
1587+
}
1588+
}
1589+
}
1590+
}
1591+
}
1592+
1593+
test("try_sum decimal overflow") {
1594+
withParquetTable(generateOverflowDecimalInputs, "tbl") {
1595+
val res = sql("SELECT try_sum(_1) FROM tbl")
1596+
checkSparkAnswerAndOperator(res)
1597+
}
1598+
}
1599+
1600+
test("try_sum decimal overflow - with GROUP BY") {
1601+
withParquetTable(generateOverflowDecimalInputs, "tbl") {
1602+
val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2"))
1603+
checkSparkAnswerAndOperator(res)
1604+
}
1605+
}
1606+
1607+
test("try_sum decimal partial overflow - with GROUP BY") {
1608+
// Group 1 overflows, Group 2 succeeds
1609+
val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs ++ Seq(
1610+
(new java.math.BigDecimal(300), 2),
1611+
(new java.math.BigDecimal(200), 2))
1612+
withParquetTable(data, "tbl") {
1613+
val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2")
1614+
// Group 1 should be NULL, Group 2 should be 500
1615+
checkSparkAnswerAndOperator(res)
1616+
}
1617+
}
1618+
14741619
protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = {
14751620
val df = sql(query)
14761621
checkSparkAnswer(df)

spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.SparkContext
2929
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE}
3030
import org.apache.spark.sql.TPCDSBase
3131
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast}
32-
import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
32+
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
3333
import org.apache.spark.sql.catalyst.util.resourceToString
3434
import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec}
3535
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
@@ -226,7 +226,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa
226226
CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true",
227227
// Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support
228228
CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true",
229-
CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true",
230229
// as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64
231230
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
232231
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {

0 commit comments

Comments
 (0)