From 92bdbebdcbf07f38e5f70b92b749cc3eb6975cd3 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Mon, 14 Oct 2024 20:14:17 +0300 Subject: [PATCH] Ensure that math functions fulfil the ColumnarValue contract (#275) If all UDF arguments are scalars, so should be the result. In most cases, such function calls will be contant-folded, however if for whatever reason the are not optimized, we want to avoid an error due to array length mismatch. --- datafusion/expr/src/columnar_value.rs | 14 ++++++++++++-- datafusion/functions/src/macros.rs | 16 ++++++++-------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index bfefb37c98d7..7b614ba9c491 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -17,8 +17,7 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. -use arrow::array::ArrayRef; -use arrow::array::NullArray; +use arrow::array::{Array, ArrayRef, NullArray}; use arrow::compute::{kernels, CastOptions}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::format::DEFAULT_CAST_OPTIONS; @@ -218,6 +217,17 @@ impl ColumnarValue { } } } + + /// Converts an [`ArrayRef`] to a [`ColumnarValue`] based on the supplied arguments. + /// This is useful for scalar UDF implementations to fulfil their contract: + /// if all arguments are scalar values, the result should also be a scalar value. + pub fn from_args_and_result(args: &[Self], result: ArrayRef) -> Result { + if result.len() == 1 && args.iter().all(|arg| matches!(arg, Self::Scalar(_))) { + Ok(Self::Scalar(ScalarValue::try_from_array(&result, 0)?)) + } else { + Ok(Self::Array(result)) + } + } } #[cfg(test)] diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index cae689b3e0cb..1aadaae6c1d6 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -222,9 +222,8 @@ macro_rules! make_math_unary_udf { $OUTPUT_ORDERING(input) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - + fn invoke(&self, col_args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(col_args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => { Arc::new(make_function_scalar_inputs_return_type!( @@ -251,7 +250,8 @@ macro_rules! make_math_unary_udf { ) } }; - Ok(ColumnarValue::Array(arr)) + + ColumnarValue::from_args_and_result(col_args, arr) } } } @@ -332,9 +332,8 @@ macro_rules! make_math_binary_udf { $OUTPUT_ORDERING(input) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - + fn invoke(&self, col_args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(col_args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new(make_function_inputs2!( &args[0], @@ -360,7 +359,8 @@ macro_rules! make_math_binary_udf { ) } }; - Ok(ColumnarValue::Array(arr)) + + ColumnarValue::from_args_and_result(col_args, arr) } } }