Skip to content

Commit 6b3ba80

Browse files
committed
support_ansi_avg_wip
1 parent 12b8fb2 commit 6b3ba80

File tree

3 files changed

+53
-37
lines changed

3 files changed

+53
-37
lines changed

native/core/src/execution/planner.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ use datafusion_comet_proto::{
113113
};
114114
use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId;
115115
use datafusion_comet_spark_expr::{
116-
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
117-
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike,
118-
RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr,
119-
ToJson, UnboundColumn, Variance, AvgInt
116+
ArrayInsert, Avg, AvgDecimal, AvgInt, Cast, CheckOverflow, Correlation, Covariance,
117+
CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract,
118+
NormalizeNaNAndZero, RLike, RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr,
119+
SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance,
120120
};
121121
use itertools::Itertools;
122122
use jni::objects::GlobalRef;
@@ -1854,8 +1854,11 @@ impl PhysicalPlanner {
18541854
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
18551855
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
18561856
let builder = match datatype {
1857-
1858-
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 | DataType::Int32 => {
1857+
DataType::Int8
1858+
| DataType::UInt8
1859+
| DataType::Int16
1860+
| DataType::UInt16
1861+
| DataType::Int32 => {
18591862
let func =
18601863
AggregateUDF::new_from_impl(AvgInt::new(datatype, input_datatype));
18611864
AggregateExprBuilder::new(Arc::new(func), vec![child])

native/spark-expr/src/agg_funcs/avg_int.rs

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::any::Any;
18+
use crate::{AvgDecimal, EvalMode};
1919
use arrow::array::{ArrayRef, BooleanArray};
2020
use arrow::datatypes::{DataType, FieldRef};
2121
use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
22-
use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature};
2322
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
2423
use datafusion::logical_expr::type_coercion::aggregates::avg_return_type;
2524
use datafusion::logical_expr::Volatility::Immutable;
26-
use crate::{AvgDecimal, EvalMode};
25+
use datafusion::logical_expr::{
26+
Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
27+
};
28+
use std::any::Any;
2729

2830
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2931
pub struct AvgInt {
@@ -34,13 +36,13 @@ pub struct AvgInt {
3436
impl AvgInt {
3537
pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult<Self> {
3638
match data_type {
37-
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
38-
Ok(Self {
39+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self {
3940
signature: Signature::user_defined(Immutable),
40-
eval_mode
41-
})
42-
},
43-
_ => {Err(DataFusionError::Internal("inalid data type for AvgInt".to_string()))}
41+
eval_mode,
42+
}),
43+
_ => Err(DataFusionError::Internal(
44+
"inalid data type for AvgInt".to_string(),
45+
)),
4446
}
4547
}
4648
}
@@ -58,7 +60,7 @@ impl AggregateUDFImpl for AvgInt {
5860
ReversedUDAF::Identical
5961
}
6062

61-
fn signature(&self) -> &Signature {
63+
fn signature(&self) -> &Signature {
6264
&self.signature
6365
}
6466

@@ -70,7 +72,10 @@ impl AggregateUDFImpl for AvgInt {
7072
true
7173
}
7274

73-
fn accumulator(&self, acc_args: AccumulatorArgs) -> datafusion::common::Result<Box<dyn Accumulator>> {
75+
fn accumulator(
76+
&self,
77+
acc_args: AccumulatorArgs,
78+
) -> datafusion::common::Result<Box<dyn Accumulator>> {
7479
todo!()
7580
}
7681

@@ -82,7 +87,10 @@ impl AggregateUDFImpl for AvgInt {
8287
false
8388
}
8489

85-
fn create_groups_accumulator(&self, _args: AccumulatorArgs) -> datafusion::common::Result<Box<dyn GroupsAccumulator>> {
90+
fn create_groups_accumulator(
91+
&self,
92+
_args: AccumulatorArgs,
93+
) -> datafusion::common::Result<Box<dyn GroupsAccumulator>> {
8694
Ok(Box::new(AvgIntGroupsAccumulator::new(self.eval_mode)))
8795
}
8896

@@ -91,37 +99,36 @@ impl AggregateUDFImpl for AvgInt {
9199
}
92100
}
93101

94-
struct AvgIntegerAccumulator{
102+
struct AvgIntegerAccumulator {
95103
sum: Option<i64>,
96104
count: u64,
97105
eval_mode: EvalMode,
98106
}
99107

100108
impl AvgIntegerAccumulator {
101109
fn new(eval_mode: EvalMode) -> Self {
102-
Self{
103-
sum : Some(0),
110+
Self {
111+
sum: Some(0),
104112
count: 0,
105-
eval_mode
113+
eval_mode,
106114
}
107115
}
108116
}
109117

110-
impl Accumulator for AvgIntegerAccumulator {
111-
112-
}
118+
impl Accumulator for AvgIntegerAccumulator {}
113119

114-
struct AvgIntGroupsAccumulator {
115-
116-
}
117-
118-
impl AvgIntGroupsAccumulator {
119-
120-
}
120+
struct AvgIntGroupsAccumulator {}
121121

122+
impl AvgIntGroupsAccumulator {}
122123

123124
impl GroupsAccumulator for AvgIntGroupsAccumulator {
124-
fn update_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion::common::Result<()> {
125+
fn update_batch(
126+
&mut self,
127+
values: &[ArrayRef],
128+
group_indices: &[usize],
129+
opt_filter: Option<&BooleanArray>,
130+
total_num_groups: usize,
131+
) -> datafusion::common::Result<()> {
125132
todo!()
126133
}
127134

@@ -133,11 +140,17 @@ impl GroupsAccumulator for AvgIntGroupsAccumulator {
133140
todo!()
134141
}
135142

136-
fn merge_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion::common::Result<()> {
143+
fn merge_batch(
144+
&mut self,
145+
values: &[ArrayRef],
146+
group_indices: &[usize],
147+
opt_filter: Option<&BooleanArray>,
148+
total_num_groups: usize,
149+
) -> datafusion::common::Result<()> {
137150
todo!()
138151
}
139152

140153
fn size(&self) -> usize {
141154
todo!()
142155
}
143-
}
156+
}

native/spark-expr/src/agg_funcs/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@
1717

1818
mod avg;
1919
mod avg_decimal;
20+
mod avg_int;
2021
mod correlation;
2122
mod covariance;
2223
mod stddev;
2324
mod sum_decimal;
2425
mod variance;
25-
mod avg_int;
2626

2727
pub use avg::Avg;
2828
pub use avg_decimal::AvgDecimal;
29+
pub use avg_int::AvgInt;
2930
pub use correlation::Correlation;
3031
pub use covariance::Covariance;
3132
pub use stddev::Stddev;
3233
pub use sum_decimal::SumDecimal;
3334
pub use variance::Variance;
34-
pub use avg_int::AvgInt;

0 commit comments

Comments
 (0)