Skip to content

Commit 272b661

Browse files
committed
support_ansi_sum_decimal_input
1 parent 5df4821 commit 272b661

File tree

1 file changed

+41
-33
lines changed

1 file changed

+41
-33
lines changed

native/spark-expr/src/agg_funcs/sum_decimal.rs

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,10 @@ impl Accumulator for SumDecimalAccumulator {
221221
Some(sum_value) if is_valid_decimal_precision(sum_value, self.precision) => {
222222
ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale)
223223
}
224-
_ => {
225-
ScalarValue::new_primitive::<Decimal128Type>(
226-
None,
227-
&DataType::Decimal128(self.precision, self.scale),
228-
)
229-
}
224+
_ => ScalarValue::new_primitive::<Decimal128Type>(
225+
None,
226+
&DataType::Decimal128(self.precision, self.scale),
227+
),
230228
}
231229
}
232230
}
@@ -240,12 +238,10 @@ impl Accumulator for SumDecimalAccumulator {
240238
Some(sum_value) => {
241239
ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale)?
242240
}
243-
None => {
244-
ScalarValue::new_primitive::<Decimal128Type>(
245-
None,
246-
&DataType::Decimal128(self.precision, self.scale),
247-
)?
248-
}
241+
None => ScalarValue::new_primitive::<Decimal128Type>(
242+
None,
243+
&DataType::Decimal128(self.precision, self.scale),
244+
)?,
249245
};
250246

251247
// For decimal sum, always return 2 state values regardless of eval_mode
@@ -392,33 +388,45 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
392388
fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
393389
match emit_to {
394390
EmitTo::All => {
395-
let result = Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map(|(&sum, &empty)| {
396-
if empty {
397-
None
398-
} else {
399-
match sum {
400-
Some(v) if is_valid_decimal_precision(v, self.precision) => Some(v),
401-
_ => None,
402-
}
403-
}
404-
}))
405-
.with_data_type(self.result_type.clone());
391+
let result =
392+
Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map(
393+
|(&sum, &empty)| {
394+
if empty {
395+
None
396+
} else {
397+
match sum {
398+
Some(v) if is_valid_decimal_precision(v, self.precision) => {
399+
Some(v)
400+
}
401+
_ => None,
402+
}
403+
}
404+
},
405+
))
406+
.with_data_type(self.result_type.clone());
406407

407408
self.sum.clear();
408409
self.is_empty.clear();
409410
Ok(Arc::new(result))
410411
}
411412
EmitTo::First(n) => {
412-
let result = Decimal128Array::from_iter(self.sum.drain(..n).zip(self.is_empty.drain(..n)).map(|(sum, empty)| {
413-
if empty {
414-
None
415-
} else {
416-
match sum {
417-
Some(v) if is_valid_decimal_precision(v, self.precision) => Some(v),
418-
_ => None,
419-
}
420-
}
421-
}))
413+
let result = Decimal128Array::from_iter(
414+
self.sum
415+
.drain(..n)
416+
.zip(self.is_empty.drain(..n))
417+
.map(|(sum, empty)| {
418+
if empty {
419+
None
420+
} else {
421+
match sum {
422+
Some(v) if is_valid_decimal_precision(v, self.precision) => {
423+
Some(v)
424+
}
425+
_ => None,
426+
}
427+
}
428+
}),
429+
)
422430
.with_data_type(self.result_type.clone());
423431

424432
Ok(Arc::new(result))

0 commit comments

Comments
 (0)