From 9ab30a66a5c725feb901e819b900a0c92004f307 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Wed, 10 Jul 2024 17:05:55 +0300 Subject: [PATCH] Add customizable equality and hash functions to UDFs --- datafusion/expr/src/udaf.rs | 45 +++++++++++++++++++++++++++++++------ datafusion/expr/src/udf.rs | 29 +++++++++++++++++++++++- datafusion/expr/src/udwf.rs | 43 ++++++++++++++++++++++++++++++----- 3 files changed, 103 insertions(+), 14 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7a054abea75b3..ba018c3ac8c0f 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,6 +17,17 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; +use std::vec; + +use arrow::datatypes::{DataType, Field}; +use sqlparser::ast::NullTreatment; + +use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; + use crate::expr::AggregateFunction; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, @@ -26,13 +37,6 @@ use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; -use sqlparser::ast::NullTreatment; -use std::any::Any; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; /// Logical representation of a user-defined [aggregate function] (UDAF). /// @@ -507,6 +511,21 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Dynamic equality. Allows customizing the equality of aggregate UDFs. + /// By default, compares the UDF name and signature. + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Dynamic hashing. Allows customizing the hash code of aggregate UDFs. + /// By default, hashes the UDF name and signature. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } pub enum ReversedUDAF { @@ -562,6 +581,18 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) + } else { + self.inner.equals(other) + } + } + + fn hash_value(&self) -> u64 { + self.inner.hash_value() + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 68d3af6ace3c0..1292330c38e73 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use crate::expr::create_name; @@ -540,6 +541,21 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Dynamic equality. Allows customizing the equality of scalar UDFs. + /// By default, compares the UDF name and signature. + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Dynamic hashing. Allows customizing the hash code of scalar UDFs. + /// By default, hashes the UDF name and signature. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -557,7 +573,6 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } } } @@ -575,6 +590,18 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.signature() } + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) + } else { + self.inner.equals(other) + } + } + + fn hash_value(&self) -> u64 { + self.inner.hash_value() + } + fn return_type(&self, arg_types: &[DataType]) -> Result { self.inner.return_type(arg_types) } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 70b44e5e307a4..aaf457fcd1eaa 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -17,18 +17,22 @@ //! [`WindowUDF`]: User Defined Window Functions -use crate::{ - function::WindowFunctionSimplification, Expr, PartitionEvaluator, - PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, -}; -use arrow::datatypes::DataType; -use datafusion_common::Result; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; +use arrow::datatypes::DataType; + +use datafusion_common::Result; + +use crate::{ + function::WindowFunctionSimplification, Expr, PartitionEvaluator, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, +}; + /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// @@ -296,6 +300,21 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn simplify(&self) -> Option { None } + + /// Dynamic equality. Allows customizing the equality of window UDFs. + /// By default, compares the UDF name and signature. + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Dynamic hashing. Allows customizing the hash code of window UDFs. + /// By default, hashes the UDF name and signature. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// WindowUDF that adds an alias to the underlying function. It is better to @@ -342,6 +361,18 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) + } else { + self.inner.equals(other) + } + } + + fn hash_value(&self) -> u64 { + self.inner.hash_value() + } } /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers