diff --git a/Cargo.lock b/Cargo.lock index 2ee907b30cf02..9b55ae6675e4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2568,6 +2568,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-nested", "log", "percent-encoding", diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 9b2e7429ab3b0..c07958a858ed4 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -415,7 +415,7 @@ impl Accumulator for ArrayAggAccumulator { } #[derive(Debug)] -struct DistinctArrayAggAccumulator { +pub struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, sort_options: Option, diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 673b62c5c3485..0dc35f4a87776 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -48,6 +48,7 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true, features = ["crypto_expressions"] } +datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } log = { workspace = true } percent-encoding = "2.3.2" diff --git a/datafusion/spark/src/function/aggregate/collect.rs b/datafusion/spark/src/function/aggregate/collect.rs new file mode 100644 index 0000000000000..50497e2826383 --- /dev/null +++ b/datafusion/spark/src/function/aggregate/collect.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::utils::SingleRowListArrayBuilder; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_functions_aggregate::array_agg::{ + ArrayAggAccumulator, DistinctArrayAggAccumulator, +}; +use std::{any::Any, sync::Arc}; + +// Spark implementation of collect_list/collect_set aggregate function. +// Differs from DataFusion ArrayAgg in the following ways: +// - ignores NULL inputs +// - returns an empty list when all inputs are NULL +// - does not support ordering + +// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCollectList { + signature: Signature, +} + +impl Default for SparkCollectList { + fn default() -> Self { + Self::new() + } +} + +impl SparkCollectList { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SparkCollectList { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "collect_list" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new_list( + format_state_name(args.name, "collect_list"), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), + true, + ) + .into(), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let field = &acc_args.expr_fields[0]; + let data_type = field.data_type().clone(); + let ignore_nulls = true; + Ok(Box::new(NullToEmptyListAccumulator::new( + ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?, + data_type, + ))) + } +} + +// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCollectSet { + signature: Signature, +} + +impl Default for SparkCollectSet { + fn default() -> Self { + Self::new() + } +} + +impl SparkCollectSet { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SparkCollectSet { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "collect_set" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new_list( + format_state_name(args.name, "collect_set"), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), + true, + ) + .into(), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let field = &acc_args.expr_fields[0]; + let data_type = field.data_type().clone(); + let ignore_nulls = true; + Ok(Box::new(NullToEmptyListAccumulator::new( + DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls)?, + data_type, + ))) + } +} + +/// Wrapper accumulator that returns an empty list instead of NULL when all inputs are NULL. +/// This implements Spark's behavior for collect_list and collect_set. +#[derive(Debug)] +struct NullToEmptyListAccumulator { + inner: T, + data_type: DataType, +} + +impl NullToEmptyListAccumulator { + pub fn new(inner: T, data_type: DataType) -> Self { + Self { inner, data_type } + } +} + +impl Accumulator for NullToEmptyListAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.inner.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.inner.merge_batch(states) + } + + fn state(&mut self) -> Result> { + self.inner.state() + } + + fn evaluate(&mut self) -> Result { + let result = self.inner.evaluate()?; + if result.is_null() { + let empty_array = arrow::array::new_empty_array(&self.data_type); + Ok(SingleRowListArrayBuilder::new(empty_array).build_list_scalar()) + } else { + Ok(result) + } + } + + fn size(&self) -> usize { + self.inner.size() + self.data_type.size() + } +} diff --git a/datafusion/spark/src/function/aggregate/mod.rs b/datafusion/spark/src/function/aggregate/mod.rs index 3db72669d42bd..d6a2fe7a8503e 100644 --- a/datafusion/spark/src/function/aggregate/mod.rs +++ b/datafusion/spark/src/function/aggregate/mod.rs @@ -19,6 +19,7 @@ use datafusion_expr::AggregateUDF; use std::sync::Arc; pub mod avg; +pub mod collect; pub mod try_sum; pub mod expr_fn { @@ -30,6 +31,16 @@ pub mod expr_fn { "Returns the sum of values for a column, or NULL if overflow occurs", arg1 )); + export_functions!(( + collect_list, + "Returns a list created from the values in a column", + arg1 + )); + export_functions!(( + collect_set, + "Returns a set created from the values in a column", + arg1 + )); } // TODO: try use something like datafusion_functions_aggregate::create_func!() @@ -39,7 +50,13 @@ pub fn avg() -> Arc { pub fn try_sum() -> Arc { Arc::new(AggregateUDF::new_from_impl(try_sum::SparkTrySum::new())) } +pub fn collect_list() -> Arc { + Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectList::new())) +} +pub fn collect_set() -> Arc { + Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectSet::new())) +} pub fn functions() -> Vec> { - vec![avg(), try_sum()] + vec![avg(), try_sum(), collect_list(), collect_set()] } diff --git a/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt b/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt new file mode 100644 index 0000000000000..2bd80e2e13283 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query ? +SELECT collect_list(a) FROM (VALUES (1), (2), (3)) AS t(a); +---- +[1, 2, 3] + +query ? +SELECT collect_list(a) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a); +---- +[1, 2, 2, 3, 1] + +query ? +SELECT collect_list(a) FROM (VALUES (1), (NULL), (3)) AS t(a); +---- +[1, 3] + +query ? +SELECT collect_list(a) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS t(a); +---- +[] + +query I? +SELECT g, collect_list(a) +FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a) +GROUP BY g +ORDER BY g; +---- +1 [10, 20, 10] +2 [30, 30] + +query I? +SELECT g, collect_list(a) +FROM (VALUES (1, 10), (1, NULL), (2, 20), (2, NULL)) AS t(g, a) +GROUP BY g +ORDER BY g; +---- +1 [10] +2 [20] + +# we need to wrap collect_set with array_sort to have consistent outputs +query ? +SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (3)) AS t(a); +---- +[1, 2, 3] + +query ? +SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a); +---- +[1, 2, 3] + +query ? +SELECT array_sort(collect_set(a)) FROM (VALUES (1), (NULL), (3)) AS t(a); +---- +[1, 3] + +query ? +SELECT array_sort(collect_set(a)) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS t(a); +---- +[] + +query I? +SELECT g, array_sort(collect_set(a)) +FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a) +GROUP BY g +ORDER BY g; +---- +1 [10, 20] +2 [30] + +query I? +SELECT g, array_sort(collect_set(a)) +FROM (VALUES (1, 10), (1, NULL), (1, NULL), (2, 20), (2, NULL)) AS t(g, a) +GROUP BY g +ORDER BY g; +---- +1 [10] +2 [20]