Skip to content

Commit

Permalink
A data fusion inspired traversal API for expressions (#1828)
Browse files Browse the repository at this point in the history
This could later be used for other hierarchies
  • Loading branch information
joseph-isaacs authored Jan 7, 2025
1 parent 61290ad commit 6437b41
Show file tree
Hide file tree
Showing 16 changed files with 516 additions and 47 deletions.
2 changes: 1 addition & 1 deletion vortex-datafusion/src/memory/plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use vortex_array::compute::take;
use vortex_array::{ArrayData, IntoArrayVariant, IntoCanonical};
use vortex_dtype::Field;
use vortex_error::{vortex_err, vortex_panic, VortexError};
use vortex_expr::ExprRef;
use vortex_expr::{ExprRef, VortexExprExt};

/// Physical plan operator that applies a set of [filters][Expr] against the input, producing a
/// row mask that can be used downstream to force a take against the corresponding struct array
Expand Down
12 changes: 7 additions & 5 deletions vortex-expr/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ use std::any::Any;
use std::fmt::Display;
use std::sync::Arc;

use vortex_array::aliases::hash_set::HashSet;
use vortex_array::compute::{and_kleene, compare, or_kleene, Operator as ArrayOperator};
use vortex_array::ArrayData;
use vortex_dtype::Field;
use vortex_error::VortexResult;

use crate::{ExprRef, Operator, VortexExpr};
Expand Down Expand Up @@ -62,9 +60,13 @@ impl VortexExpr for BinaryExpr {
}
}

fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) {
self.lhs.collect_references(references);
self.rhs.collect_references(references);
fn children(&self) -> Vec<&ExprRef> {
vec![&self.lhs, &self.rhs]
}

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
assert_eq!(children.len(), 2);
BinaryExpr::new_expr(children[0].clone(), self.operator, children[1].clone())
}
}

Expand Down
16 changes: 11 additions & 5 deletions vortex-expr/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::any::Any;
use std::fmt::Display;
use std::sync::Arc;

use vortex_array::aliases::hash_set::HashSet;
use vortex_array::array::StructArray;
use vortex_array::variants::StructArrayTrait;
use vortex_array::ArrayData;
Expand All @@ -17,8 +16,10 @@ pub struct Column {
}

impl Column {
pub fn new_expr(field: Field) -> ExprRef {
Arc::new(Self { field })
pub fn new_expr(field: impl Into<Field>) -> ExprRef {
Arc::new(Self {
field: field.into(),
})
}

pub fn field(&self) -> &Field {
Expand Down Expand Up @@ -69,7 +70,12 @@ impl VortexExpr for Column {
.ok_or_else(|| vortex_err!("Array doesn't contain child array {}", self.field))
}

fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) {
references.insert(self.field());
fn children(&self) -> Vec<&ExprRef> {
vec![]
}

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
assert_eq!(children.len(), 0);
self
}
}
9 changes: 9 additions & 0 deletions vortex-expr/src/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ impl VortexExpr for Identity {
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
Ok(batch.clone())
}

fn children(&self) -> Vec<&ExprRef> {
vec![]
}

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
assert_eq!(children.len(), 0);
self
}
}

// Return a global pointer to the identity token.
Expand Down
29 changes: 20 additions & 9 deletions vortex-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use std::any::Any;
use std::fmt::{Debug, Display};
use std::sync::Arc;

use vortex_array::aliases::hash_set::HashSet;

mod binary;
mod column;
pub mod datafusion;
Expand All @@ -16,6 +14,8 @@ mod project;
pub mod pruning;
mod row_filter;
mod select;
#[allow(dead_code)]
mod traversal;

pub use binary::*;
pub use column::*;
Expand All @@ -27,9 +27,12 @@ pub use operators::*;
pub use project::*;
pub use row_filter::*;
pub use select::*;
use vortex_array::aliases::hash_set::HashSet;
use vortex_array::ArrayData;
use vortex_dtype::Field;
use vortex_error::VortexResult;
use vortex_error::{VortexResult, VortexUnwrap};

use crate::traversal::{Node, ReferenceCollector};

pub type ExprRef = Arc<dyn VortexExpr>;

Expand All @@ -41,14 +44,22 @@ pub trait VortexExpr: Debug + Send + Sync + DynEq + Display {
/// Compute result of expression on given batch producing a new batch
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData>;

/// Accumulate all field references from this expression and its children in the provided set
fn collect_references<'a>(&'a self, _references: &mut HashSet<&'a Field>) {}
fn children(&self) -> Vec<&ExprRef>;

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
}

pub trait VortexExprExt {
/// Accumulate all field references from this expression and its children in a set
fn references(&self) -> HashSet<&Field>;
}

/// Accumulate all field references from this expression and its children in a new set
impl VortexExprExt for ExprRef {
fn references(&self) -> HashSet<&Field> {
let mut refs = HashSet::new();
self.collect_references(&mut refs);
refs
let mut collector = ReferenceCollector::new();
// The collector is infallible, so we can unwrap the result
self.accept(&mut collector).vortex_unwrap();
collector.into_fields()
}
}

Expand Down
17 changes: 12 additions & 5 deletions vortex-expr/src/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ use std::any::Any;
use std::fmt::Display;
use std::sync::Arc;

