diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 7107b0a9004d3..ce99c4a80513e 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -416,6 +416,28 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { 0 } + /// Bind runtime-specific data into this expression, if needed. + /// + /// This hook lets an expression replace itself with a runtime-bound version using the given + /// `context` (e.g. binding a per-partition view). + /// + /// Binding is single-pass over the existing tree. If this method returns a replacement + /// expression that itself contains additional bindable nodes, those newly introduced nodes are + /// not rebound in the same call. + /// + /// You should not call this method directly as it does not handle recursion. Instead use + /// [`bind_runtime_physical_expr`] to handle recursion and bind the full expression tree. + /// + /// Note for implementers: this method should not handle recursion. + /// Recursion is handled in [`bind_runtime_physical_expr`]. + fn bind_runtime( + &self, + _context: &(dyn Any + Send + Sync), + ) -> Result>> { + // By default, this expression does not need runtime binding. + Ok(None) + } + /// Returns true if the expression node is volatile, i.e. whether it can return /// different results when evaluated multiple times with the same input. /// @@ -618,6 +640,46 @@ pub fn snapshot_physical_expr_opt( }) } +/// Bind runtime-specific data into the given `PhysicalExpr`. +/// +/// See the documentation of [`PhysicalExpr::bind_runtime`] for more details. +/// +/// Runtime binding is applied once over the current expression tree. +/// +/// # Returns +/// +/// Returns a runtime-bound expression if any node required binding, +/// otherwise returns the original expression. +pub fn bind_runtime_physical_expr( + expr: Arc, + context: &(dyn Any + Send + Sync), +) -> Result> { + bind_runtime_physical_expr_opt(expr, context).data() +} + +/// Bind runtime-specific data into the given `PhysicalExpr`. +/// +/// See the documentation of [`PhysicalExpr::bind_runtime`] for more details. +/// +/// Runtime binding is applied once over the current expression tree. +/// +/// # Returns +/// +/// Returns a [`Transformed`] indicating whether any runtime binding happened, +/// along with the resulting expression. +pub fn bind_runtime_physical_expr_opt( + expr: Arc, + context: &(dyn Any + Send + Sync), +) -> Result>> { + expr.transform_up(|e| { + if let Some(bound) = e.bind_runtime(context)? { + Ok(Transformed::yes(bound)) + } else { + Ok(Transformed::no(Arc::clone(&e))) + } + }) +} + /// Check the generation of this `PhysicalExpr`. /// Dynamic `PhysicalExpr`s may have a generation that is incremented /// every time the state of the `PhysicalExpr` changes. @@ -677,6 +739,7 @@ mod test { use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch}; use arrow::datatypes::{DataType, Schema}; use datafusion_expr_common::columnar_value::ColumnarValue; + use std::any::Any; use std::fmt::{Display, Formatter}; use std::sync::Arc; @@ -684,7 +747,7 @@ mod test { struct TestExpr {} impl PhysicalExpr for TestExpr { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -726,6 +789,161 @@ mod test { } } + #[derive(Debug, PartialEq, Eq, Hash)] + struct RuntimeBindableExpr { + name: &'static str, + // Selector used to decide if this node should bind for a given context. + bind_key: Option<&'static str>, + children: Vec>, + } + + impl RuntimeBindableExpr { + fn new( + name: &'static str, + bind_key: Option<&'static str>, + children: Vec>, + ) -> Self { + Self { + name, + bind_key, + children, + } + } + } + + impl PhysicalExpr for RuntimeBindableExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(DataType::Int64) + } + + fn nullable(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(false) + } + + fn evaluate( + &self, + batch: &RecordBatch, + ) -> datafusion_common::Result { + let data = vec![1; batch.num_rows()]; + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data)))) + } + + fn children(&self) -> Vec<&Arc> { + self.children.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(Self { + name: self.name, + bind_key: self.bind_key, + children, + })) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(self.name) + } + + fn bind_runtime( + &self, + context: &(dyn Any + Send + Sync), + ) -> datafusion_common::Result>> { + let Some(bind_key) = self.bind_key else { + return Ok(None); + }; + let Some(ctx) = context.downcast_ref::() else { + return Ok(None); + }; + // Bind only when selector in context matches this node's key. + if ctx.target_key != bind_key { + return Ok(None); + } + + Ok(Some(Arc::new(Self { + // Simulate replacing runtime placeholder with bound payload. + name: ctx.bound_name, + bind_key: None, + children: self.children.clone(), + }))) + } + } + + impl Display for RuntimeBindableExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.fmt_sql(f) + } + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct ErrorOnBindExpr; + + impl PhysicalExpr for ErrorOnBindExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(DataType::Int64) + } + + fn nullable(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(false) + } + + fn evaluate( + &self, + batch: &RecordBatch, + ) -> datafusion_common::Result { + let data = vec![1; batch.num_rows()]; + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data)))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion_common::Result> { + Ok(self) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("ErrorOnBindExpr") + } + + fn bind_runtime( + &self, + _context: &(dyn Any + Send + Sync), + ) -> datafusion_common::Result>> { + // Used to verify traversal propagates bind errors. + Err(datafusion_common::DataFusionError::Internal( + "forced bind_runtime error".to_string(), + )) + } + } + + impl Display for ErrorOnBindExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.fmt_sql(f) + } + } + + struct RuntimeBindContext { + // Which bindable nodes should be replaced for this call. + target_key: &'static str, + // Replacement used by matching nodes. + bound_name: &'static str, + } + macro_rules! assert_arrays_eq { ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => { let expected = $EXPECTED.to_array(1).unwrap(); @@ -856,4 +1074,109 @@ mod test { &BooleanArray::from(vec![true; 5]), ); } + + #[test] + fn test_bind_runtime_physical_expr_default_noop() { + // TestExpr does not override bind_runtime, so traversal is a no-op. + let expr: Arc = Arc::new(TestExpr {}); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "bound", + }; + + let transformed = + super::bind_runtime_physical_expr_opt(Arc::clone(&expr), &ctx).unwrap(); + + assert!(!transformed.transformed); + assert!(Arc::ptr_eq(&expr, &transformed.data)); + } + + #[test] + fn test_bind_runtime_physical_expr_recurses() { + // Only the right child matches target_key and should be rewritten. + let left: Arc = + Arc::new(RuntimeBindableExpr::new("left", Some("left"), vec![])); + let right: Arc = + Arc::new(RuntimeBindableExpr::new("right", Some("right"), vec![])); + let root: Arc = Arc::new(RuntimeBindableExpr::new( + "root", + None, + vec![Arc::clone(&left), Arc::clone(&right)], + )); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "right_bound", + }; + + let transformed = super::bind_runtime_physical_expr_opt(root, &ctx).unwrap(); + assert!(transformed.transformed); + + let root = transformed + .data + .as_any() + .downcast_ref::() + .expect("root should be RuntimeBindableExpr"); + let left = root.children[0] + .as_any() + .downcast_ref::() + .expect("left should be RuntimeBindableExpr"); + let right = root.children[1] + .as_any() + .downcast_ref::() + .expect("right should be RuntimeBindableExpr"); + + assert_eq!(left.name, "left"); + assert_eq!(right.name, "right_bound"); + assert_eq!(right.bind_key, None); + } + + #[test] + fn test_bind_runtime_physical_expr_returns_data() { + // The non-_opt helper should return the rewritten tree directly. + let expr: Arc = + Arc::new(RuntimeBindableExpr::new("right", Some("right"), vec![])); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "right_bound", + }; + + let bound = super::bind_runtime_physical_expr(expr, &ctx).unwrap(); + let bound = bound + .as_any() + .downcast_ref::() + .expect("bound should be RuntimeBindableExpr"); + + assert_eq!(bound.name, "right_bound"); + assert_eq!(bound.bind_key, None); + } + + #[test] + fn test_bind_runtime_physical_expr_context_mismatch_no_transform() { + // Context mismatch returns no transform even for bindable nodes. + let expr: Arc = + Arc::new(RuntimeBindableExpr::new("left", Some("left"), vec![])); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "right_bound", + }; + + let transformed = + super::bind_runtime_physical_expr_opt(Arc::clone(&expr), &ctx).unwrap(); + + assert!(!transformed.transformed); + assert!(Arc::ptr_eq(&expr, &transformed.data)); + } + + #[test] + fn test_bind_runtime_physical_expr_propagates_error() { + // A bind_runtime error from any node should fail the traversal. + let expr: Arc = Arc::new(ErrorOnBindExpr); + let ctx = RuntimeBindContext { + target_key: "right", + bound_name: "right_bound", + }; + + let err = super::bind_runtime_physical_expr_opt(expr, &ctx).unwrap_err(); + assert!(err.to_string().contains("forced bind_runtime error")); + } }