Skip to content

Commit 8aba4be

Browse files
committed
init
1 parent eea40ca commit 8aba4be

File tree

6 files changed

+40
-28
lines changed

6 files changed

+40
-28
lines changed

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ message ScalarFunc {
308308
string func = 1;
309309
repeated Expr args = 2;
310310
DataType return_type = 3;
311+
bool fail_on_error = 4;
311312
}
312313

313314
message CaseWhen {

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
115115
make_comet_scalar_udf!("rpad", func, without data_type)
116116
}
117117
"round" => {
118-
make_comet_scalar_udf!("round", spark_round, data_type)
118+
make_comet_scalar_udf!("round", spark_round, data_type, eval_mode)
119119
}
120120
"unscaled_value" => {
121121
let func = Arc::new(spark_unscaled_value);

native/spark-expr/src/math_funcs/round.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
19+
use crate::EvalMode;
1920
use arrow::array::{Array, ArrowNativeTypeOp};
2021
use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
2122
use arrow::datatypes::DataType;
@@ -68,6 +69,7 @@ macro_rules! round_integer_scalar {
6869
pub fn spark_round(
6970
args: &[ColumnarValue],
7071
data_type: &DataType,
72+
eval_mode: EvalMode,
7173
) -> Result<ColumnarValue, DataFusionError> {
7274
let value = &args[0];
7375
let point = &args[1];
@@ -141,7 +143,7 @@ fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
141143
mod test {
142144
use std::sync::Arc;
143145

144-
use crate::spark_round;
146+
use crate::{spark_round, EvalMode};
145147

146148
use arrow::array::{Float32Array, Float64Array};
147149
use arrow::datatypes::DataType;
@@ -158,7 +160,9 @@ mod test {
158160
]))),
159161
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
160162
];
161-
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else {
163+
let ColumnarValue::Array(result) =
164+
spark_round(&args, &DataType::Float32, EvalMode::Legacy)?
165+
else {
162166
unreachable!()
163167
};
164168
let floats = as_float32_array(&result)?;
@@ -176,7 +180,9 @@ mod test {
176180
]))),
177181
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
178182
];
179-
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else {
183+
let ColumnarValue::Array(result) =
184+
spark_round(&args, &DataType::Float64, EvalMode::Legacy)?
185+
else {
180186
unreachable!()
181187
};
182188
let floats = as_float64_array(&result)?;
@@ -193,7 +199,7 @@ mod test {
193199
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
194200
];
195201
let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
196-
spark_round(&args, &DataType::Float32)?
202+
spark_round(&args, &DataType::Float32, EvalMode::Legacy)?
197203
else {
198204
unreachable!()
199205
};
@@ -209,7 +215,7 @@ mod test {
209215
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
210216
];
211217
let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
212-
spark_round(&args, &DataType::Float64)?
218+
spark_round(&args, &DataType::Float64, EvalMode::Legacy)?
213219
else {
214220
unreachable!()
215221
};

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
634634
* The input attributes
635635
* @param binding
636636
* Whether to bind the expression to the input attributes
637+
* @param failOnError
638+
* Should fail on error in case there is an exception on native side. Default false unless ANSI mode is enabled
637639
* @return
638640
* The protobuf representation of the expression, or None if the expression is not supported.
639641
* In the case where None is returned, the expression will be tagged with the reason(s) why it
@@ -642,7 +644,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
642644
def exprToProtoInternal(
643645
expr: Expression,
644646
inputs: Seq[Attribute],
645-
binding: Boolean): Option[Expr] = {
647+
binding: Boolean,
648+
failOnError:Boolean =false): Option[Expr] = {
646649
val conf = SQLConf.get
647650

648651
def convert[T <: Expression](expr: T, handler: CometExpressionSerde[T]): Option[Expr] = {
@@ -681,7 +684,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
681684
if (notes.isDefined) {
682685
logWarning(s"Comet supports $expr but has notes: ${notes.get}")
683686
}
684-
handler.convert(expr, inputs, binding)
687+
handler.convert(expr, inputs, binding, failOnError)
685688
}
686689
}
687690

@@ -1019,7 +1022,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
10191022
builder: ScalarFunc.Builder,
10201023
args: Option[Expr]*): Option[Expr] = {
10211024
args.foreach {
1022-
case Some(a) => builder.addArgs(a)
1025+
case Some(a) =>
1026+
builder.addArgs(a)
10231027
case _ =>
10241028
return None
10251029
}
@@ -1918,7 +1922,7 @@ trait CometExpressionSerde[T <: Expression] {
19181922
* case it is expected that the input expression will have been tagged with reasons why it
19191923
* could not be converted.
19201924
*/
1921-
def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr]
1925+
def convert(expr: T, inputs: Seq[Attribute], binding: Boolean, failOnError:Boolean = false): Option[ExprOuterClass.Expr]
19221926
}
19231927

