@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
2727import org .apache .spark .sql .catalyst .optimizer .EliminateSorts
2828import org .apache .spark .sql .comet .CometHashAggregateExec
2929import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
30- import org .apache .spark .sql .functions .{avg , count_distinct , sum }
30+ import org .apache .spark .sql .functions .{avg , col , count_distinct , sum }
3131import org .apache .spark .sql .internal .SQLConf
3232import org .apache .spark .sql .types .{DataTypes , StructField , StructType }
3333
@@ -1471,6 +1471,151 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14711471 }
14721472 }
14731473
1474+ test(" ANSI support for decimal sum - null test" ) {
1475+ Seq (true , false ).foreach { ansiEnabled =>
1476+ withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
1477+ withParquetTable(
1478+ Seq (
1479+ (null .asInstanceOf [java.math.BigDecimal ], " a" ),
1480+ (null .asInstanceOf [java.math.BigDecimal ], " b" )),
1481+ " null_tbl" ) {
1482+ val res = sql(" SELECT sum(_1) FROM null_tbl" )
1483+ checkSparkAnswerAndOperator(res)
1484+ assert(res.collect() === Array (Row (null )))
1485+ }
1486+ }
1487+ }
1488+ }
1489+
1490+ test(" ANSI support for try_sum decimal - null test" ) {
1491+ Seq (true , false ).foreach { ansiEnabled =>
1492+ withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
1493+ withParquetTable(
1494+ Seq (
1495+ (null .asInstanceOf [java.math.BigDecimal ], " a" ),
1496+ (null .asInstanceOf [java.math.BigDecimal ], " b" )),
1497+ " null_tbl" ) {
1498+ val res = sql(" SELECT try_sum(_1) FROM null_tbl" )
1499+ checkSparkAnswerAndOperator(res)
1500+ assert(res.collect() === Array (Row (null )))
1501+ }
1502+ }
1503+ }
1504+ }
1505+
1506+ test(" ANSI support for decimal sum - null test (group by)" ) {
1507+ Seq (true , false ).foreach { ansiEnabled =>
1508+ withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
1509+ withParquetTable(
1510+ Seq (
1511+ (null .asInstanceOf [java.math.BigDecimal ], " a" ),
1512+ (null .asInstanceOf [java.math.BigDecimal ], " a" ),
1513+ (null .asInstanceOf [java.math.BigDecimal ], " b" ),
1514+ (null .asInstanceOf [java.math.BigDecimal ], " b" ),
1515+ (null .asInstanceOf [java.math.BigDecimal ], " b" )),
1516+ " tbl" ) {
1517+ val res = sql(" SELECT _2, sum(_1) FROM tbl group by 1" )
1518+ checkSparkAnswerAndOperator(res)
1519+ assert(res.orderBy(col(" _2" )).collect() === Array (Row (" a" , null ), Row (" b" , null )))
1520+ }
1521+ }
1522+ }
1523+ }
1524+
1525+ test(" ANSI support for try_sum decimal - null test (group by)" ) {
1526+ Seq (true , false ).foreach { ansiEnabled =>
1527+ withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
1528+ withParquetTable(
1529+ Seq (
1530+ (null .asInstanceOf [java.math.BigDecimal ], " a" ),
1531+ (null .asInstanceOf [java.math.BigDecimal ], " a" ),
1532+ (null .asInstanceOf [java.math.BigDecimal ], " b" ),
1533+ (null .asInstanceOf [java.math.BigDecimal ], " b" ),
1534+ (null .asInstanceOf [java.math.BigDecimal ], " b" )),
1535+ " tbl" ) {
1536+ val res = sql(" SELECT _2, try_sum(_1) FROM tbl group by 1" )
1537+ checkSparkAnswerAndOperator(res)
1538+ assert(res.orderBy(col(" _2" )).collect() === Array (Row (" a" , null ), Row (" b" , null )))
1539+ }
1540+ }
1541+ }
1542+ }
1543+
1544+ protected def generateOverflowDecimalInputs : Seq [(java.math.BigDecimal , Int )] = {
1545+ val maxDec38_0 = new java.math.BigDecimal (" 99999999999999999999" )
1546+ (1 to 50 ).flatMap(_ => Seq ((maxDec38_0, 1 )))
1547+ }
1548+
1549+ test(" ANSI support - decimal SUM function" ) {
1550+ Seq (true , false ).foreach { ansiEnabled =>
1551+ withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
1552+ withParquetTable(generateOverflowDecimalInputs, " tbl" ) {
1553+ val input = sql(" SELECT _1 FROM tbl" )
1554+ val res = sql(" SELECT SUM(_1) FROM tbl" )
1555+ if (ansiEnabled) {
1556+ checkSparkAnswerMaybeThrows(res) match {
1557+ case (Some (sparkExc), Some (cometExc)) =>
1558+ assert(sparkExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
1559+ assert(cometExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
1560+ case _ =>
1561+ fail(" Exception should be thrown for decimal overflow in ANSI mode" )
1562+ }
1563+ } else {
1564+ checkSparkAnswerAndOperator(res)
1565+ }
1566+ }
1567+ }
1568+ }
1569+ }
1570+
1571+ test(" ANSI support for decimal SUM - GROUP BY" ) {
1572+ Seq (true , false ).foreach { ansiEnabled =>
1573+ withSQLConf(SQLConf .ANSI_ENABLED .key -> ansiEnabled.toString) {
1574+ withParquetTable(generateOverflowDecimalInputs, " tbl" ) {
1575+ val res =
1576+ sql(" SELECT _2, SUM(_1) FROM tbl GROUP BY _2" ).repartition(2 )
1577+ if (ansiEnabled) {
1578+ checkSparkAnswerMaybeThrows(res) match {
1579+ case (Some (sparkExc), Some (cometExc)) =>
1580+ assert(sparkExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
1581+ assert(cometExc.getMessage.contains(" ARITHMETIC_OVERFLOW" ))
1582+ case _ =>
1583+ fail(" Exception should be thrown for decimal overflow with GROUP BY in ANSI mode" )
1584+ }
1585+ } else {
1586+ checkSparkAnswerAndOperator(res)
1587+ }
1588+ }
1589+ }
1590+ }
1591+ }
1592+
1593+ test(" try_sum decimal overflow" ) {
1594+ withParquetTable(generateOverflowDecimalInputs, " tbl" ) {
1595+ val res = sql(" SELECT try_sum(_1) FROM tbl" )
1596+ checkSparkAnswerAndOperator(res)
1597+ }
1598+ }
1599+
1600+ test(" try_sum decimal overflow - with GROUP BY" ) {
1601+ withParquetTable(generateOverflowDecimalInputs, " tbl" ) {
1602+ val res = sql(" SELECT _2, try_sum(_1) FROM tbl GROUP BY _2" ).repartition(2 , col(" _2" ))
1603+ checkSparkAnswerAndOperator(res)
1604+ }
1605+ }
1606+
1607+ test(" try_sum decimal partial overflow - with GROUP BY" ) {
1608+ // Group 1 overflows, Group 2 succeeds
1609+ val data : Seq [(java.math.BigDecimal , Int )] = generateOverflowDecimalInputs ++ Seq (
1610+ (new java.math.BigDecimal (300 ), 2 ),
1611+ (new java.math.BigDecimal (200 ), 2 ))
1612+ withParquetTable(data, " tbl" ) {
1613+ val res = sql(" SELECT _2, try_sum(_1) FROM tbl GROUP BY _2" )
1614+ // Group 1 should be NULL, Group 2 should be 500
1615+ checkSparkAnswerAndOperator(res)
1616+ }
1617+ }
1618+
14741619 protected def checkSparkAnswerAndNumOfAggregates (query : String , numAggregates : Int ): Unit = {
14751620 val df = sql(query)
14761621 checkSparkAnswer(df)
0 commit comments