Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions datafusion/expr-common/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,37 @@ pub trait Accumulator: Send + Sync + Debug {
/// running sum.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;

/// Returns the final aggregate value, consuming the internal state.
/// Returns the final aggregate value.
///
/// For example, the `SUM` accumulator maintains a running sum,
/// and `evaluate` will produce that running sum as its output.
///
/// This function should not be called twice, otherwise it will
/// result in potentially non-deterministic behavior.
///
/// This function gets `&mut self` to allow for the accumulator to build
/// arrow-compatible internal state that can be returned without copying
/// when possible (for example distinct strings)
/// when possible (for example distinct strings).
///
/// # Window Frame Queries
///
/// When used in a window context without [`Self::supports_retract_batch`],
/// `evaluate()` may be called multiple times on the same accumulator instance
/// (once per row in the partition). In this case, implementations **must not**
/// consume or modify internal state. Use references or clones to preserve state:
///
/// ```ignore
/// // GOOD: Preserves state for subsequent calls
/// fn evaluate(&mut self) -> Result<ScalarValue> {
/// calculate_result(&self.values) // Use reference
/// }
///
/// // BAD: Consumes state, breaks window queries
/// fn evaluate(&mut self) -> Result<ScalarValue> {
/// calculate_result(std::mem::take(&mut self.values))
/// }
/// ```
///
/// For efficient sliding window calculations, consider implementing
/// [`Self::retract_batch`] which allows DataFusion to incrementally
/// update state rather than calling `evaluate()` repeatedly.
fn evaluate(&mut self) -> Result<ScalarValue>;

/// Returns the allocated size required for this accumulator, in
Expand Down
107 changes: 104 additions & 3 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ impl AggregateUDF {
self.inner.window_function_display_name(params)
}

#[allow(deprecated)]
pub fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}
Expand Down Expand Up @@ -528,10 +529,32 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {

/// Whether the aggregate function is nullable.
///
/// **DEPRECATED**: This method is deprecated and will be removed in a future version.
/// Nullability should instead be specified in [`Self::return_field`] which can provide
/// more context-aware nullability based on input field properties.
///
/// Nullable means that the function could return `null` for any inputs.
/// For example, aggregate functions like `COUNT` always return a non null value
/// but others like `MIN` will return `NULL` if there is nullable input.
/// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
///
/// # Migration Guide
///
/// If you need to override nullability, implement [`Self::return_field`] instead:
///
/// ```ignore
/// fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
/// let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect();
/// let data_type = self.return_type(&arg_types)?;
/// // Specify nullability based on your function's logic
/// let nullable = arg_fields.iter().any(|f| f.is_nullable());
/// Ok(Arc::new(Field::new(self.name(), data_type, nullable)))
/// }
/// ```
#[deprecated(
since = "52.0.0",
note = "Use `return_field` to specify nullability instead of `is_nullable`"
)]
fn is_nullable(&self) -> bool {
true
}
Expand Down Expand Up @@ -1091,17 +1114,27 @@ pub fn udaf_default_window_function_display_name<F: AggregateUDFImpl + ?Sized>(
}

/// Encapsulates default implementation of [`AggregateUDFImpl::return_field`].
///
/// This function computes nullability based on input field nullability:
/// - The result is nullable if ANY input field is nullable
/// - The result is non-nullable only if ALL input fields are non-nullable
///
/// This replaces the previous behavior of always deferring to `is_nullable()`,
/// providing more accurate nullability inference for aggregate functions.
pub fn udaf_default_return_field<F: AggregateUDFImpl + ?Sized>(
func: &F,
arg_fields: &[FieldRef],
) -> Result<FieldRef> {
let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect();
let data_type = func.return_type(&arg_types)?;

// Determine nullability: result is nullable if any input is nullable
let is_nullable = arg_fields.iter().any(|f| f.is_nullable());

Ok(Arc::new(Field::new(
func.name(),
data_type,
func.is_nullable(),
is_nullable,
)))
}

Expand Down Expand Up @@ -1247,6 +1280,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
self.inner.return_field(arg_fields)
}

