Skip to content

Commit

Permalink
Add customizable equality and hash functions to UDFs (#252)
Browse files Browse the repository at this point in the history
* Add customizable equality and hash functions to UDFs

* Improve equals and hash_value documentation

* Add tests for parameterized UDFs

* CI issues
  • Loading branch information
joroKr21 authored Jul 12, 2024
1 parent 79c606a commit 55cdadf
Show file tree
Hide file tree
Showing 7 changed files with 372 additions and 48 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async fn test_parameter_invalid_types() -> Result<()> {
.await;
assert_eq!(
results.unwrap_err().strip_backtrace(),
"Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })"
"Arrow error: Invalid argument error: Nested comparison: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) (hint: use make_comparator instead)"
);
Ok(())
}
80 changes: 75 additions & 5 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions
use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{
types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray,
};
use arrow_schema::Schema;

use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
use datafusion::test_util::plan_and_collect;
use datafusion::{
Expand All @@ -45,8 +50,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
LogicalPlanBuilder, SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -377,6 +382,56 @@ async fn test_groups_accumulator() -> Result<()> {
Ok(())
}

#[ignore]
#[tokio::test]
async fn test_parameterized_aggregate_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable);
let udf1 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 1,
});
let udf2 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 2,
});

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.aggregate(
[col("text")],
[
udf1.call(vec![col("text")]).alias("a"),
udf2.call(vec![col("text")]).alias("b"),
],
)?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+------+---+---+",
"| text | a | b |",
"+------+---+---+",
"| foo | 1 | 2 |",
"+------+---+---+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
Expand Down Expand Up @@ -735,6 +790,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
self.result == other.result && self.signature == other.signature
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.signature.hash(hasher);
self.result.hash(hasher);
hasher.finish()
}
}

impl Accumulator for TestGroupsAccumulator {
Expand Down
128 changes: 125 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
// under the License.

use std::any::Any;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
use arrow_array::builder::BooleanBuilder;
use arrow_array::cast::AsArray;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use parking_lot::Mutex;
use regex::Regex;
use sqlparser::ast::Ident;

use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
Expand All @@ -37,8 +46,6 @@ use datafusion_expr::{
Volatility,
};
use datafusion_functions_array::range::range_udf;
use parking_lot::Mutex;
use sqlparser::ast::Ident;

/// test that casting happens on udfs.
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
Expand Down Expand Up @@ -1010,6 +1017,121 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<(
Ok(())
}

#[derive(Debug)]
struct MyRegexUdf {
signature: Signature,
regex: Regex,
}

impl MyRegexUdf {
fn new(pattern: &str) -> Self {
Self {
signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
regex: Regex::new(pattern).expect("regex"),
}
}

fn matches(&self, value: Option<&str>) -> Option<bool> {
Some(self.regex.is_match(value?))
}
}

impl ScalarUDFImpl for MyRegexUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"regex_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if matches!(args, [DataType::Utf8]) {
Ok(DataType::Boolean)
} else {
plan_err!("regex_udf only accepts a Utf8 argument")
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
self.matches(value.as_deref()),
)))
}
[ColumnarValue::Array(values)] => {
let mut builder = BooleanBuilder::with_capacity(values.len());
for value in values.as_string::<i32>() {
builder.append_option(self.matches(value))
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
}
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
self.regex.as_str() == other.regex.as_str()
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.regex.as_str().hash(hasher);
hasher.finish()
}
}

#[tokio::test]
async fn test_parameterized_scalar_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}"));
let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar"));

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.filter(
foo_udf
.call(vec![col("text")])
.and(bar_udf.call(vec![col("text")])),
)?
.filter(col("text").is_not_null())?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+--------+",
"| text |",
"+--------+",
"| foobar |",
"| barfoo |",
"+--------+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
71 changes: 58 additions & 13 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::vec;

use arrow::datatypes::{DataType, Field};

use datafusion_common::{exec_err, not_impl_err, Result};

use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
Expand All @@ -25,12 +35,6 @@ use crate::utils::format_state_name;
use crate::utils::AggregateOrderSensitivity;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::vec;

/// Logical representation of a user-defined [aggregate function] (UDAF).
///
Expand Down Expand Up @@ -70,20 +74,19 @@ pub struct AggregateUDF {

impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
self.inner.equals(other.inner.as_ref())
}
}

impl Eq for AggregateUDF {}

impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
impl Hash for AggregateUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}

impl std::fmt::Display for AggregateUDF {
impl fmt::Display for AggregateUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
Expand Down Expand Up @@ -276,7 +279,7 @@ where
/// #[derive(Debug, Clone)]
/// struct GeoMeanUdf {
/// signature: Signature
/// };
/// }
///
/// impl GeoMeanUdf {
/// fn new() -> Self {
Expand Down Expand Up @@ -503,6 +506,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}

/// Return true if this aggregate UDF is equal to the other.
///
/// Allows customizing the equality of aggregate UDFs.
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
///
/// - reflexive: `a.equals(a)`;
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
///
/// By default, compares [`Self::name`] and [`Self::signature`].
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}

/// Returns a hash value for this aggregate UDF.
///
/// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
///
/// By default, hashes [`Self::name`] and [`Self::signature`].
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
}

pub enum ReversedUDAF {
Expand Down Expand Up @@ -558,6 +588,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.inner.hash_value().hash(hasher);
self.aliases.hash(hasher);
hasher.finish()
}
}

/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers
Expand Down
Loading

0 comments on commit 55cdadf

Please sign in to comment.