Skip to content

Commit 0ddd89f

Browse files
authored
feat:Support ANSI mode integral divide (#2421)
1 parent 9caeec1 commit 0ddd89f

File tree

6 files changed

+58
-26
lines changed

6 files changed

+58
-26
lines changed

native/core/src/execution/planner.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,6 @@ impl PhysicalPlanner {
287287
)
288288
}
289289
ExprStruct::IntegralDivide(expr) => {
290-
// TODO respect eval mode
291-
// https://github.com/apache/datafusion-comet/issues/533
292290
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
293291
self.create_binary_expr_with_options(
294292
expr.left.as_ref().unwrap(),
@@ -987,11 +985,12 @@ impl PhysicalPlanner {
987985
} else {
988986
"decimal_div"
989987
};
990-
let fun_expr = create_comet_physical_fun(
988+
let fun_expr = create_comet_physical_fun_with_eval_mode(
991989
func_name,
992990
data_type.clone(),
993991
&self.session_ctx.state(),
994992
None,
993+
eval_mode,
995994
)?;
996995
Ok(Arc::new(ScalarFunctionExpr::new(
997996
func_name,

native/spark-expr/benches/decimal_div.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use arrow::compute::cast;
2020
use arrow::datatypes::DataType;
2121
use criterion::{criterion_group, criterion_main, Criterion};
2222
use datafusion::physical_plan::ColumnarValue;
23-
use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div};
23+
use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div, EvalMode};
2424
use std::hint::black_box;
2525
use std::sync::Arc;
2626

@@ -48,6 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) {
4848
black_box(spark_decimal_div(
4949
black_box(&args),
5050
black_box(&DataType::Decimal128(10, 4)),
51+
EvalMode::Legacy,
5152
))
5253
})
5354
});
@@ -57,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) {
5758
black_box(spark_decimal_integral_div(
5859
black_box(&args),
5960
black_box(&DataType::Decimal128(10, 4)),
61+
EvalMode::Legacy,
6062
))
6163
})
6264
});

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,14 @@ pub fn create_comet_physical_fun_with_eval_mode(
137137
make_comet_scalar_udf!("unhex", func, without data_type)
138138
}
139139
"decimal_div" => {
140-
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
140+
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type, eval_mode)
141141
}
142142
"decimal_integral_div" => {
143143
make_comet_scalar_udf!(
144144
"decimal_integral_div",
145145
spark_decimal_integral_div,
146-
data_type
146+
data_type,
147+
eval_mode
147148
)
148149
}
149150
"checked_add" => {

native/spark-expr/src/math_funcs/div.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,33 @@
1616
// under the License.
1717

1818
use crate::math_funcs::utils::get_precision_scale;
19+
use crate::{divide_by_zero_error, EvalMode};
1920
use arrow::array::{Array, Decimal128Array};
2021
use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION};
22+
use arrow::error::ArrowError;
2123
use arrow::{
2224
array::{ArrayRef, AsArray},
2325
datatypes::Decimal128Type,
2426
};
2527
use datafusion::common::DataFusionError;
2628
use datafusion::physical_plan::ColumnarValue;
27-
use num::{BigInt, Signed, ToPrimitive};
29+
use num::{BigInt, Signed, ToPrimitive, Zero};
2830
use std::sync::Arc;
2931

3032
pub fn spark_decimal_div(
3133
args: &[ColumnarValue],
3234
data_type: &DataType,
35+
eval_mode: EvalMode,
3336
) -> Result<ColumnarValue, DataFusionError> {
34-
spark_decimal_div_internal(args, data_type, false)
37+
spark_decimal_div_internal(args, data_type, false, eval_mode)
3538
}
3639

3740
pub fn spark_decimal_integral_div(
3841
args: &[ColumnarValue],
3942
data_type: &DataType,
43+
eval_mode: EvalMode,
4044
) -> Result<ColumnarValue, DataFusionError> {
41-
spark_decimal_div_internal(args, data_type, true)
45+
spark_decimal_div_internal(args, data_type, true, eval_mode)
4246
}
4347