19241928
/**

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase {
9090
override def convert(
9191
expr: Add,
9292
inputs: Seq[Attribute],
93-
binding: Boolean): Option[ExprOuterClass.Expr] = {
93+
binding: Boolean,
94+
failOnError:Boolean): Option[ExprOuterClass.Expr] = {
9495
if (!supportedDataType(expr.left.dataType)) {
9596
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
9697
return None
@@ -112,7 +113,8 @@ object CometSubtract extends CometExpressionSerde[Subtract] with MathBase {
112113
override def convert(
113114
expr: Subtract,
114115
inputs: Seq[Attribute],
115-
binding: Boolean): Option[ExprOuterClass.Expr] = {
116+
binding: Boolean,
117+
failOnError:Boolean): Option[ExprOuterClass.Expr] = {
116118
if (!supportedDataType(expr.left.dataType)) {
117119
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
118120
return None
@@ -134,7 +136,8 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
134136
override def convert(
135137
expr: Multiply,
136138
inputs: Seq[Attribute],
137-
binding: Boolean): Option[ExprOuterClass.Expr] = {
139+
binding: Boolean,
140+
failOnError:Boolean): Option[ExprOuterClass.Expr] = {
138141
if (!supportedDataType(expr.left.dataType)) {
139142
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
140143
return None
@@ -156,7 +159,8 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
156159
override def convert(
157160
expr: Divide,
158161
inputs: Seq[Attribute],
159-
binding: Boolean): Option[ExprOuterClass.Expr] = {
162+
binding: Boolean,
163+
failOnError:Boolean): Option[ExprOuterClass.Expr] = {
160164
// Datafusion now throws an exception for dividing by zero
161165
// See https://github.com/apache/arrow-datafusion/pull/6792
162166
// For now, use NullIf to swap zeros with nulls.
@@ -191,7 +195,8 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
191195
override def convert(
192196
expr: IntegralDivide,
193197
inputs: Seq[Attribute],
194-
binding: Boolean): Option[ExprOuterClass.Expr] = {
198+
binding: Boolean,
199+
failOnError:Boolean): Option[ExprOuterClass.Expr] = {
195200
if (!supportedDataType(expr.left.dataType)) {
196201
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
197202
return None
@@ -263,7 +268,8 @@ object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
263268
override def convert(
264269
expr: Remainder,
265270
inputs: Seq[Attribute],
266-
binding: Boolean): Option[ExprOuterClass.Expr] = {
271+
binding: Boolean,
272+
failOnError:Boolean): Option[ExprOuterClass.Expr] = {
267273
if (!supportedDataType(expr.left.dataType)) {
268274
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
269275
return None
@@ -287,21 +293,15 @@ object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
287293

288294
object CometRound extends CometExpressionSerde[Round] {
289295

290-
override def getSupportLevel(expr: Round): SupportLevel = {
291-
if (expr.ansiEnabled) {
292-
Incompatible(Some("ANSI mode is not supported"))
293-
} else {
294-
Compatible(None)
295-
}
296-
}
297-
298296
override def convert(
299297
r: Round,
300298
inputs: Seq[Attribute],
301-
binding: Boolean): Option[ExprOuterClass.Expr] = {
299+
binding: Boolean,
300+
failOnError:Boolean): Option[ExprOuterClass.Expr] = {
302301
// _scale s a constant, copied from Spark's RoundBase because it is a protected val
303302
val scaleV: Any = r.scale.eval(EmptyRow)
304303
val _scale: Int = scaleV.asInstanceOf[Int]
304+
val isAnsiEnabled = r.ansiEnabled
305305

306306
lazy val childExpr = exprToProtoInternal(r.child, inputs, binding)
307307
r.child.dataType match {
@@ -331,7 +331,7 @@ object CometRound extends CometExpressionSerde[Round] {
331331
None
332332
case _ =>
333333
// `scale` must be Int64 type in DataFusion
334-
val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding)
334+
val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding,isAnsiEnabled)
335335
val optExpr =
336336
scalarFunctionExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr)
337337
optExprWithInfo(optExpr, r, r.child)
@@ -344,7 +344,8 @@ object CometUnaryMinus extends CometExpressionSerde[UnaryMinus] {
344344
override def convert(
345345
expr: UnaryMinus,
346346
inputs: Seq[Attribute],
347-
binding: Boolean): Option[ExprOuterClass.Expr] = {
347+
binding: Boolean,
348+
failOnError: Boolean): Option[ExprOuterClass.Expr] = {
348349
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
349350
if (childExpr.isDefined) {
350351
val builder = ExprOuterClass.UnaryMinus.newBuilder()

spark/src/main/scala/org/apache/comet/serde/math.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ object CometHex extends CometExpressionSerde[Hex] with MathExprBase {
118118
inputs: Seq[Attribute],
119119
binding: Boolean): Option[ExprOuterClass.Expr] = {
120120
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
121-
val optExpr = scalarFunctionExprToProtoWithReturnType("hex", expr.dataType, childExpr)
121+
val optExpr = scalarFunctionExprToProtoWithReturnType("hex", expr.dataType)
122122
optExprWithInfo(optExpr, expr, expr.child)
123123
}
124124
}

0 commit comments

Comments
 (0)