diff --git a/dev/diffs/3.4.3.diff b/dev/diffs/3.4.3.diff index 1c0ca867d6..ab9ac08886 100644 --- a/dev/diffs/3.4.3.diff +++ b/dev/diffs/3.4.3.diff @@ -193,6 +193,19 @@ index 41fd4de2a09..44cd244d3b0 100644 -- Test aggregate operator with codegen on and off. --CONFIG_DIM1 spark.sql.codegen.wholeStage=true --CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY +diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql +index 3a409eea348..38fed024c98 100644 +--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql ++++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql +@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1 + -- any evens + SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0'); + ++-- https://github.com/apache/datafusion-comet/issues/2215 ++--SET spark.comet.exec.enabled=false + -- [SPARK-28024] Incorrect value when out of range + SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i; + diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql index fac23b4a26f..2b73732c33f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql @@ -881,7 +894,7 @@ index b5b34922694..a72403780c4 100644 protected val baseResourcePath = { // use the same way as `SQLQueryTestSuite` to get the resource path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala -index 525d97e4998..5e04319dd97 100644 +index 525d97e4998..843f0472c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark @@ -894,7 +907,27 @@ index 525d97e4998..5e04319dd97 100644 AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() } -@@ -4467,7 +4468,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark +@@ -4429,7 +4430,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark + } + + test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" + +- " when WSCG is off") { ++ " when WSCG is off", ++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.ANSI_ENABLED.key -> "true") { + withTable("t") { +@@ -4450,7 +4452,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark + } + + test("SPARK-39175: Query context of Cast should be serialized to executors" + +- " when WSCG is off") { ++ " when WSCG is off", ++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.ANSI_ENABLED.key -> "true") { + withTable("t") { +@@ -4467,14 +4470,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val msg = intercept[SparkException] { sql(query).collect() }.getMessage @@ -907,6 +940,15 @@ index 525d97e4998..5e04319dd97 100644 } } } + } + + test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " + +- "be serialized to executors when WSCG is off") { ++ "be serialized to executors when WSCG is off", ++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.ANSI_ENABLED.key -> "true") { + withTable("t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 48ad10992c5..51d1ee65422 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala diff --git a/dev/diffs/3.5.6.diff b/dev/diffs/3.5.6.diff index f3909d074a..63f0d3eb0d 100644 --- a/dev/diffs/3.5.6.diff +++ b/dev/diffs/3.5.6.diff @@ -172,6 +172,19 @@ index 41fd4de2a09..44cd244d3b0 100644 -- Test aggregate operator with codegen on and off. --CONFIG_DIM1 spark.sql.codegen.wholeStage=true --CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY +diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql +index 3a409eea348..38fed024c98 100644 +--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql ++++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql +@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1 + -- any evens + SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0'); + ++-- https://github.com/apache/datafusion-comet/issues/2215 ++--SET spark.comet.exec.enabled=false + -- [SPARK-28024] Incorrect value when out of range + SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i; + diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql index fac23b4a26f..2b73732c33f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql @@ -866,7 +879,7 @@ index c26757c9cff..d55775f09d7 100644 protected val baseResourcePath = { // use the same way as `SQLQueryTestSuite` to get the resource path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala -index 793a0da6a86..e48e74091cb 100644 +index 793a0da6a86..181bfc16e4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark @@ -879,7 +892,27 @@ index 793a0da6a86..e48e74091cb 100644 AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() } -@@ -4497,7 +4498,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark +@@ -4459,7 +4460,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark + } + + test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" + +- " when WSCG is off") { ++ " when WSCG is off", ++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.ANSI_ENABLED.key -> "true") { + withTable("t") { +@@ -4480,7 +4482,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark + } + + test("SPARK-39175: Query context of Cast should be serialized to executors" + +- " when WSCG is off") { ++ " when WSCG is off", ++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.ANSI_ENABLED.key -> "true") { + withTable("t") { +@@ -4497,14 +4500,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val msg = intercept[SparkException] { sql(query).collect() }.getMessage @@ -892,6 +925,15 @@ index 793a0da6a86..e48e74091cb 100644 } } } + } + + test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " + +- "be serialized to executors when WSCG is off") { ++ "be serialized to executors when WSCG is off", ++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.ANSI_ENABLED.key -> "true") { + withTable("t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index fa1a64460fc..1d2e215d6a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 64efa31d52..517c037e93 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -62,8 +62,9 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_comet_spark_expr::{ - create_comet_physical_fun, create_modulo_expr, create_negate_expr, BinaryOutputStyle, - BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, SparkSecond, + create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, create_modulo_expr, + create_negate_expr, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode, + SparkHour, SparkMinute, SparkSecond, }; use crate::execution::operators::ExecutionError::GeneralError; @@ -242,8 +243,6 @@ impl PhysicalPlanner { ) -> Result, ExecutionError> { match spark_expr.expr_struct.as_ref().unwrap() { ExprStruct::Add(expr) => { - // TODO respect ANSI eval mode - // https://github.com/apache/datafusion-comet/issues/536 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr( expr.left.as_ref().unwrap(), @@ -255,8 +254,6 @@ impl PhysicalPlanner { ) } ExprStruct::Subtract(expr) => { - // TODO respect ANSI eval mode - // https://github.com/apache/datafusion-comet/issues/535 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr( expr.left.as_ref().unwrap(), @@ -268,8 +265,6 @@ impl PhysicalPlanner { ) } ExprStruct::Multiply(expr) => { - // TODO respect ANSI eval mode - // https://github.com/apache/datafusion-comet/issues/534 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr( expr.left.as_ref().unwrap(), @@ -281,8 +276,6 @@ impl PhysicalPlanner { ) } ExprStruct::Divide(expr) => { - // TODO respect ANSI eval mode - // https://github.com/apache/datafusion-comet/issues/533 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr( expr.left.as_ref().unwrap(), @@ -1010,21 +1003,25 @@ impl PhysicalPlanner { } _ => { let data_type = return_type.map(to_arrow_datatype).unwrap(); - if eval_mode == EvalMode::Try && data_type.is_integer() { + if [EvalMode::Try, EvalMode::Ansi].contains(&eval_mode) + && (data_type.is_integer() + || (data_type.is_floating() && op == DataFusionOperator::Divide)) + { let op_str = match op { DataFusionOperator::Plus => "checked_add", DataFusionOperator::Minus => "checked_sub", DataFusionOperator::Multiply => "checked_mul", DataFusionOperator::Divide => "checked_div", _ => { - todo!("Operator yet to be implemented!"); + todo!("ANSI mode for Operator yet to be implemented!"); } }; - let fun_expr = create_comet_physical_fun( + let fun_expr = create_comet_physical_fun_with_eval_mode( op_str, data_type.clone(), &self.session_ctx.state(), None, + eval_mode, )?; Ok(Arc::new(ScalarFunctionExpr::new( op_str, diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 93a820ba9a..f96ddffce9 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -21,7 +21,7 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, - spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, + spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace, }; use arrow::datatypes::DataType; @@ -64,6 +64,15 @@ macro_rules! make_comet_scalar_udf { ); Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func))) }}; + ($name:expr, $func:ident, $data_type:ident, $eval_mode:ident) => {{ + let scalar_func = CometScalarFunction::new( + $name.to_string(), + Signature::variadic_any(Volatility::Immutable), + $data_type.clone(), + Arc::new(move |args| $func(args, &$data_type, $eval_mode)), + ); + Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func))) + }}; } /// Create a physical scalar function. @@ -72,6 +81,23 @@ pub fn create_comet_physical_fun( data_type: DataType, registry: &dyn FunctionRegistry, fail_on_error: Option, +) -> Result, DataFusionError> { + create_comet_physical_fun_with_eval_mode( + fun_name, + data_type, + registry, + fail_on_error, + EvalMode::Legacy, + ) +} + +/// Create a physical scalar function with eval mode. Goal is to deprecate above function once all the operators have ANSI support +pub fn create_comet_physical_fun_with_eval_mode( + fun_name: &str, + data_type: DataType, + registry: &dyn FunctionRegistry, + fail_on_error: Option, + eval_mode: EvalMode, ) -> Result, DataFusionError> { match fun_name { "ceil" => { @@ -117,16 +143,16 @@ pub fn create_comet_physical_fun( ) } "checked_add" => { - make_comet_scalar_udf!("checked_add", checked_add, data_type) + make_comet_scalar_udf!("checked_add", checked_add, data_type, eval_mode) } "checked_sub" => { - make_comet_scalar_udf!("checked_sub", checked_sub, data_type) + make_comet_scalar_udf!("checked_sub", checked_sub, data_type, eval_mode) } "checked_mul" => { - make_comet_scalar_udf!("checked_mul", checked_mul, data_type) + make_comet_scalar_udf!("checked_mul", checked_mul, data_type, eval_mode) } "checked_div" => { - make_comet_scalar_udf!("checked_div", checked_div, data_type) + make_comet_scalar_udf!("checked_div", checked_div, data_type, eval_mode) } "murmur3_hash" => { let func = Arc::new(spark_murmur3_hash); diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index af5677a9bf..7bdc7ff515 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -64,7 +64,10 @@ pub use conditional_funcs::*; pub use conversion_funcs::*; pub use nondetermenistic_funcs::*; -pub use comet_scalar_funcs::{create_comet_physical_fun, register_all_comet_functions}; +pub use comet_scalar_funcs::{ + create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, + register_all_comet_functions, +}; pub use datetime_funcs::{ spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr, diff --git a/native/spark-expr/src/math_funcs/checked_arithmetic.rs b/native/spark-expr/src/math_funcs/checked_arithmetic.rs index 0312cdb0b0..bb4118f868 100644 --- a/native/spark-expr/src/math_funcs/checked_arithmetic.rs +++ b/native/spark-expr/src/math_funcs/checked_arithmetic.rs @@ -18,7 +18,11 @@ use arrow::array::{Array, ArrowNativeTypeOp, PrimitiveArray, PrimitiveBuilder}; use arrow::array::{ArrayRef, AsArray}; -use arrow::datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type}; +use crate::{divide_by_zero_error, EvalMode, SparkError}; +use arrow::datatypes::{ + ArrowPrimitiveType, DataType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, + Int64Type, Int8Type, +}; use datafusion::common::DataFusionError; use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; @@ -27,6 +31,7 @@ pub fn try_arithmetic_kernel( left: &PrimitiveArray, right: &PrimitiveArray, op: &str, + is_ansi_mode: bool, ) -> Result where T: ArrowPrimitiveType, @@ -39,7 +44,19 @@ where if left.is_null(i) || right.is_null(i) { builder.append_null(); } else { - builder.append_option(left.value(i).add_checked(right.value(i)).ok()); + match left.value(i).add_checked(right.value(i)) { + Ok(v) => builder.append_value(v), + Err(_e) => { + if is_ansi_mode { + return Err(SparkError::ArithmeticOverflow { + from_type: String::from("integer"), + } + .into()); + } else { + builder.append_null(); + } + } + } } } } @@ -48,7 +65,19 @@ where if left.is_null(i) || right.is_null(i) { builder.append_null(); } else { - builder.append_option(left.value(i).sub_checked(right.value(i)).ok()); + match left.value(i).sub_checked(right.value(i)) { + Ok(v) => builder.append_value(v), + Err(_e) => { + if is_ansi_mode { + return Err(SparkError::ArithmeticOverflow { + from_type: String::from("integer"), + } + .into()); + } else { + builder.append_null(); + } + } + } } } } @@ -57,7 +86,19 @@ where if left.is_null(i) || right.is_null(i) { builder.append_null(); } else { - builder.append_option(left.value(i).mul_checked(right.value(i)).ok()); + match left.value(i).mul_checked(right.value(i)) { + Ok(v) => builder.append_value(v), + Err(_e) => { + if is_ansi_mode { + return Err(SparkError::ArithmeticOverflow { + from_type: String::from("integer"), + } + .into()); + } else { + builder.append_null(); + } + } + } } } } @@ -66,7 +107,23 @@ where if left.is_null(i) || right.is_null(i) { builder.append_null(); } else { - builder.append_option(left.value(i).div_checked(right.value(i)).ok()); + match left.value(i).div_checked(right.value(i)) { + Ok(v) => builder.append_value(v), + Err(_e) => { + if is_ansi_mode { + return if right.value(i).is_zero() { + Err(divide_by_zero_error().into()) + } else { + return Err(SparkError::ArithmeticOverflow { + from_type: String::from("integer"), + } + .into()); + }; + } else { + builder.append_null(); + } + } + } } } } @@ -84,39 +141,55 @@ where pub fn checked_add( args: &[ColumnarValue], data_type: &DataType, + eval_mode: EvalMode, ) -> Result { - checked_arithmetic_internal(args, data_type, "checked_add") + checked_arithmetic_internal(args, data_type, "checked_add", eval_mode) } pub fn checked_sub( args: &[ColumnarValue], data_type: &DataType, + eval_mode: EvalMode, ) -> Result { - checked_arithmetic_internal(args, data_type, "checked_sub") + checked_arithmetic_internal(args, data_type, "checked_sub", eval_mode) } pub fn checked_mul( args: &[ColumnarValue], data_type: &DataType, + eval_mode: EvalMode, ) -> Result { - checked_arithmetic_internal(args, data_type, "checked_mul") + checked_arithmetic_internal(args, data_type, "checked_mul", eval_mode) } pub fn checked_div( args: &[ColumnarValue], data_type: &DataType, + eval_mode: EvalMode, ) -> Result { - checked_arithmetic_internal(args, data_type, "checked_div") + checked_arithmetic_internal(args, data_type, "checked_div", eval_mode) } fn checked_arithmetic_internal( args: &[ColumnarValue], data_type: &DataType, op: &str, + eval_mode: EvalMode, ) -> Result { let left = &args[0]; let right = &args[1]; + let is_ansi_mode = match eval_mode { + EvalMode::Try => false, + EvalMode::Ansi => true, + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported mode : {:?}", + eval_mode + ))) + } + }; + let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (left, right) { (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)), (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { @@ -128,17 +201,50 @@ fn checked_arithmetic_internal( (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?), }; - // Rust only supports checked_arithmetic on Int32 and Int64 + // Rust only supports checked_arithmetic on numeric types let result_array = match data_type { + DataType::Int8 => try_arithmetic_kernel::( + left_arr.as_primitive::(), + right_arr.as_primitive::(), + op, + is_ansi_mode, + ), + DataType::Int16 => try_arithmetic_kernel::( + left_arr.as_primitive::(), + right_arr.as_primitive::(), + op, + is_ansi_mode, + ), DataType::Int32 => try_arithmetic_kernel::( left_arr.as_primitive::(), right_arr.as_primitive::(), op, + is_ansi_mode, ), DataType::Int64 => try_arithmetic_kernel::( left_arr.as_primitive::(), right_arr.as_primitive::(), op, + is_ansi_mode, + ), + // Spark always casts division operands to floats + DataType::Float16 if (op == "checked_div") => try_arithmetic_kernel::( + left_arr.as_primitive::(), + right_arr.as_primitive::(), + op, + is_ansi_mode, + ), + DataType::Float32 if (op == "checked_div") => try_arithmetic_kernel::( + left_arr.as_primitive::(), + right_arr.as_primitive::(), + op, + is_ansi_mode, + ), + DataType::Float64 if (op == "checked_div") => try_arithmetic_kernel::( + left_arr.as_primitive::(), + right_arr.as_primitive::(), + op, + is_ansi_mode, ), _ => Err(DataFusionError::Internal(format!( "Unsupported data type: {:?}", diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala index 0f1eeb758a..4507dc1073 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -87,14 +87,6 @@ trait MathBase { object CometAdd extends CometExpressionSerde[Add] with MathBase { - override def getSupportLevel(expr: Add): SupportLevel = { - if (expr.evalMode == EvalMode.ANSI) { - Incompatible(Some("ANSI mode is not supported")) - } else { - Compatible(None) - } - } - override def convert( expr: Add, inputs: Seq[Attribute], @@ -117,14 +109,6 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase { object CometSubtract extends CometExpressionSerde[Subtract] with MathBase { - override def getSupportLevel(expr: Subtract): SupportLevel = { - if (expr.evalMode == EvalMode.ANSI) { - Incompatible(Some("ANSI mode is not supported")) - } else { - Compatible(None) - } - } - override def convert( expr: Subtract, inputs: Seq[Attribute], @@ -147,14 +131,6 @@ object CometSubtract extends CometExpressionSerde[Subtract] with MathBase { object CometMultiply extends CometExpressionSerde[Multiply] with MathBase { - override def getSupportLevel(expr: Multiply): SupportLevel = { - if (expr.evalMode == EvalMode.ANSI) { - Incompatible(Some("ANSI mode is not supported")) - } else { - Compatible(None) - } - } - override def convert( expr: Multiply, inputs: Seq[Attribute], @@ -177,14 +153,6 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase { object CometDivide extends CometExpressionSerde[Divide] with MathBase { - override def getSupportLevel(expr: Divide): SupportLevel = { - if (expr.evalMode == EvalMode.ANSI) { - Incompatible(Some("ANSI mode is not supported")) - } else { - Compatible(None) - } - } - override def convert( expr: Divide, inputs: Seq[Attribute], @@ -192,7 +160,8 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase { // Datafusion now throws an exception for dividing by zero // See https://github.com/apache/arrow-datafusion/pull/6792 // For now, use NullIf to swap zeros with nulls. - val rightExpr = nullIfWhenPrimitive(expr.right) + val rightExpr = + if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(expr.right) else expr.right if (!supportedDataType(expr.left.dataType)) { withInfo(expr, s"Unsupported datatype ${expr.left.dataType}") return None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 600f4e45b5..daf0e45cc8 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -55,6 +55,11 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + val ARITHMETIC_OVERFLOW_EXCEPTION_MSG = + """org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error""" + val DIVIDE_BY_ZERO_EXCEPTION_MSG = + """org.apache.comet.CometNativeException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead""" + test("compare true/false to negative zero") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { @@ -2864,6 +2869,107 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for add") { + val data = Seq((Integer.MAX_VALUE, 1), (Integer.MIN_VALUE, -1)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + withParquetTable(data, "tbl") { + val res = spark.sql(""" + |SELECT + | _1 + _2 + | from tbl + | """.stripMargin) + + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG)) + assert(sparkExc.getMessage.contains("overflow")) + case _ => fail("Exception should be thrown") + } + } + } + } + + test("ANSI support for subtract") { + val data = Seq((Integer.MIN_VALUE, 1)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + withParquetTable(data, "tbl") { + val res = spark.sql(""" + |SELECT + | _1 - _2 + | from tbl + | """.stripMargin) + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG)) + assert(sparkExc.getMessage.contains("overflow")) + case _ => fail("Exception should be thrown") + } + } + } + } + + test("ANSI support for multiply") { + val data = Seq((Integer.MAX_VALUE, 10)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + withParquetTable(data, "tbl") { + val res = spark.sql(""" + |SELECT + | _1 * _2 + | from tbl + | """.stripMargin) + + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG)) + assert(sparkExc.getMessage.contains("overflow")) + case _ => fail("Exception should be thrown") + } + } + } + } + + test("ANSI support for divide (division by zero)") { + // TODO : Support ANSI mode in Integral divide - + val data = Seq((Integer.MIN_VALUE, 0)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + withParquetTable(data, "tbl") { + val res = spark.sql(""" + |SELECT + | _1 / _2 + | from tbl + | """.stripMargin) + + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG)) + assert(sparkExc.getMessage.contains("Division by zero")) + case _ => fail("Exception should be thrown") + } + } + } + } + + test("ANSI support for divide (division by zero) float division") { + // TODO : Support ANSI mode in Integral divide - + val data = Seq((Float.MinPositiveValue, 0.0)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + withParquetTable(data, "tbl") { + val res = spark.sql(""" + |SELECT + | _1 / _2 + | from tbl + | """.stripMargin) + + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG)) + assert(sparkExc.getMessage.contains("Division by zero")) + case _ => fail("Exception should be thrown") + } + } + } + } + test("test integral divide overflow for decimal") { if (isSpark40Plus) { Seq(true, false)