Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 324 additions & 1 deletion datafusion/physical-expr-common/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,28 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
0
}

/// Bind runtime-specific data into this expression, if needed.
///
/// This hook lets an expression replace itself with a runtime-bound version using the given
/// `context` (e.g. binding a per-partition view).
///
/// Binding is single-pass over the existing tree. If this method returns a replacement
/// expression that itself contains additional bindable nodes, those newly introduced nodes are
/// not rebound in the same call.
///
/// You should not call this method directly as it does not handle recursion. Instead use
/// [`bind_runtime_physical_expr`] to handle recursion and bind the full expression tree.
///
/// Note for implementers: this method should not handle recursion.
/// Recursion is handled in [`bind_runtime_physical_expr`].
fn bind_runtime(
&self,
_context: &(dyn Any + Send + Sync),
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
// By default, this expression does not need runtime binding.
Ok(None)
}

/// Returns true if the expression node is volatile, i.e. whether it can return
/// different results when evaluated multiple times with the same input.
///
Expand Down Expand Up @@ -618,6 +640,46 @@ pub fn snapshot_physical_expr_opt(
})
}

/// Bind runtime-specific data into the given `PhysicalExpr`.
///
/// See the documentation of [`PhysicalExpr::bind_runtime`] for more details.
///
/// Runtime binding is applied once over the current expression tree.
///
/// # Returns
///
/// Returns a runtime-bound expression if any node required binding,
/// otherwise returns the original expression.
pub fn bind_runtime_physical_expr(
expr: Arc<dyn PhysicalExpr>,
context: &(dyn Any + Send + Sync),
) -> Result<Arc<dyn PhysicalExpr>> {
bind_runtime_physical_expr_opt(expr, context).data()
}

/// Bind runtime-specific data into the given `PhysicalExpr`.
///
/// See the documentation of [`PhysicalExpr::bind_runtime`] for more details.
///
/// Runtime binding is applied once over the current expression tree.
///
/// # Returns
///
/// Returns a [`Transformed`] indicating whether any runtime binding happened,
/// along with the resulting expression.
pub fn bind_runtime_physical_expr_opt(
expr: Arc<dyn PhysicalExpr>,
context: &(dyn Any + Send + Sync),
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
expr.transform_up(|e| {
if let Some(bound) = e.bind_runtime(context)? {
Ok(Transformed::yes(bound))
} else {
Ok(Transformed::no(Arc::clone(&e)))
}
})
}

/// Check the generation of this `PhysicalExpr`.
/// Dynamic `PhysicalExpr`s may have a generation that is incremented
/// every time the state of the `PhysicalExpr` changes.
Expand Down Expand Up @@ -677,14 +739,15 @@ mod test {
use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
use arrow::datatypes::{DataType, Schema};
use datafusion_expr_common::columnar_value::ColumnarValue;
use std::any::Any;
use std::fmt::{Display, Formatter};
use std::sync::Arc;

#[derive(Debug, PartialEq, Eq, Hash)]
struct TestExpr {}

impl PhysicalExpr for TestExpr {
fn as_any(&self) -> &dyn std::any::Any {
fn as_any(&self) -> &dyn Any {
self
}

Expand Down Expand Up @@ -726,6 +789,161 @@ mod test {
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
struct RuntimeBindableExpr {
name: &'static str,
// Selector used to decide if this node should bind for a given context.
bind_key: Option<&'static str>,
children: Vec<Arc<dyn PhysicalExpr>>,
}

impl RuntimeBindableExpr {
fn new(
name: &'static str,
bind_key: Option<&'static str>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Self {
Self {
name,
bind_key,
children,
}
}
}

impl PhysicalExpr for RuntimeBindableExpr {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _schema: &Schema) -> datafusion_common::Result<DataType> {
Ok(DataType::Int64)
}

fn nullable(&self, _schema: &Schema) -> datafusion_common::Result<bool> {
Ok(false)
}

fn evaluate(
&self,
batch: &RecordBatch,
) -> datafusion_common::Result<ColumnarValue> {
let data = vec![1; batch.num_rows()];
Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
self.children.iter().collect()
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(Self {
name: self.name,
bind_key: self.bind_key,
children,
}))
}

fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(self.name)
}

fn bind_runtime(
&self,
context: &(dyn Any + Send + Sync),
) -> datafusion_common::Result<Option<Arc<dyn PhysicalExpr>>> {
let Some(bind_key) = self.bind_key else {
return Ok(None);
};
let Some(ctx) = context.downcast_ref::<RuntimeBindContext>() else {
return Ok(None);
};
// Bind only when selector in context matches this node's key.
if ctx.target_key != bind_key {
return Ok(None);
}

Ok(Some(Arc::new(Self {
// Simulate replacing runtime placeholder with bound payload.
name: ctx.bound_name,
bind_key: None,
children: self.children.clone(),
})))
}
}

impl Display for RuntimeBindableExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.fmt_sql(f)
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
struct ErrorOnBindExpr;

impl PhysicalExpr for ErrorOnBindExpr {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _schema: &Schema) -> datafusion_common::Result<DataType> {
Ok(DataType::Int64)
}

fn nullable(&self, _schema: &Schema) -> datafusion_common::Result<bool> {
Ok(false)
}

fn evaluate(
&self,
batch: &RecordBatch,
) -> datafusion_common::Result<ColumnarValue> {
let data = vec![1; batch.num_rows()];
Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![]
}

fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
Ok(self)
}

fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("ErrorOnBindExpr")
}

fn bind_runtime(
&self,
_context: &(dyn Any + Send + Sync),
) -> datafusion_common::Result<Option<Arc<dyn PhysicalExpr>>> {
// Used to verify traversal propagates bind errors.
Err(datafusion_common::DataFusionError::Internal(
"forced bind_runtime error".to_string(),
))
}
}

impl Display for ErrorOnBindExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.fmt_sql(f)
}
}

struct RuntimeBindContext {
// Which bindable nodes should be replaced for this call.
target_key: &'static str,
// Replacement used by matching nodes.
bound_name: &'static str,
}

macro_rules! assert_arrays_eq {
($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
let expected = $EXPECTED.to_array(1).unwrap();
Expand Down Expand Up @@ -856,4 +1074,109 @@ mod test {
&BooleanArray::from(vec![true; 5]),
);
}

#[test]
fn test_bind_runtime_physical_expr_default_noop() {
// TestExpr does not override bind_runtime, so traversal is a no-op.
let expr: Arc<dyn PhysicalExpr> = Arc::new(TestExpr {});
let ctx = RuntimeBindContext {
target_key: "right",
bound_name: "bound",
};

let transformed =
super::bind_runtime_physical_expr_opt(Arc::clone(&expr), &ctx).unwrap();

assert!(!transformed.transformed);
assert!(Arc::ptr_eq(&expr, &transformed.data));
}

#[test]
fn test_bind_runtime_physical_expr_recurses() {
// Only the right child matches target_key and should be rewritten.
let left: Arc<dyn PhysicalExpr> =
Arc::new(RuntimeBindableExpr::new("left", Some("left"), vec![]));
let right: Arc<dyn PhysicalExpr> =
Arc::new(RuntimeBindableExpr::new("right", Some("right"), vec![]));
let root: Arc<dyn PhysicalExpr> = Arc::new(RuntimeBindableExpr::new(
"root",
None,
vec![Arc::clone(&left), Arc::clone(&right)],
));
let ctx = RuntimeBindContext {
target_key: "right",
bound_name: "right_bound",
};

let transformed = super::bind_runtime_physical_expr_opt(root, &ctx).unwrap();
assert!(transformed.transformed);

let root = transformed
.data
.as_any()
.downcast_ref::<RuntimeBindableExpr>()
.expect("root should be RuntimeBindableExpr");
let left = root.children[0]
.as_any()
.downcast_ref::<RuntimeBindableExpr>()
.expect("left should be RuntimeBindableExpr");
let right = root.children[1]
.as_any()
.downcast_ref::<RuntimeBindableExpr>()
.expect("right should be RuntimeBindableExpr");

assert_eq!(left.name, "left");
assert_eq!(right.name, "right_bound");
assert_eq!(right.bind_key, None);
}

#[test]
fn test_bind_runtime_physical_expr_returns_data() {
// The non-_opt helper should return the rewritten tree directly.
let expr: Arc<dyn PhysicalExpr> =
Arc::new(RuntimeBindableExpr::new("right", Some("right"), vec![]));
let ctx = RuntimeBindContext {
target_key: "right",
bound_name: "right_bound",
};

let bound = super::bind_runtime_physical_expr(expr, &ctx).unwrap();
let bound = bound
.as_any()
.downcast_ref::<RuntimeBindableExpr>()
.expect("bound should be RuntimeBindableExpr");

assert_eq!(bound.name, "right_bound");
assert_eq!(bound.bind_key, None);
}

#[test]
fn test_bind_runtime_physical_expr_context_mismatch_no_transform() {
// Context mismatch returns no transform even for bindable nodes.
let expr: Arc<dyn PhysicalExpr> =
Arc::new(RuntimeBindableExpr::new("left", Some("left"), vec![]));
let ctx = RuntimeBindContext {
target_key: "right",
bound_name: "right_bound",
};

let transformed =
super::bind_runtime_physical_expr_opt(Arc::clone(&expr), &ctx).unwrap();

assert!(!transformed.transformed);
assert!(Arc::ptr_eq(&expr, &transformed.data));
}

#[test]
fn test_bind_runtime_physical_expr_propagates_error() {
// A bind_runtime error from any node should fail the traversal.
let expr: Arc<dyn PhysicalExpr> = Arc::new(ErrorOnBindExpr);
let ctx = RuntimeBindContext {
target_key: "right",
bound_name: "right_bound",
};

let err = super::bind_runtime_physical_expr_opt(expr, &ctx).unwrap_err();
assert!(err.to_string().contains("forced bind_runtime error"));
}
}
Loading