use vortex_array::aliases::hash_set::HashSet;
use vortex_array::compute::{like, LikeOptions};
use vortex_array::ArrayData;
use vortex_dtype::Field;
use vortex_error::VortexResult;

use crate::{ExprRef, VortexExpr};
Expand Down Expand Up @@ -74,9 +72,18 @@ impl VortexExpr for Like {
)
}

fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) {
self.child().collect_references(references);
self.pattern().collect_references(references);
fn children(&self) -> Vec<&ExprRef> {
vec![&self.pattern, &self.child]
}

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
assert_eq!(children.len(), 2);
Like::new_expr(
children[0].clone(),
children[1].clone(),
self.negated,
self.case_insensitive,
)
}
}

Expand Down
15 changes: 13 additions & 2 deletions vortex-expr/src/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ pub struct Literal {
}

impl Literal {
pub fn new_expr(value: Scalar) -> ExprRef {
Arc::new(Self { value })
pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
Arc::new(Self {
value: value.into(),
})
}

pub fn value(&self) -> &Scalar {
Expand All @@ -38,6 +40,15 @@ impl VortexExpr for Literal {
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
Ok(ConstantArray::new(self.value.clone(), batch.len()).into_array())
}

fn children(&self) -> Vec<&ExprRef> {
vec![]
}

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
assert_eq!(children.len(), 0);
self
}
}

/// Create a new `Literal` expression from a type that coerces to `Scalar`.
Expand Down
11 changes: 7 additions & 4 deletions vortex-expr/src/not.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ use std::any::Any;
use std::fmt::Display;
use std::sync::Arc;

use vortex_array::aliases::hash_set::HashSet;
use vortex_array::compute::invert;
use vortex_array::ArrayData;
use vortex_dtype::Field;
use vortex_error::VortexResult;

use crate::{ExprRef, VortexExpr};
Expand Down Expand Up @@ -42,8 +40,13 @@ impl VortexExpr for Not {
invert(&child_result)
}

fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) {
self.child.collect_references(references)
fn children(&self) -> Vec<&ExprRef> {
vec![&self.child]
}

fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
assert_eq!(children.len(), 0);
Self::new_expr(children.remove(0))
}
}

Expand Down
4 changes: 2 additions & 2 deletions vortex-expr/src/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use vortex_dtype::Field;

use crate::{
col, lit, BinaryExpr, Column, ExprRef, Identity, Like, Literal, Not, Operator, RowFilter,
Select, VortexExpr,
Select, VortexExpr, VortexExprExt,
};

/// Restrict expression to only the fields that appear in projection
Expand Down Expand Up @@ -52,7 +52,7 @@ pub fn expr_project(expr: &ExprRef, projection: &[Field]) -> Option<ExprRef> {
}
})
} else if let Some(n) = expr.as_any().downcast_ref::<Not>() {
let own_refs = n.references();
let own_refs = expr.references();
if own_refs.iter().all(|p| projection.contains(p)) {
expr_project(n.child(), projection).map(Not::new_expr)
} else {
Expand Down
2 changes: 1 addition & 1 deletion vortex-expr/src/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use vortex_scalar::Scalar;

use crate::{
and, col, eq, gt, gt_eq, lit, lt_eq, or, BinaryExpr, Column, ExprRef, Identity, Literal, Not,
Operator, RowFilter,
Operator, RowFilter, VortexExprExt,
};

#[derive(Debug, Clone)]
Expand Down
12 changes: 7 additions & 5 deletions vortex-expr/src/row_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fmt::{Debug, Display};
use std::sync::Arc;

use itertools::Itertools;
use vortex_array::aliases::hash_set::HashSet;
use vortex_array::array::ConstantArray;
use vortex_array::compute::{and_kleene, fill_null};
use vortex_array::stats::ArrayStatistics;
Expand Down Expand Up @@ -89,9 +88,12 @@ impl VortexExpr for RowFilter {
fill_null(mask, false.into())
}

fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) {
for expr in self.conjunction.iter() {
expr.collect_references(references);
}
fn children(&self) -> Vec<&ExprRef> {
self.conjunction.iter().collect()
}

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
assert_eq!(self.conjunction.len(), children.len());
Self::from_conjunction_expr(children)
}
}
22 changes: 15 additions & 7 deletions vortex-expr/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use vortex_array::ArrayData;
use vortex_dtype::Field;
use vortex_error::{vortex_err, VortexResult};

use crate::VortexExpr;
use crate::{ExprRef, VortexExpr};

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Select {
Expand All @@ -32,6 +32,13 @@ impl Select {
pub fn exclude_expr(columns: Vec<Field>) -> Arc<Self> {
Arc::new(Self::exclude(columns))
}

pub fn fields(&self) -> &[Field] {
match self {
Select::Include(fields) => fields,
Select::Exclude(fields) => fields,
}
}
}

impl Display for Select {
Expand Down Expand Up @@ -77,12 +84,13 @@ impl VortexExpr for Select {
}
}

fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) {
match self {
Select::Include(f) => references.extend(f.iter()),
// It's weird that we treat the references of exclusions and inclusions the same, we need to have a wrapper around Field in the return
Select::Exclude(e) => references.extend(e.iter()),
}
fn children(&self) -> Vec<&ExprRef> {
vec![]
}

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
assert_eq!(children.len(), 0);
self
}
}

Expand Down
Loading

0 comments on commit 6437b41

Please sign in to comment.