Skip to content

Commit

Permalink
Allow customizing UDF equality and hash (#251)
Browse files Browse the repository at this point in the history
* Add customizable equality and hash functions to UDFs

* Improve equals and hash_value documentation

* Add tests for parameterized UDFs

* Cargo clippy
  • Loading branch information
joroKr21 authored Jul 12, 2024
1 parent 1d9a64b commit 932a62f
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 46 deletions.
79 changes: 75 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions
use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{
types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray,
};
use arrow_schema::Schema;

use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
use datafusion::test_util::plan_and_collect;
use datafusion::{
Expand All @@ -45,7 +50,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
col, create_udaf, AggregateUDFImpl, GroupsAccumulator, LogicalPlanBuilder,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -376,6 +382,56 @@ async fn test_groups_accumulator() -> Result<()> {
Ok(())
}

#[ignore]
#[tokio::test]
async fn test_parameterized_aggregate_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable);
let udf1 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 1,
});
let udf2 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 2,
});

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.aggregate(
[col("text")],
[
udf1.call(vec![col("text")]).alias("a"),
udf2.call(vec![col("text")]).alias("b"),
],
)?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+------+---+---+",
"| text | a | b |",
"+------+---+---+",
"| foo | 1 | 2 |",
"+------+---+---+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
Expand Down Expand Up @@ -733,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
self.result == other.result && self.signature == other.signature
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.signature.hash(hasher);
self.result.hash(hasher);
hasher.finish()
}
}

impl Accumulator for TestGroupsAccumulator {
Expand Down
130 changes: 126 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,23 @@
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::iter;
use std::sync::Arc;

use arrow::compute::kernels::numeric::add;
use arrow_array::builder::BooleanBuilder;
use arrow_array::cast::AsArray;
use arrow_array::StringArray;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use rand::{thread_rng, Rng};
use regex::Regex;

use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
Expand All @@ -36,10 +47,6 @@ use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
use std::sync::Arc;

/// test that casting happens on udfs.
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
Expand Down Expand Up @@ -961,6 +968,121 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
Ok(())
}

#[derive(Debug)]
struct MyRegexUdf {
signature: Signature,
regex: Regex,
}

impl MyRegexUdf {
fn new(pattern: &str) -> Self {
Self {
signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
regex: Regex::new(pattern).expect("regex"),
}
}

fn matches(&self, value: Option<&str>) -> Option<bool> {
Some(self.regex.is_match(value?))
}
}

impl ScalarUDFImpl for MyRegexUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"regex_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if matches!(args, [DataType::Utf8]) {
Ok(DataType::Boolean)
} else {
plan_err!("regex_udf only accepts a Utf8 argument")
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
self.matches(value.as_deref()),
)))
}
[ColumnarValue::Array(values)] => {
let mut builder = BooleanBuilder::with_capacity(values.len());
for value in values.as_string::<i32>() {
builder.append_option(self.matches(value))
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
}
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
self.regex.as_str() == other.regex.as_str()
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.regex.as_str().hash(hasher);
hasher.finish()
}
}

#[tokio::test]
async fn test_parameterized_scalar_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}"));
let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar"));

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.filter(
foo_udf
.call(vec![col("text")])
.and(bar_udf.call(vec![col("text")])),
)?
.filter(col("text").is_not_null())?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+--------+",
"| text |",
"+--------+",
"| foobar |",
"| barfoo |",
"+--------+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
67 changes: 56 additions & 11 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::datatypes::DataType;

use datafusion_common::{not_impl_err, Result};

use crate::groups_accumulator::GroupsAccumulator;
use crate::{Accumulator, Expr};
use crate::{
AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction,
};
use arrow::datatypes::DataType;
use datafusion_common::{not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;

/// Logical representation of a user-defined [aggregate function] (UDAF).
///
Expand Down Expand Up @@ -66,16 +70,15 @@ pub struct AggregateUDF {

impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
self.inner.equals(other.inner.as_ref())
}
}

impl Eq for AggregateUDF {}

impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
impl Hash for AggregateUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}

Expand Down Expand Up @@ -218,7 +221,7 @@ where
/// #[derive(Debug, Clone)]
/// struct GeoMeanUdf {
/// signature: Signature
/// };
/// }
///
/// impl GeoMeanUdf {
/// fn new() -> Self {
Expand Down Expand Up @@ -298,6 +301,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}

/// Return true if this aggregate UDF is equal to the other.
///
/// Allows customizing the equality of aggregate UDFs.
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
///
/// - reflexive: `a.equals(a)`;
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
///
/// By default, compares [`Self::name`] and [`Self::signature`].
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}

/// Returns a hash value for this aggregate UDF.
///
/// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
///
/// By default, hashes [`Self::name`] and [`Self::signature`].
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
}

/// AggregateUDF that adds an alias to the underlying function. It is better to
Expand Down Expand Up @@ -348,6 +378,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.inner.hash_value().hash(hasher);
self.aliases.hash(hasher);
hasher.finish()
}
}

/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers
Expand Down
Loading

0 comments on commit 932a62f

Please sign in to comment.