diff --git a/vortex-datafusion/src/memory/plans.rs b/vortex-datafusion/src/memory/plans.rs index 09d382ebac..8b45fadfbf 100644 --- a/vortex-datafusion/src/memory/plans.rs +++ b/vortex-datafusion/src/memory/plans.rs @@ -23,7 +23,7 @@ use vortex_array::compute::take; use vortex_array::{ArrayData, IntoArrayVariant, IntoCanonical}; use vortex_dtype::Field; use vortex_error::{vortex_err, vortex_panic, VortexError}; -use vortex_expr::ExprRef; +use vortex_expr::{ExprRef, VortexExprExt}; /// Physical plan operator that applies a set of [filters][Expr] against the input, producing a /// row mask that can be used downstream to force a take against the corresponding struct array diff --git a/vortex-expr/src/binary.rs b/vortex-expr/src/binary.rs index ffb45dc373..07fbc8998d 100644 --- a/vortex-expr/src/binary.rs +++ b/vortex-expr/src/binary.rs @@ -2,10 +2,8 @@ use std::any::Any; use std::fmt::Display; use std::sync::Arc; -use vortex_array::aliases::hash_set::HashSet; use vortex_array::compute::{and_kleene, compare, or_kleene, Operator as ArrayOperator}; use vortex_array::ArrayData; -use vortex_dtype::Field; use vortex_error::VortexResult; use crate::{ExprRef, Operator, VortexExpr}; @@ -62,9 +60,13 @@ impl VortexExpr for BinaryExpr { } } - fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) { - self.lhs.collect_references(references); - self.rhs.collect_references(references); + fn children(&self) -> Vec<&ExprRef> { + vec![&self.lhs, &self.rhs] + } + + fn replacing_children(self: Arc, children: Vec) -> ExprRef { + assert_eq!(children.len(), 2); + BinaryExpr::new_expr(children[0].clone(), self.operator, children[1].clone()) } } diff --git a/vortex-expr/src/column.rs b/vortex-expr/src/column.rs index 642dad09a0..cf29fe8ca4 100644 --- a/vortex-expr/src/column.rs +++ b/vortex-expr/src/column.rs @@ -2,7 +2,6 @@ use std::any::Any; use std::fmt::Display; use std::sync::Arc; -use vortex_array::aliases::hash_set::HashSet; use vortex_array::array::StructArray; use vortex_array::variants::StructArrayTrait; use vortex_array::ArrayData; @@ -17,8 +16,10 @@ pub struct Column { } impl Column { - pub fn new_expr(field: Field) -> ExprRef { - Arc::new(Self { field }) + pub fn new_expr(field: impl Into) -> ExprRef { + Arc::new(Self { + field: field.into(), + }) } pub fn field(&self) -> &Field { @@ -69,7 +70,12 @@ impl VortexExpr for Column { .ok_or_else(|| vortex_err!("Array doesn't contain child array {}", self.field)) } - fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) { - references.insert(self.field()); + fn children(&self) -> Vec<&ExprRef> { + vec![] + } + + fn replacing_children(self: Arc, children: Vec) -> ExprRef { + assert_eq!(children.len(), 0); + self } } diff --git a/vortex-expr/src/identity.rs b/vortex-expr/src/identity.rs index 7363a7a4d7..56b6a65dd0 100644 --- a/vortex-expr/src/identity.rs +++ b/vortex-expr/src/identity.rs @@ -30,6 +30,15 @@ impl VortexExpr for Identity { fn evaluate(&self, batch: &ArrayData) -> VortexResult { Ok(batch.clone()) } + + fn children(&self) -> Vec<&ExprRef> { + vec![] + } + + fn replacing_children(self: Arc, children: Vec) -> ExprRef { + assert_eq!(children.len(), 0); + self + } } // Return a global pointer to the identity token. diff --git a/vortex-expr/src/lib.rs b/vortex-expr/src/lib.rs index bbc03a7f1a..c4c5bf5d47 100644 --- a/vortex-expr/src/lib.rs +++ b/vortex-expr/src/lib.rs @@ -2,8 +2,6 @@ use std::any::Any; use std::fmt::{Debug, Display}; use std::sync::Arc; -use vortex_array::aliases::hash_set::HashSet; - mod binary; mod column; pub mod datafusion; @@ -16,6 +14,8 @@ mod project; pub mod pruning; mod row_filter; mod select; +#[allow(dead_code)] +mod traversal; pub use binary::*; pub use column::*; @@ -27,9 +27,12 @@ pub use operators::*; pub use project::*; pub use row_filter::*; pub use select::*; +use vortex_array::aliases::hash_set::HashSet; use vortex_array::ArrayData; use vortex_dtype::Field; -use vortex_error::VortexResult; +use vortex_error::{VortexResult, VortexUnwrap}; + +use crate::traversal::{Node, ReferenceCollector}; pub type ExprRef = Arc; @@ -41,14 +44,22 @@ pub trait VortexExpr: Debug + Send + Sync + DynEq + Display { /// Compute result of expression on given batch producing a new batch fn evaluate(&self, batch: &ArrayData) -> VortexResult; - /// Accumulate all field references from this expression and its children in the provided set - fn collect_references<'a>(&'a self, _references: &mut HashSet<&'a Field>) {} + fn children(&self) -> Vec<&ExprRef>; + + fn replacing_children(self: Arc, children: Vec) -> ExprRef; +} + +pub trait VortexExprExt { + /// Accumulate all field references from this expression and its children in a set + fn references(&self) -> HashSet<&Field>; +} - /// Accumulate all field references from this expression and its children in a new set +impl VortexExprExt for ExprRef { fn references(&self) -> HashSet<&Field> { - let mut refs = HashSet::new(); - self.collect_references(&mut refs); - refs + let mut collector = ReferenceCollector::new(); + // The collector is infallible, so we can unwrap the result + self.accept(&mut collector).vortex_unwrap(); + collector.into_fields() } } diff --git a/vortex-expr/src/like.rs b/vortex-expr/src/like.rs index c0159e3733..ce4c043528 100644 --- a/vortex-expr/src/like.rs +++ b/vortex-expr/src/like.rs @@ -2,10 +2,8 @@ use std::any::Any; use std::fmt::Display; use std::sync::Arc; -use vortex_array::aliases::hash_set::HashSet; use vortex_array::compute::{like, LikeOptions}; use vortex_array::ArrayData; -use vortex_dtype::Field; use vortex_error::VortexResult; use crate::{ExprRef, VortexExpr}; @@ -74,9 +72,18 @@ impl VortexExpr for Like { ) } - fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) { - self.child().collect_references(references); - self.pattern().collect_references(references); + fn children(&self) -> Vec<&ExprRef> { + vec![&self.pattern, &self.child] + } + + fn replacing_children(self: Arc, children: Vec) -> ExprRef { + assert_eq!(children.len(), 2); + Like::new_expr( + children[0].clone(), + children[1].clone(), + self.negated, + self.case_insensitive, + ) } } diff --git a/vortex-expr/src/literal.rs b/vortex-expr/src/literal.rs index 0430adb4a6..e9e7f78a1b 100644 --- a/vortex-expr/src/literal.rs +++ b/vortex-expr/src/literal.rs @@ -15,8 +15,10 @@ pub struct Literal { } impl Literal { - pub fn new_expr(value: Scalar) -> ExprRef { - Arc::new(Self { value }) + pub fn new_expr(value: impl Into) -> ExprRef { + Arc::new(Self { + value: value.into(), + }) } pub fn value(&self) -> &Scalar { @@ -38,6 +40,15 @@ impl VortexExpr for Literal { fn evaluate(&self, batch: &ArrayData) -> VortexResult { Ok(ConstantArray::new(self.value.clone(), batch.len()).into_array()) } + + fn children(&self) -> Vec<&ExprRef> { + vec![] + } + + fn replacing_children(self: Arc, children: Vec) -> ExprRef { + assert_eq!(children.len(), 0); + self + } } /// Create a new `Literal` expression from a type that coerces to `Scalar`. diff --git a/vortex-expr/src/not.rs b/vortex-expr/src/not.rs index d5fe7b91cb..4c2b63ff9f 100644 --- a/vortex-expr/src/not.rs +++ b/vortex-expr/src/not.rs @@ -2,10 +2,8 @@ use std::any::Any; use std::fmt::Display; use std::sync::Arc; -use vortex_array::aliases::hash_set::HashSet; use vortex_array::compute::invert; use vortex_array::ArrayData; -use vortex_dtype::Field; use vortex_error::VortexResult; use crate::{ExprRef, VortexExpr}; @@ -42,8 +40,13 @@ impl VortexExpr for Not { invert(&child_result) } - fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) { - self.child.collect_references(references) + fn children(&self) -> Vec<&ExprRef> { + vec![&self.child] + } + + fn replacing_children(self: Arc, mut children: Vec) -> ExprRef { + assert_eq!(children.len(), 0); + Self::new_expr(children.remove(0)) } } diff --git a/vortex-expr/src/project.rs b/vortex-expr/src/project.rs index f55b5e1e85..db11038284 100644 --- a/vortex-expr/src/project.rs +++ b/vortex-expr/src/project.rs @@ -5,7 +5,7 @@ use vortex_dtype::Field; use crate::{ col, lit, BinaryExpr, Column, ExprRef, Identity, Like, Literal, Not, Operator, RowFilter, - Select, VortexExpr, + Select, VortexExpr, VortexExprExt, }; /// Restrict expression to only the fields that appear in projection @@ -52,7 +52,7 @@ pub fn expr_project(expr: &ExprRef, projection: &[Field]) -> Option { } }) } else if let Some(n) = expr.as_any().downcast_ref::() { - let own_refs = n.references(); + let own_refs = expr.references(); if own_refs.iter().all(|p| projection.contains(p)) { expr_project(n.child(), projection).map(Not::new_expr) } else { diff --git a/vortex-expr/src/pruning.rs b/vortex-expr/src/pruning.rs index 913d26b719..9d3268b4c9 100644 --- a/vortex-expr/src/pruning.rs +++ b/vortex-expr/src/pruning.rs @@ -15,7 +15,7 @@ use vortex_scalar::Scalar; use crate::{ and, col, eq, gt, gt_eq, lit, lt_eq, or, BinaryExpr, Column, ExprRef, Identity, Literal, Not, - Operator, RowFilter, + Operator, RowFilter, VortexExprExt, }; #[derive(Debug, Clone)] diff --git a/vortex-expr/src/row_filter.rs b/vortex-expr/src/row_filter.rs index 72052a644f..0ceb5f9bf9 100644 --- a/vortex-expr/src/row_filter.rs +++ b/vortex-expr/src/row_filter.rs @@ -3,7 +3,6 @@ use std::fmt::{Debug, Display}; use std::sync::Arc; use itertools::Itertools; -use vortex_array::aliases::hash_set::HashSet; use vortex_array::array::ConstantArray; use vortex_array::compute::{and_kleene, fill_null}; use vortex_array::stats::ArrayStatistics; @@ -89,9 +88,12 @@ impl VortexExpr for RowFilter { fill_null(mask, false.into()) } - fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) { - for expr in self.conjunction.iter() { - expr.collect_references(references); - } + fn children(&self) -> Vec<&ExprRef> { + self.conjunction.iter().collect() + } + + fn replacing_children(self: Arc, children: Vec) -> ExprRef { + assert_eq!(self.conjunction.len(), children.len()); + Self::from_conjunction_expr(children) } } diff --git a/vortex-expr/src/select.rs b/vortex-expr/src/select.rs index b8f3039465..99af2e4711 100644 --- a/vortex-expr/src/select.rs +++ b/vortex-expr/src/select.rs @@ -8,7 +8,7 @@ use vortex_array::ArrayData; use vortex_dtype::Field; use vortex_error::{vortex_err, VortexResult}; -use crate::VortexExpr; +use crate::{ExprRef, VortexExpr}; #[derive(Debug, Clone, PartialEq, Eq)] pub enum Select { @@ -32,6 +32,13 @@ impl Select { pub fn exclude_expr(columns: Vec) -> Arc { Arc::new(Self::exclude(columns)) } + + pub fn fields(&self) -> &[Field] { + match self { + Select::Include(fields) => fields, + Select::Exclude(fields) => fields, + } + } } impl Display for Select { @@ -77,12 +84,13 @@ impl VortexExpr for Select { } } - fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) { - match self { - Select::Include(f) => references.extend(f.iter()), - // It's weird that we treat the references of exclusions and inclusions the same, we need to have a wrapper around Field in the return - Select::Exclude(e) => references.extend(e.iter()), - } + fn children(&self) -> Vec<&ExprRef> { + vec![] + } + + fn replacing_children(self: Arc, children: Vec) -> ExprRef { + assert_eq!(children.len(), 0); + self } } diff --git a/vortex-expr/src/traversal/mod.rs b/vortex-expr/src/traversal/mod.rs new file mode 100644 index 0000000000..083cc70913 --- /dev/null +++ b/vortex-expr/src/traversal/mod.rs @@ -0,0 +1,312 @@ +mod references; +mod visitor; + +use itertools::Itertools; +pub use references::ReferenceCollector; +use vortex_error::VortexResult; + +use crate::ExprRef; + +/// Define a data fusion inspired traversal pattern for visiting nodes in a `Node`, +/// for now only VortexExpr. +/// +/// This traversal is a pre-order traversal. +/// There are control traversal controls `TraversalOrder`: +/// - `Skip`: Skip visiting the children of the current node. +/// - `Stop`: Stop visiting any more nodes in the traversal. +/// - `Continue`: Continue with the traversal as expected. + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TraversalOrder { + // In a top-down traversal, skip visiting the children of the current node. + // In the bottom-up phase of the traversal this does nothing (for now). + Skip, + + // Stop visiting any more nodes in the traversal. + Stop, + + // Continue with the traversal as expected. + Continue, +} + +#[derive(Debug, Clone)] +pub struct TransformResult { + result: T, + order: TraversalOrder, + changed: bool, +} + +impl TransformResult { + pub fn yes(result: T) -> Self { + Self { + result, + order: TraversalOrder::Continue, + changed: true, + } + } + + pub fn no(result: T) -> Self { + Self { + result, + order: TraversalOrder::Continue, + changed: false, + } + } +} + +pub trait NodeVisitor<'a> { + type NodeTy: Node; + + fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult { + Ok(TraversalOrder::Continue) + } + + fn visit_up(&mut self, _node: &'a Self::NodeTy) -> VortexResult { + Ok(TraversalOrder::Continue) + } +} + +pub trait MutNodeVisitor { + type NodeTy: Node; + + fn visit_down(&mut self, _node: &Self::NodeTy) -> VortexResult { + Ok(TraversalOrder::Continue) + } + + fn visit_up(&mut self, _node: Self::NodeTy) -> VortexResult>; +} + +pub trait Node: Sized { + fn accept<'a, V: NodeVisitor<'a, NodeTy = Self>>( + &'a self, + _visitor: &mut V, + ) -> VortexResult; + + fn transform>( + self, + _visitor: &mut V, + ) -> VortexResult>; +} + +impl Node for ExprRef { + // A pre-order traversal. + fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>( + &'a self, + visitor: &mut V, + ) -> VortexResult { + let mut ord = visitor.visit_down(self)?; + if ord == TraversalOrder::Stop { + return Ok(TraversalOrder::Stop); + } + if ord == TraversalOrder::Skip { + return Ok(TraversalOrder::Continue); + } + for child in self.children() { + if ord != TraversalOrder::Continue { + return Ok(ord); + } + ord = child.accept(visitor)?; + } + if ord == TraversalOrder::Stop { + return Ok(TraversalOrder::Stop); + } + visitor.visit_up(self) + } + + // A pre-order transform, with an option to ignore sub-tress (using visit_down). + fn transform>( + self, + visitor: &mut V, + ) -> VortexResult> { + let mut ord = visitor.visit_down(&self)?; + if ord == TraversalOrder::Stop { + return Ok(TransformResult { + result: self, + order: TraversalOrder::Stop, + changed: false, + }); + } + let (children, ord, changed) = if ord == TraversalOrder::Continue { + let mut new_children = Vec::with_capacity(self.children().len()); + let mut changed = false; + for child in self.children() { + match ord { + TraversalOrder::Continue | TraversalOrder::Skip => { + let TransformResult { + result: new_child, + order: child_order, + changed: child_changed, + } = child.clone().transform(visitor)?; + new_children.push(new_child); + ord = child_order; + changed |= child_changed; + } + TraversalOrder::Stop => new_children.push(child.clone()), + } + } + (new_children, ord, changed) + } else { + ( + self.children().into_iter().cloned().collect_vec(), + ord, + false, + ) + }; + + if ord == TraversalOrder::Continue { + let up = visitor.visit_up(self)?; + Ok(TransformResult::yes(up.result.replacing_children(children))) + } else { + Ok(TransformResult { + result: self.replacing_children(children), + order: ord, + changed, + }) + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use vortex_array::aliases::hash_set::HashSet; + use vortex_dtype::Field; + use vortex_error::VortexResult; + + use crate::traversal::visitor::pre_order_visit_down; + use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder}; + use crate::{BinaryExpr, Column, ExprRef, Literal, Operator, VortexExpr, VortexExprExt}; + + #[derive(Default)] + pub struct ExprLitCollector<'a>(pub Vec<&'a ExprRef>); + + impl<'a> NodeVisitor<'a> for ExprLitCollector<'a> { + type NodeTy = ExprRef; + + fn visit_down(&mut self, node: &'a ExprRef) -> VortexResult { + if node.as_any().downcast_ref::().is_some() { + self.0.push(node) + } + Ok(TraversalOrder::Continue) + } + + fn visit_up(&mut self, _node: &'a ExprRef) -> VortexResult { + Ok(TraversalOrder::Continue) + } + } + + #[derive(Default)] + pub struct ExprColToLit(i32); + + impl MutNodeVisitor for ExprColToLit { + type NodeTy = ExprRef; + + fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult> { + let col = node.as_any().downcast_ref::(); + if col.is_some() { + let id = self.0; + self.0 += 1; + Ok(TransformResult::yes(Literal::new_expr(id))) + } else { + Ok(TransformResult::no(node)) + } + } + } + + #[test] + fn expr_deep_visitor_test() { + let col1: Arc = Column::new_expr("col1"); + let lit1 = Literal::new_expr(1); + let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, lit1.clone()); + let lit2 = Literal::new_expr(2); + let expr = BinaryExpr::new_expr(expr, Operator::And, lit2); + let mut printer = ExprLitCollector::default(); + expr.accept(&mut printer).unwrap(); + assert_eq!(printer.0.len(), 2); + } + + #[test] + fn expr_deep_mut_visitor_test() { + let col1: Arc = Column::new_expr("col1"); + let col2: Arc = Column::new_expr("col2"); + let expr = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone()); + let lit2 = Literal::new_expr(2); + let expr = BinaryExpr::new_expr(expr, Operator::And, lit2); + let mut printer = ExprColToLit::default(); + let new = expr.transform(&mut printer).unwrap(); + assert!(new.changed); + + let expr = new.result; + + let mut printer = ExprLitCollector::default(); + expr.accept(&mut printer).unwrap(); + assert_eq!(printer.0.len(), 3); + } + + #[test] + fn expr_skip_test() { + let col1: Arc = Column::new_expr("col1"); + let col2: Arc = Column::new_expr("col2"); + let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone()); + let col3: Arc = Column::new_expr("col3"); + let col4: Arc = Column::new_expr("col4"); + let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone()); + let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2); + + let mut nodes = Vec::new(); + expr.accept(&mut pre_order_visit_down(|node: &ExprRef| { + if node.as_any().downcast_ref::().is_some() { + nodes.push(node) + } + if let Some(bin) = node.as_any().downcast_ref::() { + if bin.op() == Operator::Eq { + return Ok(TraversalOrder::Skip); + } + } + Ok(TraversalOrder::Continue) + })) + .unwrap(); + + assert_eq!( + nodes + .into_iter() + .map(|x| x.references()) + .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()), + HashSet::from_iter(vec![&Field::from("col3"), &Field::from("col4")]) + ); + } + + #[test] + fn expr_stop_test() { + let col1: Arc = Column::new_expr("col1"); + let col2: Arc = Column::new_expr("col2"); + let expr1 = BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone()); + let col3: Arc = Column::new_expr("col3"); + let col4: Arc = Column::new_expr("col4"); + let expr2 = BinaryExpr::new_expr(col3.clone(), Operator::NotEq, col4.clone()); + let expr = BinaryExpr::new_expr(expr1, Operator::And, expr2); + + let mut nodes = Vec::new(); + expr.accept(&mut pre_order_visit_down(|node: &ExprRef| { + if node.as_any().downcast_ref::().is_some() { + nodes.push(node) + } + if let Some(bin) = node.as_any().downcast_ref::() { + if bin.op() == Operator::Eq { + return Ok(TraversalOrder::Stop); + } + } + Ok(TraversalOrder::Continue) + })) + .unwrap(); + + assert_eq!( + nodes + .into_iter() + .map(|x| x.references()) + .fold(HashSet::new(), |acc, x| acc.union(&x).cloned().collect()), + HashSet::from_iter(vec![]) + ); + } +} diff --git a/vortex-expr/src/traversal/references.rs b/vortex-expr/src/traversal/references.rs new file mode 100644 index 0000000000..496ea8c99e --- /dev/null +++ b/vortex-expr/src/traversal/references.rs @@ -0,0 +1,40 @@ +use vortex_array::aliases::hash_set::HashSet; +use vortex_dtype::Field; +use vortex_error::VortexResult; + +use crate::traversal::{NodeVisitor, TraversalOrder}; +use crate::{Column, ExprRef, Select}; + +pub struct ReferenceCollector<'a> { + fields: HashSet<&'a Field>, +} + +impl<'a> ReferenceCollector<'a> { + pub fn new() -> Self { + Self { + fields: HashSet::new(), + } + } + + pub fn with_set(set: HashSet<&'a Field>) -> Self { + Self { fields: set } + } + + pub fn into_fields(self) -> HashSet<&'a Field> { + self.fields + } +} + +impl<'a> NodeVisitor<'a> for ReferenceCollector<'a> { + type NodeTy = ExprRef; + + fn visit_up(&mut self, node: &'a ExprRef) -> VortexResult { + if let Some(col) = node.as_any().downcast_ref::() { + self.fields.insert(col.field()); + } + if let Some(sel) = node.as_any().downcast_ref::