4448
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
@@ -50,6 +54,7 @@ fn spark_decimal_div_internal(
5054
args: &[ColumnarValue],
5155
data_type: &DataType,
5256
is_integral_div: bool,
57+
eval_mode: EvalMode,
5358
) -> Result<ColumnarValue, DataFusionError> {
5459
let left = &args[0];
5560
let right = &args[1];
@@ -80,9 +85,12 @@ fn spark_decimal_div_internal(
8085
let r_mul = ten.pow(r_exp);
8186
let five = BigInt::from(5);
8287
let zero = BigInt::from(0);
83-
arrow::compute::kernels::arity::binary(left, right, |l, r| {
88+
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
8489
let l = BigInt::from(l) * &l_mul;
8590
let r = BigInt::from(r) * &r_mul;
91+
if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() {
92+
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
93+
}
8694
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
8795
let res = if is_integral_div {
8896
div
@@ -91,14 +99,17 @@ fn spark_decimal_div_internal(
9199
} else {
92100
div + &five
93101
} / &ten;
94-
res.to_i128().unwrap_or(i128::MAX)
102+
Ok(res.to_i128().unwrap_or(i128::MAX))
95103
})?
96104
} else {
97105
let l_mul = 10_i128.pow(l_exp);
98106
let r_mul = 10_i128.pow(r_exp);
99-
arrow::compute::kernels::arity::binary(left, right, |l, r| {
107+
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
100108
let l = l * l_mul;
101109
let r = r * r_mul;
110+
if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() {
111+
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
112+
}
102113
let div = if r == 0 { 0 } else { l / r };
103114
let res = if is_integral_div {
104115
div
@@ -107,7 +118,7 @@ fn spark_decimal_div_internal(
107118
} else {
108119
div + 5
109120
} / 10;
110-
res.to_i128().unwrap_or(i128::MAX)
121+
Ok(res.to_i128().unwrap_or(i128::MAX))
111122
})?
112123
};
113124
let result = result.with_data_type(DataType::Decimal128(p3, s3));

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,6 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
180180

181181
object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with MathBase {
182182

183-
override def getSupportLevel(expr: IntegralDivide): SupportLevel = {
184-
if (expr.evalMode == EvalMode.ANSI) {
185-
Incompatible(Some("ANSI mode is not supported"))
186-
} else {
187-
Compatible(None)
188-
}
189-
}
190-
191183
override def convert(
192184
expr: IntegralDivide,
193185
inputs: Seq[Attribute],
@@ -206,9 +198,9 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
206198
if (expr.right.dataType.isInstanceOf[DecimalType]) expr.right
207199
else Cast(expr.right, DecimalType(19, 0))
208200

209-
val rightExpr = nullIfWhenPrimitive(right)
201+
val rightExpr = if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(right) else right
210202

211-
val dataType = (left.dataType, right.dataType) match {
203+
val dataType = (left.dataType, rightExpr.dataType) match {
212204
case (l: DecimalType, r: DecimalType) =>
213205
// copy from IntegralDivide.resultDecimalType
214206
val intDig = l.precision - l.scale + r.scale

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
5858
val ARITHMETIC_OVERFLOW_EXCEPTION_MSG =
5959
"""org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error"""
6060
val DIVIDE_BY_ZERO_EXCEPTION_MSG =
61-
"""org.apache.comet.CometNativeException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
61+
"""Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
6262

6363
test("compare true/false to negative zero") {
6464
Seq(false, true).foreach { dictionary =>
@@ -2949,7 +2949,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29492949
}
29502950

29512951
test("ANSI support for divide (division by zero)") {
2952-
// TODO : Support ANSI mode in Integral divide -
29532952
val data = Seq((Integer.MIN_VALUE, 0))
29542953
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
29552954
withParquetTable(data, "tbl") {
@@ -2970,7 +2969,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29702969
}
29712970

29722971
test("ANSI support for divide (division by zero) float division") {
2973-
// TODO : Support ANSI mode in Integral divide -
29742972
val data = Seq((Float.MinPositiveValue, 0.0))
29752973
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
29762974
withParquetTable(data, "tbl") {
@@ -2990,6 +2988,35 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29902988
}
29912989
}
29922990

2991+
test("ANSI support for integral divide (division by zero)") {
2992+
val data = Seq((Integer.MAX_VALUE, 0))
2993+
Seq("true", "false").foreach { p =>
2994+
withSQLConf(SQLConf.ANSI_ENABLED.key -> p) {
2995+
withParquetTable(data, "tbl") {
2996+
val res = spark.sql("""
2997+
|SELECT
2998+
| _1 div _2
2999+
| from tbl
3000+
| """.stripMargin)
3001+
3002+
checkSparkMaybeThrows(res) match {
3003+
case (Some(sparkException), Some(cometException)) =>
3004+
assert(sparkException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
3005+
assert(cometException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
3006+
case (None, None) => checkSparkAnswerAndOperator(res)
3007+
case (None, Some(ex)) =>
3008+
fail(
3009+
"Comet threw an exception but Spark did not. Comet exception: " + ex.getMessage)
3010+
case (Some(sparkException), None) =>
3011+
fail(
3012+
"Spark threw an exception but Comet did not. Spark exception: " +
3013+
sparkException.getMessage)
3014+
}
3015+
}
3016+
}
3017+
}
3018+
}
3019+
29933020
test("test integral divide overflow for decimal") {
29943021
if (isSpark40Plus) {
29953022
Seq(true, false)

0 commit comments

Comments
 (0)