Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ impl Accumulator for ArrayAggAccumulator {
}

#[derive(Debug)]
struct DistinctArrayAggAccumulator {
pub struct DistinctArrayAggAccumulator {
values: HashSet<ScalarValue>,
datatype: DataType,
sort_options: Option<SortOptions>,
Expand Down
1 change: 1 addition & 0 deletions datafusion/spark/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { workspace = true, features = ["crypto_expressions"] }
datafusion-functions-aggregate = { workspace = true }
datafusion-functions-nested = { workspace = true }
log = { workspace = true }
percent-encoding = "2.3.2"
Expand Down
200 changes: 200 additions & 0 deletions datafusion/spark/src/function/aggregate/collect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::utils::SingleRowListArrayBuilder;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use datafusion_functions_aggregate::array_agg::{
ArrayAggAccumulator, DistinctArrayAggAccumulator,
};
use std::{any::Any, sync::Arc};

// Spark implementation of collect_list/collect_set aggregate function.
// Differs from DataFusion ArrayAgg in the following ways:
// - ignores NULL inputs
// - returns an empty list when all inputs are NULL
// - does not support ordering

// <https://spark.apache.org/docs/latest/api/sql/index.html#collect_list>
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkCollectList {
signature: Signature,
}

impl Default for SparkCollectList {
fn default() -> Self {
Self::new()
}
}

impl SparkCollectList {
pub fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for SparkCollectList {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"collect_list"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new_list_field(
arg_types[0].clone(),
true,
))))
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(vec![
Field::new_list(
format_state_name(args.name, "collect_list"),
Field::new_list_field(args.input_fields[0].data_type().clone(), true),
true,
)
.into(),
])
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let field = &acc_args.expr_fields[0];
let data_type = field.data_type().clone();
let ignore_nulls = true;
Ok(Box::new(NullToEmptyListAccumulator::new(
ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?,
data_type,
)))
}
}

// <https://spark.apache.org/docs/latest/api/sql/index.html#collect_set>
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkCollectSet {
signature: Signature,
}

impl Default for SparkCollectSet {
fn default() -> Self {
Self::new()
}
}

impl SparkCollectSet {
pub fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for SparkCollectSet {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"collect_set"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new_list_field(
arg_types[0].clone(),
true,
))))
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(vec![
Field::new_list(
format_state_name(args.name, "collect_set"),
Field::new_list_field(args.input_fields[0].data_type().clone(), true),
true,
)
.into(),
])
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let field = &acc_args.expr_fields[0];
let data_type = field.data_type().clone();
let ignore_nulls = true;
Ok(Box::new(NullToEmptyListAccumulator::new(
DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls)?,
data_type,
)))
}
}

/// Wrapper accumulator that returns an empty list instead of NULL when all inputs are NULL.
/// This implements Spark's behavior for collect_list and collect_set.
#[derive(Debug)]
struct NullToEmptyListAccumulator<T: Accumulator> {
inner: T,
data_type: DataType,
}

impl<T: Accumulator> NullToEmptyListAccumulator<T> {
pub fn new(inner: T, data_type: DataType) -> Self {
Self { inner, data_type }
}
}

impl<T: Accumulator> Accumulator for NullToEmptyListAccumulator<T> {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
self.inner.update_batch(values)
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.inner.merge_batch(states)
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.inner.state()
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let result = self.inner.evaluate()?;
if result.is_null() {
let empty_array = arrow::array::new_empty_array(&self.data_type);
Ok(SingleRowListArrayBuilder::new(empty_array).build_list_scalar())
} else {
Ok(result)
}
}

fn size(&self) -> usize {
self.inner.size() + self.data_type.size()
}
}
19 changes: 18 additions & 1 deletion datafusion/spark/src/function/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use datafusion_expr::AggregateUDF;
use std::sync::Arc;

pub mod avg;
pub mod collect;
pub mod try_sum;

pub mod expr_fn {
Expand All @@ -30,6 +31,16 @@ pub mod expr_fn {
"Returns the sum of values for a column, or NULL if overflow occurs",
arg1
));
export_functions!((
collect_list,
"Returns a list created from the values in a column",
arg1
));
export_functions!((
collect_set,
"Returns a set created from the values in a column",
arg1
));
}

// TODO: try use something like datafusion_functions_aggregate::create_func!()
Expand All @@ -39,7 +50,13 @@ pub fn avg() -> Arc<AggregateUDF> {
pub fn try_sum() -> Arc<AggregateUDF> {
Arc::new(AggregateUDF::new_from_impl(try_sum::SparkTrySum::new()))
}
pub fn collect_list() -> Arc<AggregateUDF> {
Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectList::new()))
}
pub fn collect_set() -> Arc<AggregateUDF> {
Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectSet::new()))
}

pub fn functions() -> Vec<Arc<AggregateUDF>> {
vec![avg(), try_sum()]
vec![avg(), try_sum(), collect_list(), collect_set()]
}
93 changes: 93 additions & 0 deletions datafusion/sqllogictest/test_files/spark/aggregate/collect.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

query ?
SELECT collect_list(a) FROM (VALUES (1), (2), (3)) AS t(a);
----
[1, 2, 3]

query ?
SELECT collect_list(a) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a);
----
[1, 2, 2, 3, 1]

query ?
SELECT collect_list(a) FROM (VALUES (1), (NULL), (3)) AS t(a);
----
[1, 3]

query ?
SELECT collect_list(a) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS t(a);
----
[]

query I?
SELECT g, collect_list(a)
FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a)
GROUP BY g
ORDER BY g;
----
1 [10, 20, 10]
2 [30, 30]

query I?
SELECT g, collect_list(a)
FROM (VALUES (1, 10), (1, NULL), (2, 20), (2, NULL)) AS t(g, a)
GROUP BY g
ORDER BY g;
----
1 [10]
2 [20]

# we need to wrap collect_set with array_sort to have consistent outputs
query ?
SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (3)) AS t(a);
----
[1, 2, 3]

query ?
SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a);
----
[1, 2, 3]

query ?
SELECT array_sort(collect_set(a)) FROM (VALUES (1), (NULL), (3)) AS t(a);
----
[1, 3]

query ?
SELECT array_sort(collect_set(a)) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS t(a);
----
[]

query I?
SELECT g, array_sort(collect_set(a))
FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a)
GROUP BY g
ORDER BY g;
----
1 [10, 20]
2 [30]

query I?
SELECT g, array_sort(collect_set(a))
FROM (VALUES (1, 10), (1, NULL), (1, NULL), (2, 20), (2, NULL)) AS t(g, a)
GROUP BY g
ORDER BY g;
----
1 [10]
2 [20]