-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Deprecate AggregateUDFImpl::is_nullable in favor of return_field nullability inference #19688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
af2e94f
c5fe87b
86e4e03
3d4eeee
3b1e671
4d9cbef
15c0d12
1e487fb
26e2261
fa355e8
5bfae7a
5a220b6
e928c23
9590099
f263d91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -208,6 +208,7 @@ impl AggregateUDF { | |||||
| self.inner.window_function_display_name(params) | ||||||
| } | ||||||
|
|
||||||
| #[allow(deprecated)] | ||||||
| pub fn is_nullable(&self) -> bool { | ||||||
| self.inner.is_nullable() | ||||||
| } | ||||||
|
|
@@ -528,10 +529,32 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { | |||||
|
|
||||||
| /// Whether the aggregate function is nullable. | ||||||
| /// | ||||||
| /// **DEPRECATED**: This method is deprecated and will be removed in a future version. | ||||||
| /// Nullability should instead be specified in [`Self::return_field`] which can provide | ||||||
| /// more context-aware nullability based on input field properties. | ||||||
| /// | ||||||
| /// Nullable means that the function could return `null` for any inputs. | ||||||
| /// For example, aggregate functions like `COUNT` always return a non null value | ||||||
| /// but others like `MIN` will return `NULL` if there is nullable input. | ||||||
| /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null` | ||||||
| /// | ||||||
| /// # Migration Guide | ||||||
| /// | ||||||
| /// If you need to override nullability, implement [`Self::return_field`] instead: | ||||||
| /// | ||||||
| /// ```ignore | ||||||
| /// fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> { | ||||||
| /// let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect(); | ||||||
| /// let data_type = self.return_type(&arg_types)?; | ||||||
| /// // Specify nullability based on your function's logic | ||||||
| /// let nullable = arg_fields.iter().any(|f| f.is_nullable()); | ||||||
| /// Ok(Arc::new(Field::new(self.name(), data_type, nullable))) | ||||||
| /// } | ||||||
| /// ``` | ||||||
| #[deprecated( | ||||||
| since = "52.0.0", | ||||||
| note = "Use `return_field` to specify nullability instead of `is_nullable`" | ||||||
| )] | ||||||
| fn is_nullable(&self) -> bool { | ||||||
| true | ||||||
| } | ||||||
|
|
@@ -1091,17 +1114,27 @@ pub fn udaf_default_window_function_display_name<F: AggregateUDFImpl + ?Sized>( | |||||
| } | ||||||
|
|
||||||
| /// Encapsulates default implementation of [`AggregateUDFImpl::return_field`]. | ||||||
| /// | ||||||
| /// This function computes nullability based on input field nullability: | ||||||
| /// - The result is nullable if ANY input field is nullable | ||||||
| /// - The result is non-nullable only if ALL input fields are non-nullable | ||||||
| /// | ||||||
| /// This replaces the previous behavior of always deferring to `is_nullable()`, | ||||||
| /// providing more accurate nullability inference for aggregate functions. | ||||||
| pub fn udaf_default_return_field<F: AggregateUDFImpl + ?Sized>( | ||||||
| func: &F, | ||||||
| arg_fields: &[FieldRef], | ||||||
| ) -> Result<FieldRef> { | ||||||
| let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect(); | ||||||
| let data_type = func.return_type(&arg_types)?; | ||||||
|
|
||||||
| // Determine nullability: result is nullable if any input is nullable | ||||||
| let is_nullable = arg_fields.iter().any(|f| f.is_nullable()); | ||||||
|
|
||||||
| Ok(Arc::new(Field::new( | ||||||
| func.name(), | ||||||
| data_type, | ||||||
| func.is_nullable(), | ||||||
| is_nullable, | ||||||
| ))) | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -1247,6 +1280,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { | |||||
| self.inner.return_field(arg_fields) | ||||||
| } | ||||||
|
|
||||||
| #[allow(deprecated)] | ||||||
| fn is_nullable(&self) -> bool { | ||||||
| self.inner.is_nullable() | ||||||
| } | ||||||
|
|
@@ -1343,7 +1377,7 @@ mod test { | |||||
| &self.signature | ||||||
| } | ||||||
| fn return_type(&self, _args: &[DataType]) -> Result<DataType> { | ||||||
| unimplemented!() | ||||||
| Ok(DataType::Float64) | ||||||
| } | ||||||
| fn accumulator( | ||||||
| &self, | ||||||
|
|
@@ -1383,7 +1417,7 @@ mod test { | |||||
| &self.signature | ||||||
| } | ||||||
| fn return_type(&self, _args: &[DataType]) -> Result<DataType> { | ||||||
| unimplemented!() | ||||||
| Ok(DataType::Float64) | ||||||
| } | ||||||
| fn accumulator( | ||||||
| &self, | ||||||
|
|
@@ -1424,4 +1458,71 @@ mod test { | |||||
| value.hash(hasher); | ||||||
| hasher.finish() | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_return_field_nullability_from_nullable_input() { | ||||||
| // Test that return_field derives nullability from input field nullability | ||||||
| use arrow::datatypes::Field; | ||||||
| use std::sync::Arc; | ||||||
|
|
||||||
| let udf = AggregateUDF::from(AMeanUdf::new()); | ||||||
|
|
||||||
| // Create a nullable input field | ||||||
| let nullable_field = Arc::new(Field::new("col", DataType::Float64, true)); | ||||||
| let return_field = udf.return_field(&[nullable_field]).unwrap(); | ||||||
|
|
||||||
| // When input is nullable, output should be nullable | ||||||
| assert!(return_field.is_nullable()); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_return_field_nullability_from_non_nullable_input() { | ||||||
| // Test that return_field respects non-nullable input fields | ||||||
| use arrow::datatypes::Field; | ||||||
| use std::sync::Arc; | ||||||
|
|
||||||
| let udf = AggregateUDF::from(AMeanUdf::new()); | ||||||
|
|
||||||
| // Create a non-nullable input field | ||||||
| let non_nullable_field = Arc::new(Field::new("col", DataType::Float64, false)); | ||||||
| let return_field = udf.return_field(&[non_nullable_field]).unwrap(); | ||||||
|
|
||||||
| // When input is non-nullable, output should also be non-nullable | ||||||
| assert!(!return_field.is_nullable()); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_return_field_nullability_with_mixed_inputs() { | ||||||
| // Test that return_field is nullable if ANY input is nullable | ||||||
| use arrow::datatypes::Field; | ||||||
| use std::sync::Arc; | ||||||
|
|
||||||
| let a = AggregateUDF::from(AMeanUdf::new()); | ||||||
|
|
||||||
| // With multiple inputs (typical for aggregates in more complex scenarios) | ||||||
| let nullable_field = Arc::new(Field::new("col1", DataType::Float64, true)); | ||||||
| let non_nullable_field = Arc::new(Field::new("col2", DataType::Float64, false)); | ||||||
|
|
||||||
| let return_field = a.return_field(&[non_nullable_field, nullable_field]).unwrap(); | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| // If ANY input is nullable, result should be nullable | ||||||
| assert!(return_field.is_nullable()); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn test_return_field_preserves_return_type() { | ||||||
| // Test that return_field correctly preserves the return type | ||||||
| use arrow::datatypes::Field; | ||||||
| use std::sync::Arc; | ||||||
|
|
||||||
| let udf = AggregateUDF::from(AMeanUdf::new()); | ||||||
|
|
||||||
| let nullable_field = Arc::new(Field::new("col", DataType::Float64, true)); | ||||||
| let return_field = udf.return_field(&[nullable_field]).unwrap(); | ||||||
|
|
||||||
| // Verify data type is preserved | ||||||
| assert_eq!(*return_field.data_type(), DataType::Float64); | ||||||
| // Verify name matches function name | ||||||
| assert_eq!(return_field.name(), "a"); | ||||||
| } | ||||||
| } | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use std::collections::HashMap; | ||
| use std::fmt::Debug; | ||
| use std::mem::{size_of, size_of_val}; | ||
| use std::sync::Arc; | ||
|
|
@@ -52,7 +53,7 @@ use datafusion_expr::{ | |
| }; | ||
| use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; | ||
| use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; | ||
| use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; | ||
| use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable}; | ||
| use datafusion_macros::user_doc; | ||
|
|
||
| use crate::utils::validate_percentile_expr; | ||
|
|
@@ -427,14 +428,51 @@ impl<T: ArrowNumericType + Debug> Accumulator for PercentileContAccumulator<T> { | |
| } | ||
|
|
||
| fn evaluate(&mut self) -> Result<ScalarValue> { | ||
| let d = std::mem::take(&mut self.all_values); | ||
| let value = calculate_percentile::<T>(d, self.percentile); | ||
| let value = calculate_percentile::<T>(&mut self.all_values, self.percentile); | ||
| ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE) | ||
| } | ||
|
|
||
| fn size(&self) -> usize { | ||
| size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>() | ||
| } | ||
|
|
||
| fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||
GaneshPatil7517 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if values.is_empty() { | ||
| return Ok(()); | ||
| } | ||
| let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new(); | ||
| for i in 0..values[0].len() { | ||
| let v = ScalarValue::try_from_array(&values[0], i)?; | ||
| if !v.is_null() { | ||
| *to_remove.entry(v).or_default() += 1; | ||
| } | ||
| } | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Return early here is |
||
| let mut i = 0; | ||
| while i < self.all_values.len() { | ||
| let k = | ||
| ScalarValue::new_primitive::<T>(Some(self.all_values[i]), &T::DATA_TYPE)?; | ||
| if let Some(count) = to_remove.get_mut(&k) | ||
| && *count > 0 | ||
| { | ||
| self.all_values.swap_remove(i); | ||
| *count -= 1; | ||
| if *count == 0 { | ||
| to_remove.remove(&k); | ||
| if to_remove.is_empty() { | ||
| break; | ||
| } | ||
| } | ||
| } else { | ||
| i += 1; | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn supports_retract_batch(&self) -> bool { | ||
| true | ||
| } | ||
| } | ||
|
|
||
| /// The percentile_cont groups accumulator accumulates the raw input values | ||
|
|
@@ -549,13 +587,13 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator | |
|
|
||
| fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> { | ||
| // Emit values | ||
| let emit_group_values = emit_to.take_needed(&mut self.group_values); | ||
| let mut emit_group_values = emit_to.take_needed(&mut self.group_values); | ||
|
|
||
| // Calculate percentile for each group | ||
| let mut evaluate_result_builder = | ||
| PrimitiveBuilder::<T>::with_capacity(emit_group_values.len()); | ||
| for values in emit_group_values { | ||
| let value = calculate_percentile::<T>(values, self.percentile); | ||
| for values in &mut emit_group_values { | ||
| let value = calculate_percentile::<T>(values.as_mut_slice(), self.percentile); | ||
| evaluate_result_builder.append_option(value); | ||
| } | ||
|
|
||
|
|
@@ -652,17 +690,31 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula | |
| } | ||
|
|
||
| fn evaluate(&mut self) -> Result<ScalarValue> { | ||
| let d = std::mem::take(&mut self.distinct_values.values) | ||
| .into_iter() | ||
| .map(|v| v.0) | ||
| .collect::<Vec<_>>(); | ||
| let value = calculate_percentile::<T>(d, self.percentile); | ||
| let mut values: Vec<T::Native> = | ||
| self.distinct_values.values.iter().map(|v| v.0).collect(); | ||
| let value = calculate_percentile::<T>(&mut values, self.percentile); | ||
| ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE) | ||
| } | ||
|
|
||
| fn size(&self) -> usize { | ||
| size_of_val(self) + self.distinct_values.size() | ||
| } | ||
|
|
||
| fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||
| if values.is_empty() { | ||
| return Ok(()); | ||
| } | ||
|
|
||
| let arr = values[0].as_primitive::<T>(); | ||
| for value in arr.iter().flatten() { | ||
| self.distinct_values.values.remove(&Hashable(value)); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a .slt test for this ? |
||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn supports_retract_batch(&self) -> bool { | ||
| true | ||
| } | ||
| } | ||
|
|
||
| /// Calculate the percentile value for a given set of values. | ||
|
|
@@ -672,8 +724,12 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula | |
| /// For percentile p and n values: | ||
| /// - If p * (n-1) is an integer, return the value at that position | ||
| /// - Otherwise, interpolate between the two closest values | ||
| /// | ||
| /// Note: This function takes a mutable slice and sorts it in place, but does not | ||
| /// consume the data. This is important for window frame queries where evaluate() | ||
| /// may be called multiple times on the same accumulator state. | ||
| fn calculate_percentile<T: ArrowNumericType>( | ||
| mut values: Vec<T::Native>, | ||
| values: &mut [T::Native], | ||
| percentile: f64, | ||
| ) -> Option<T::Native> { | ||
| let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.