#[allow(deprecated)]
fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}
Expand Down Expand Up @@ -1343,7 +1377,7 @@ mod test {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
unimplemented!()
Ok(DataType::Float64)
}
fn accumulator(
&self,
Expand Down Expand Up @@ -1383,7 +1417,7 @@ mod test {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
unimplemented!()
Ok(DataType::Float64)
}
fn accumulator(
&self,
Expand Down Expand Up @@ -1424,4 +1458,71 @@ mod test {
value.hash(hasher);
hasher.finish()
}

#[test]
fn test_return_field_nullability_from_nullable_input() {
// Test that return_field derives nullability from input field nullability
use arrow::datatypes::Field;
use std::sync::Arc;

let udf = AggregateUDF::from(AMeanUdf::new());

// Create a nullable input field
let nullable_field = Arc::new(Field::new("col", DataType::Float64, true));
let return_field = udf.return_field(&[nullable_field]).unwrap();

// When input is nullable, output should be nullable
assert!(return_field.is_nullable());
}

#[test]
fn test_return_field_nullability_from_non_nullable_input() {
// Test that return_field respects non-nullable input fields
use arrow::datatypes::Field;
use std::sync::Arc;

let udf = AggregateUDF::from(AMeanUdf::new());

// Create a non-nullable input field
let non_nullable_field = Arc::new(Field::new("col", DataType::Float64, false));
let return_field = udf.return_field(&[non_nullable_field]).unwrap();

// When input is non-nullable, output should also be non-nullable
assert!(!return_field.is_nullable());
}

#[test]
fn test_return_field_nullability_with_mixed_inputs() {
// Test that return_field is nullable if ANY input is nullable
use arrow::datatypes::Field;
use std::sync::Arc;

let a = AggregateUDF::from(AMeanUdf::new());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let a = AggregateUDF::from(AMeanUdf::new());
let udf = AggregateUDF::from(AMeanUdf::new());


// With multiple inputs (typical for aggregates in more complex scenarios)
let nullable_field = Arc::new(Field::new("col1", DataType::Float64, true));
let non_nullable_field = Arc::new(Field::new("col2", DataType::Float64, false));

let return_field = a.return_field(&[non_nullable_field, nullable_field]).unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let return_field = a.return_field(&[non_nullable_field, nullable_field]).unwrap();
let return_field = udf.return_field(&[non_nullable_field, nullable_field]).unwrap();


// If ANY input is nullable, result should be nullable
assert!(return_field.is_nullable());
}

#[test]
fn test_return_field_preserves_return_type() {
// Test that return_field correctly preserves the return type
use arrow::datatypes::Field;
use std::sync::Arc;

let udf = AggregateUDF::from(AMeanUdf::new());

let nullable_field = Arc::new(Field::new("col", DataType::Float64, true));
let return_field = udf.return_field(&[nullable_field]).unwrap();

// Verify data type is preserved
assert_eq!(*return_field.data_type(), DataType::Float64);
// Verify name matches function name
assert_eq!(return_field.name(), "a");
}
}
80 changes: 68 additions & 12 deletions datafusion/functions-aggregate/src/percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;
use std::fmt::Debug;
use std::mem::{size_of, size_of_val};
use std::sync::Arc;
Expand Down Expand Up @@ -52,7 +53,7 @@ use datafusion_expr::{
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable};
use datafusion_macros::user_doc;

use crate::utils::validate_percentile_expr;
Expand Down Expand Up @@ -427,14 +428,51 @@ impl<T: ArrowNumericType + Debug> Accumulator for PercentileContAccumulator<T> {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let d = std::mem::take(&mut self.all_values);
let value = calculate_percentile::<T>(d, self.percentile);
let value = calculate_percentile::<T>(&mut self.all_values, self.percentile);
ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
}

fn size(&self) -> usize {
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
for i in 0..values[0].len() {
let v = ScalarValue::try_from_array(&values[0], i)?;
if !v.is_null() {
*to_remove.entry(v).or_default() += 1;
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return early here is to_remove.is_empty() ?

let mut i = 0;
while i < self.all_values.len() {
let k =
ScalarValue::new_primitive::<T>(Some(self.all_values[i]), &T::DATA_TYPE)?;
if let Some(count) = to_remove.get_mut(&k)
&& *count > 0
{
self.all_values.swap_remove(i);
*count -= 1;
if *count == 0 {
to_remove.remove(&k);
if to_remove.is_empty() {
break;
}
}
} else {
i += 1;
}
}
Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}
}

/// The percentile_cont groups accumulator accumulates the raw input values
Expand Down Expand Up @@ -549,13 +587,13 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
// Emit values
let emit_group_values = emit_to.take_needed(&mut self.group_values);
let mut emit_group_values = emit_to.take_needed(&mut self.group_values);

// Calculate percentile for each group
let mut evaluate_result_builder =
PrimitiveBuilder::<T>::with_capacity(emit_group_values.len());
for values in emit_group_values {
let value = calculate_percentile::<T>(values, self.percentile);
for values in &mut emit_group_values {
let value = calculate_percentile::<T>(values.as_mut_slice(), self.percentile);
evaluate_result_builder.append_option(value);
}

Expand Down Expand Up @@ -652,17 +690,31 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let d = std::mem::take(&mut self.distinct_values.values)
.into_iter()
.map(|v| v.0)
.collect::<Vec<_>>();
let value = calculate_percentile::<T>(d, self.percentile);
let mut values: Vec<T::Native> =
self.distinct_values.values.iter().map(|v| v.0).collect();
let value = calculate_percentile::<T>(&mut values, self.percentile);
ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
}

fn size(&self) -> usize {
size_of_val(self) + self.distinct_values.size()
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = values[0].as_primitive::<T>();
for value in arr.iter().flatten() {
self.distinct_values.values.remove(&Hashable(value));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a .slt test for this ?

}
Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}
}

/// Calculate the percentile value for a given set of values.
Expand All @@ -672,8 +724,12 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
/// For percentile p and n values:
/// - If p * (n-1) is an integer, return the value at that position
/// - Otherwise, interpolate between the two closest values
///
/// Note: This function takes a mutable slice and sorts it in place, but does not
/// consume the data. This is important for window frame queries where evaluate()
/// may be called multiple times on the same accumulator state.
fn calculate_percentile<T: ArrowNumericType>(
mut values: Vec<T::Native>,
values: &mut [T::Native],
percentile: f64,
) -> Option<T::Native> {
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
Expand Down
13 changes: 6 additions & 7 deletions datafusion/functions-aggregate/src/string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,13 @@ impl Accumulator for SimpleStringAggAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let result = if self.has_value {
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
if self.has_value {
Ok(ScalarValue::LargeUtf8(Some(
self.accumulated_string.clone(),
)))
} else {
ScalarValue::LargeUtf8(None)
};

self.has_value = false;
Ok(result)
Ok(ScalarValue::LargeUtf8(None))
}
}

fn size(&self) -> usize {
Expand Down
Loading