Skip to content

Commit 12b8fb2

Browse files
committed
support_ansi_avg_wip
1 parent 94a8711 commit 12b8fb2

File tree

5 files changed

+156
-15
lines changed

5 files changed

+156
-15
lines changed

native/core/src/execution/planner.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ use datafusion_comet_spark_expr::{
116116
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
117117
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike,
118118
RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr,
119-
ToJson, UnboundColumn, Variance,
119+
ToJson, UnboundColumn, Variance, AvgInt
120120
};
121121
use itertools::Itertools;
122122
use jni::objects::GlobalRef;
@@ -1854,6 +1854,12 @@ impl PhysicalPlanner {
18541854
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
18551855
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
18561856
let builder = match datatype {
1857+
1858+
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 | DataType::Int32 => {
1859+
let func =
1860+
AggregateUDF::new_from_impl(AvgInt::new(datatype, input_datatype));
1861+
AggregateExprBuilder::new(Arc::new(func), vec![child])
1862+
}
18571863
DataType::Decimal128(_, _) => {
18581864
let func =
18591865
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ message Avg {
137137
Expr child = 1;
138138
DataType datatype = 2;
139139
DataType sum_datatype = 3;
140-
bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs Legacy mode)
140+
EvalMode eval_mode = 4;
141141
}
142142

143143
message First {
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use arrow::array::{ArrayRef, BooleanArray};
20+
use arrow::datatypes::{DataType, FieldRef};
21+
use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
22+
use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature};
23+
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
24+
use datafusion::logical_expr::type_coercion::aggregates::avg_return_type;
25+
use datafusion::logical_expr::Volatility::Immutable;
26+
use crate::{AvgDecimal, EvalMode};
27+
28+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29+
pub struct AvgInt {
30+
signature: Signature,
31+
eval_mode: EvalMode,
32+
}
33+
34+
impl AvgInt {
35+
pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult<Self> {
36+
match data_type {
37+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
38+
Ok(Self {
39+
signature: Signature::user_defined(Immutable),
40+
eval_mode
41+
})
42+
},
43+
_ => {Err(DataFusionError::Internal("inalid data type for AvgInt".to_string()))}
44+
}
45+
}
46+
}
47+
48+
impl AggregateUDFImpl for AvgInt {
49+
fn as_any(&self) -> &dyn Any {
50+
self
51+
}
52+
53+
fn name(&self) -> &str {
54+
"avg"
55+
}
56+
57+
fn reverse_expr(&self) -> ReversedUDAF {
58+
ReversedUDAF::Identical
59+
}
60+
61+
fn signature(&self) -> &Signature {
62+
&self.signature
63+
}
64+
65+
fn return_type(&self, arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
66+
avg_return_type(self.name(), &arg_types[0])
67+
}
68+
69+
fn is_nullable(&self) -> bool {
70+
true
71+
}
72+
73+
fn accumulator(&self, acc_args: AccumulatorArgs) -> datafusion::common::Result<Box<dyn Accumulator>> {
74+
todo!()
75+
}
76+
77+
fn state_fields(&self, args: StateFieldsArgs) -> datafusion::common::Result<Vec<FieldRef>> {
78+
todo!()
79+
}
80+
81+
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
82+
false
83+
}
84+
85+
fn create_groups_accumulator(&self, _args: AccumulatorArgs) -> datafusion::common::Result<Box<dyn GroupsAccumulator>> {
86+
Ok(Box::new(AvgIntGroupsAccumulator::new(self.eval_mode)))
87+
}
88+
89+
fn default_value(&self, data_type: &DataType) -> datafusion::common::Result<ScalarValue> {
90+
todo!()
91+
}
92+
}
93+
94+
struct AvgIntegerAccumulator{
95+
sum: Option<i64>,
96+
count: u64,
97+
eval_mode: EvalMode,
98+
}
99+
100+
impl AvgIntegerAccumulator {
101+
fn new(eval_mode: EvalMode) -> Self {
102+
Self{
103+
sum : Some(0),
104+
count: 0,
105+
eval_mode
106+
}
107+
}
108+
}
109+
110+
impl Accumulator for AvgIntegerAccumulator {
111+
112+
}
113+
114+
struct AvgIntGroupsAccumulator {
115+
116+
}
117+
118+
impl AvgIntGroupsAccumulator {
119+
120+
}
121+
122+
123+
impl GroupsAccumulator for AvgIntGroupsAccumulator {
124+
fn update_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion::common::Result<()> {
125+
todo!()
126+
}
127+
128+
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion::common::Result<ArrayRef> {
129+
todo!()
130+
}
131+
132+
fn state(&mut self, emit_to: EmitTo) -> datafusion::common::Result<Vec<ArrayRef>> {
133+
todo!()
134+
}
135+
136+
fn merge_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion::common::Result<()> {
137+
todo!()
138+
}
139+
140+
fn size(&self) -> usize {
141+
todo!()
142+
}
143+
}

native/spark-expr/src/agg_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod covariance;
2222
mod stddev;
2323
mod sum_decimal;
2424
mod variance;
25+
mod avg_int;
2526

2627
pub use avg::Avg;
2728
pub use avg_decimal::AvgDecimal;
@@ -30,3 +31,4 @@ pub use covariance::Covariance;
3031
pub use stddev::Stddev;
3132
pub use sum_decimal::SumDecimal;
3233
pub use variance::Variance;
34+
pub use avg_int::AvgInt;

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType
2929
import org.apache.comet.CometConf
3030
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
3131
import org.apache.comet.CometSparkSessionExtensions.withInfo
32-
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
32+
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType}
33+
import org.apache.comet.shims.CometEvalModeUtil
3334

3435
object CometMin extends CometAggregateExpressionSerde[Min] {
3536

@@ -150,17 +151,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] {
150151

151152
object CometAverage extends CometAggregateExpressionSerde[Average] {
152153

153-
override def getSupportLevel(avg: Average): SupportLevel = {
154-
avg.evalMode match {
155-
case EvalMode.ANSI =>
156-
Incompatible(Some("ANSI mode is not supported"))
157-
case EvalMode.TRY =>
158-
Incompatible(Some("TRY mode is not supported"))
159-
case _ =>
160-
Compatible()
161-
}
162-
}
163-
164154
override def convert(
165155
aggExpr: AggregateExpression,
166156
avg: Average,
@@ -192,7 +182,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
192182
val builder = ExprOuterClass.Avg.newBuilder()
193183
builder.setChild(childExpr.get)
194184
builder.setDatatype(dataType.get)
195-
builder.setFailOnError(avg.evalMode == EvalMode.ANSI)
185+
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode)))
196186
builder.setSumDatatype(sumDataType.get)
197187

198188
Some(

0 commit comments

Comments
 (0)