diff --git a/optd-datafusion-repr/src/lib.rs b/optd-datafusion-repr/src/lib.rs index d2869af5..3e86bfde 100644 --- a/optd-datafusion-repr/src/lib.rs +++ b/optd-datafusion-repr/src/lib.rs @@ -22,12 +22,7 @@ use properties::{ schema::{Catalog, SchemaPropertyBuilder}, }; use rules::{ - EliminateDuplicatedAggExprRule, EliminateDuplicatedSortExprRule, EliminateFilterRule, - EliminateJoinRule, EliminateLimitRule, FilterAggTransposeRule, FilterCrossJoinTransposeRule, - FilterInnerJoinTransposeRule, FilterMergeRule, FilterProjectTransposeRule, - FilterSortTransposeRule, HashJoinRule, JoinAssocRule, JoinCommuteRule, PhysicalConversionRule, - ProjectFilterTransposeRule, ProjectMergeRule, ProjectionPullUpJoin, SimplifyFilterRule, - SimplifyJoinCondRule, + EliminateDuplicatedAggExprRule, EliminateDuplicatedSortExprRule, EliminateFilterRule, EliminateJoinRule, EliminateLimitRule, FilterAggTransposeRule, FilterCrossJoinTransposeRule, FilterInnerJoinTransposeRule, FilterMergeRule, FilterProjectTransposeRule, FilterSortTransposeRule, HashJoinRule, JoinAssocRule, JoinCommuteRule, PhysicalConversionRule, ProjectFilterTransposeRule, ProjectMergeRule, ProjectRemoveRule, ProjectionPullUpJoin, ProjectionPushDownJoin, SimplifyFilterRule, SimplifyJoinCondRule }; pub use optd_core::rel_node::Value; @@ -90,6 +85,7 @@ impl DatafusionOptimizer { Arc::new(EliminateDuplicatedAggExprRule::new()), Arc::new(ProjectMergeRule::new()), Arc::new(FilterMergeRule::new()), + Arc::new(ProjectRemoveRule::new()), ] } @@ -104,6 +100,9 @@ impl DatafusionOptimizer { rule_wrappers.push(RuleWrapper::new_cascades(Arc::new( ProjectFilterTransposeRule::new(), ))); + // rule_wrappers.push(RuleWrapper::new_cascades(Arc::new( + // ProjectionPushDownJoin::new(), + // ))); // add all filter pushdown rules as heuristic rules rule_wrappers.push(RuleWrapper::new_heuristic(Arc::new( FilterProjectTransposeRule::new(), diff --git a/optd-datafusion-repr/src/rules.rs b/optd-datafusion-repr/src/rules.rs index aad6a908..4e3aee48 100644 --- a/optd-datafusion-repr/src/rules.rs +++ b/optd-datafusion-repr/src/rules.rs @@ -22,6 +22,7 @@ pub use joins::{EliminateJoinRule, HashJoinRule, JoinAssocRule, JoinCommuteRule} pub use physical::PhysicalConversionRule; pub use project_transpose::{ project_filter_transpose::{FilterProjectTransposeRule, ProjectFilterTransposeRule}, - project_join_transpose::ProjectionPullUpJoin, + project_join_transpose::{ProjectionPullUpJoin, ProjectionPushDownJoin}, project_merge::ProjectMergeRule, + project_remove::ProjectRemoveRule, }; diff --git a/optd-datafusion-repr/src/rules/project_transpose.rs b/optd-datafusion-repr/src/rules/project_transpose.rs index 5c4f45bb..27c251c1 100644 --- a/optd-datafusion-repr/src/rules/project_transpose.rs +++ b/optd-datafusion-repr/src/rules/project_transpose.rs @@ -1,4 +1,5 @@ pub mod project_filter_transpose; pub mod project_join_transpose; pub mod project_merge; +pub mod project_remove; pub mod project_transpose_common; diff --git a/optd-datafusion-repr/src/rules/project_transpose/project_filter_transpose.rs b/optd-datafusion-repr/src/rules/project_transpose/project_filter_transpose.rs index 004ce0cf..58851ffa 100644 --- a/optd-datafusion-repr/src/rules/project_transpose/project_filter_transpose.rs +++ b/optd-datafusion-repr/src/rules/project_transpose/project_filter_transpose.rs @@ -4,18 +4,12 @@ use std::vec; use optd_core::rules::{Rule, RuleMatcher}; use optd_core::{optimizer::Optimizer, rel_node::RelNode}; -use super::project_transpose_common::ProjectionMapping; +use super::project_transpose_common::{ProjectionMapping, merge_exprs}; use crate::plan_nodes::{ Expr, ExprList, LogicalFilter, LogicalProjection, OptRelNode, OptRelNodeTyp, PlanNode, }; use crate::rules::macros::define_rule; -fn merge_exprs(first: ExprList, second: ExprList) -> ExprList { - let mut res_vec = first.to_vec(); - res_vec.extend(second.to_vec()); - ExprList::new(res_vec) -} - define_rule!( ProjectFilterTransposeRule, apply_projection_filter_transpose, diff --git a/optd-datafusion-repr/src/rules/project_transpose/project_join_transpose.rs b/optd-datafusion-repr/src/rules/project_transpose/project_join_transpose.rs index 52465292..55ee479f 100644 --- a/optd-datafusion-repr/src/rules/project_transpose/project_join_transpose.rs +++ b/optd-datafusion-repr/src/rules/project_transpose/project_join_transpose.rs @@ -9,14 +9,18 @@ use crate::rules::macros::define_rule; use optd_core::optimizer::Optimizer; use optd_core::rel_node::RelNode; -use super::project_transpose_common::ProjectionMapping; +use super::project_transpose_common::{ProjectionMapping, merge_exprs, split_exprs}; use crate::plan_nodes::{ - ColumnRefExpr, Expr, ExprList, JoinType, LogicalJoin, LogicalProjection, OptRelNode, - OptRelNodeTyp, PlanNode, + ColumnRefExpr, Expr, ExprList, JoinType, LogicalJoin, LogicalProjection, + OptRelNode, OptRelNodeTyp, PlanNode, }; use crate::properties::schema::SchemaPropertyBuilder; // (Proj A) join B -> (Proj (A join B)) +// TODO: rule currently doesn't work under current +// cascades + heuristic rule wrapper infrastructure +// correctness verified using unit tests in current file +// and via comparisons against CockroachDB & Calcite define_rule!( ProjectionPullUpJoin, apply_projection_pull_up_join, @@ -63,11 +67,402 @@ fn apply_projection_pull_up_join( mapping.rewrite_join_cond( Expr::from_rel_node(Arc::new(cond)).unwrap(), left_schema.len(), + false, + true, + 0 ), JoinType::Inner, ) .into_plan_node(), ExprList::new(new_projection_exprs), ); - vec![node.into_rel_node().as_ref().clone()] + vec![node.into_rel_node().as_ref().clone()] +} + +// most general: (Proj (A join B) -> Proj ((Proj A) join (Proj B)) +// ideal: (Proj (A join B) -> (Proj A) join (Proj B) +define_rule!( + ProjectionPushDownJoin, + apply_projection_push_down_join, + ( + Projection, + (Join(JoinType::Inner), left, right, [cond]), + [exprs] + ) +); + +fn apply_projection_push_down_join( + optimizer: &impl Optimizer, + ProjectionPushDownJoinPicks { + left, + right, + cond, + exprs, + }: ProjectionPushDownJoinPicks, +) -> Vec> { + let left = Arc::new(left.clone()); + let right = Arc::new(right.clone()); + + let exprs = ExprList::from_rel_node(Arc::new(exprs)).unwrap(); + let exprs_vec = exprs.clone().to_vec(); + let cond_as_expr = Expr::from_rel_node(cond.into()).unwrap(); + let cond_col_refs = cond_as_expr.get_column_refs(); + let mut dedup_cond_col_refs = Vec::new(); + + for col_ref in &cond_col_refs { + if !exprs_vec.contains(col_ref) { + dedup_cond_col_refs.push(col_ref.clone()); + }; + } + + let dedup_cond_col_refs = ExprList::new(dedup_cond_col_refs); + let tot_exprs = merge_exprs(exprs.clone(), dedup_cond_col_refs.clone()); + + // split exprs into exprs based on left + right children + let left_schema = optimizer.get_property::(left.clone(), 0); + let left_schema_len = left_schema.len(); + + let (left_exprs, right_exprs, is_left_right_ordered) = split_exprs(tot_exprs, left_schema_len); + let new_left_schema_len = left_exprs.len(); + + let Some(left_exprs_mapping) = ProjectionMapping::build(&left_exprs) else { + return vec![]; + }; + + let Some(right_exprs_mapping) = ProjectionMapping::build(&right_exprs) else { + return vec![]; + }; + + // update join cond based on new left + right child projection nodes + let new_join_cond: Expr = left_exprs_mapping.rewrite_join_cond(cond_as_expr.clone(), left_schema_len, true, true, 0); + let new_join_cond: Expr = right_exprs_mapping.rewrite_join_cond(new_join_cond.clone(), left_schema_len, true, false, new_left_schema_len); + + let new_left_child = LogicalProjection::new( + PlanNode::from_group(left), + left_exprs + ) + .into_plan_node(); + + let new_right_child = LogicalProjection::new( + PlanNode::from_group(right), + right_exprs + ) + .into_plan_node(); + + let new_join_node = LogicalJoin::new( + new_left_child, + new_right_child, + new_join_cond, + JoinType::Inner, + ) + .into_plan_node(); + + if dedup_cond_col_refs.is_empty() && is_left_right_ordered { + // don't need top projection node + return vec![new_join_node.into_rel_node().as_ref().clone()]; + } + + // update top projection node based on new left + right child projection nodes + let mut top_proj_exprs = vec![]; + let mut left_col_idx = 0; + let mut right_col_idx = new_left_schema_len; + for i in 0..exprs.len() { + let old_col_ref = ColumnRefExpr::from_rel_node(exprs_vec[i].clone().into_rel_node()).unwrap(); + if old_col_ref.index() < left_schema_len { + top_proj_exprs.push(ColumnRefExpr::new(left_col_idx).into_expr()); + left_col_idx += 1; + } else { + top_proj_exprs.push(ColumnRefExpr::new(right_col_idx).into_expr()); + right_col_idx += 1; + } + } + let top_proj_exprs = ExprList::new(top_proj_exprs); + + let new_top_node = LogicalProjection::new( + new_join_node, + top_proj_exprs, + ); + vec![new_top_node.into_rel_node().as_ref().clone()] +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use optd_core::optimizer::Optimizer; + + use crate::{ + plan_nodes::{ + BinOpExpr, BinOpType, ColumnRefExpr, ExprList, JoinType, LogicalJoin, LogicalProjection, LogicalScan, OptRelNode, OptRelNodeTyp + }, + rules::ProjectionPushDownJoin, + testing::new_test_optimizer, + }; + + #[test] + fn proj_pushdown_join_adv_1() { + // convert proj -> join -> 2xscan to join -> 2xproj -> 2xscan + let mut test_optimizer = new_test_optimizer(Arc::new(ProjectionPushDownJoin::new())); + + let left_scan = LogicalScan::new("region".into()); + let right_scan = LogicalScan::new("customer".into()); + + let join_cond = BinOpExpr::new( + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(7).into_expr(), + BinOpType::Eq, + ) + .into_expr(); + + let proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(4).into_expr(), + ColumnRefExpr::new(7).into_expr(), + ColumnRefExpr::new(3).into_expr(), + ]); + + let join_node = LogicalJoin::new(left_scan.into_plan_node(), right_scan.into_plan_node(), join_cond, JoinType::Inner); + let proj = LogicalProjection::new(join_node.into_plan_node(), proj_exprs); + + let plan = test_optimizer.optimize(proj.into_rel_node()).unwrap(); + + let left_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ]) + .into_rel_node(); + + let right_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(1).into_expr(), + ColumnRefExpr::new(4).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ]) + .into_rel_node(); + + let new_join_cond = BinOpExpr::new( + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(3).into_expr(), + BinOpType::Eq, + ) + .into_expr().into_rel_node(); + + assert_eq!(plan.typ, OptRelNodeTyp::Join(JoinType::Inner)); + assert_eq!(plan.child(2), new_join_cond); + + assert_eq!(plan.child(0).typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(1).typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(0).child(1), left_proj_exprs); + assert_eq!(plan.child(1).child(1), right_proj_exprs); + + assert_eq!(plan.child(0).child(0).typ, OptRelNodeTyp::Scan); + assert_eq!(plan.child(1).child(0).typ, OptRelNodeTyp::Scan); + } + + #[test] + fn proj_pushdown_join_adv_2() { + // convert proj -> join -> 2xscan to proj -> join -> 2xproj -> 2xscan (out of order left + right cols) + let mut test_optimizer = new_test_optimizer(Arc::new(ProjectionPushDownJoin::new())); + + let left_scan = LogicalScan::new("region".into()); + let right_scan = LogicalScan::new("customer".into()); + + let join_cond = BinOpExpr::new( + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(7).into_expr(), + BinOpType::Eq, + ) + .into_expr(); + + let proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(4).into_expr(), + ColumnRefExpr::new(7).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(3).into_expr(), + ]); + + let join_node = LogicalJoin::new(left_scan.into_plan_node(), right_scan.into_plan_node(), join_cond, JoinType::Inner); + let proj = LogicalProjection::new(join_node.into_plan_node(), proj_exprs); + + let plan = test_optimizer.optimize(proj.into_rel_node()).unwrap(); + + let left_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ]) + .into_rel_node(); + + let right_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(1).into_expr(), + ColumnRefExpr::new(4).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ]) + .into_rel_node(); + + let new_join_cond = BinOpExpr::new( + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(3).into_expr(), + BinOpType::Eq, + ) + .into_expr().into_rel_node(); + + let top_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(3).into_expr(), + ColumnRefExpr::new(1).into_expr(), + ColumnRefExpr::new(4).into_expr(), + ]) + .into_rel_node(); + + assert_eq!(plan.typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(1), top_proj_exprs); + + assert_eq!(plan.child(0).typ, OptRelNodeTyp::Join(JoinType::Inner)); + assert_eq!(plan.child(0).child(2), new_join_cond); + + assert_eq!(plan.child(0).child(0).typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(0).child(1).typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(0).child(0).child(1), left_proj_exprs); + assert_eq!(plan.child(0).child(1).child(1), right_proj_exprs); + + assert_eq!(plan.child(0).child(0).child(0).typ, OptRelNodeTyp::Scan); + assert_eq!(plan.child(0).child(1).child(0).typ, OptRelNodeTyp::Scan); + } + + #[test] + fn proj_pushdown_join_adv_3() { + // convert proj -> join -> 2xscan to proj -> join -> 2xproj -> 2xscan (join cols not in proj cols) + let mut test_optimizer = new_test_optimizer(Arc::new(ProjectionPushDownJoin::new())); + + let left_scan = LogicalScan::new("region".into()); + let right_scan = LogicalScan::new("customer".into()); + + let join_cond = BinOpExpr::new( + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(7).into_expr(), + BinOpType::Eq, + ) + .into_expr(); + + let proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(4).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(3).into_expr(), + ]); + + let join_node = LogicalJoin::new(left_scan.into_plan_node(), right_scan.into_plan_node(), join_cond, JoinType::Inner); + let proj = LogicalProjection::new(join_node.into_plan_node(), proj_exprs); + + let plan = test_optimizer.optimize(proj.into_rel_node()).unwrap(); + + let left_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(2).into_expr(), + ]) + .into_rel_node(); + + let right_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(1).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(4).into_expr(), + ]) + .into_rel_node(); + + let new_join_cond = BinOpExpr::new( + ColumnRefExpr::new(1).into_expr(), + ColumnRefExpr::new(4).into_expr(), + BinOpType::Eq, + ) + .into_expr().into_rel_node(); + + let top_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(3).into_expr(), + ]) + .into_rel_node(); + + assert_eq!(plan.typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(1), top_proj_exprs); + + assert_eq!(plan.child(0).typ, OptRelNodeTyp::Join(JoinType::Inner)); + assert_eq!(plan.child(0).child(2), new_join_cond); + + assert_eq!(plan.child(0).child(0).typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(0).child(1).typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(0).child(0).child(1), left_proj_exprs); + assert_eq!(plan.child(0).child(1).child(1), right_proj_exprs); + + assert_eq!(plan.child(0).child(0).child(0).typ, OptRelNodeTyp::Scan); + assert_eq!(plan.child(0).child(1).child(0).typ, OptRelNodeTyp::Scan); + } + + #[test] + fn proj_pushdown_join_adv_4() { + // convert proj -> join -> 2xscan to proj -> join -> 2xproj -> 2xscan (partial join cols not in proj cols) + let mut test_optimizer = new_test_optimizer(Arc::new(ProjectionPushDownJoin::new())); + + let left_scan = LogicalScan::new("region".into()); + let right_scan = LogicalScan::new("customer".into()); + + let join_cond = BinOpExpr::new( + ColumnRefExpr::new(3).into_expr(), + ColumnRefExpr::new(1).into_expr(), + BinOpType::Eq, + ) + .into_expr(); + + let proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(4).into_expr(), + ColumnRefExpr::new(3).into_expr(), + ]); + + let join_node = LogicalJoin::new(left_scan.into_plan_node(), right_scan.into_plan_node(), join_cond, JoinType::Inner); + let proj = LogicalProjection::new(join_node.into_plan_node(), proj_exprs); + + let plan = test_optimizer.optimize(proj.into_rel_node()).unwrap(); + + let left_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(1).into_expr(), + ]) + .into_rel_node(); + + let right_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(1).into_expr(), + ColumnRefExpr::new(0).into_expr(), + ]) + .into_rel_node(); + + let new_join_cond = BinOpExpr::new( + ColumnRefExpr::new(3).into_expr(), + ColumnRefExpr::new(1).into_expr(), + BinOpType::Eq, + ) + .into_expr().into_rel_node(); + + let top_proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(2).into_expr(), + ColumnRefExpr::new(3).into_expr(), + ]) + .into_rel_node(); + + assert_eq!(plan.typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(1), top_proj_exprs); + + assert_eq!(plan.child(0).typ, OptRelNodeTyp::Join(JoinType::Inner)); + assert_eq!(plan.child(0).child(2), new_join_cond); + + assert_eq!(plan.child(0).child(0).typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(0).child(1).typ, OptRelNodeTyp::Projection); + assert_eq!(plan.child(0).child(0).child(1), left_proj_exprs); + assert_eq!(plan.child(0).child(1).child(1), right_proj_exprs); + + assert_eq!(plan.child(0).child(0).child(0).typ, OptRelNodeTyp::Scan); + assert_eq!(plan.child(0).child(1).child(0).typ, OptRelNodeTyp::Scan); + } } diff --git a/optd-datafusion-repr/src/rules/project_transpose/project_remove.rs b/optd-datafusion-repr/src/rules/project_transpose/project_remove.rs index 5490307d..16231ff9 100644 --- a/optd-datafusion-repr/src/rules/project_transpose/project_remove.rs +++ b/optd-datafusion-repr/src/rules/project_transpose/project_remove.rs @@ -1,2 +1,83 @@ // intended to remove a projection that outputs the same num of cols -// that are in scan node \ No newline at end of file +// that are in scan node +use std::collections::HashMap; + +use optd_core::rules::{Rule, RuleMatcher}; +use optd_core::{optimizer::Optimizer, rel_node::RelNode}; + +use crate::plan_nodes::{ColumnRefExpr, ExprList, OptRelNode, OptRelNodeTyp, PlanNode}; +use crate::properties::schema::SchemaPropertyBuilder; +use crate::rules::macros::define_rule; + +// Proj (Scan A) -> Scan A +// removes projections +// TODO: need to somehow match on just scan node instead +// only works in hueristic optimizer (which may be ok) +// ideally include a pass after for physical proj -> physical scan +define_rule!( + ProjectRemoveRule, + apply_projection_remove, + (Projection, child, [exprs]) +); + +fn apply_projection_remove( + optimizer: &impl Optimizer, + ProjectRemoveRulePicks { + child, + exprs + }: ProjectRemoveRulePicks, +) -> Vec> { + let child_schema = optimizer.get_property::(child.clone().into(), 0); + let child = PlanNode::from_group(child.into()); + if child.typ() != OptRelNodeTyp::Scan { + return vec![]; + } + let exprs = ExprList::from_rel_node(exprs.into()).unwrap().to_vec(); + if exprs.len() != child_schema.len() { + return vec![]; + } + let mut exp_col_idx: usize = 0; + for expr in exprs { + let col_ref = ColumnRefExpr::from_rel_node(expr.into_rel_node()).unwrap(); + let col_idx = col_ref.index(); + if exp_col_idx != col_idx { + return vec![]; + } + exp_col_idx += 1; + } + vec![child.into_rel_node().as_ref().clone()] +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use optd_core::optimizer::Optimizer; + + use crate::{ + plan_nodes::{ + ColumnRefExpr, ExprList, LogicalProjection, LogicalScan, OptRelNode, OptRelNodeTyp, + }, + rules::ProjectRemoveRule, + testing::new_test_optimizer, + }; + + #[test] + fn proj_scan_basic() { + // convert proj -> scan to scan + let mut test_optimizer = new_test_optimizer(Arc::new(ProjectRemoveRule::new())); + + let scan = LogicalScan::new("region".into()); + + let proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(1).into_expr(), + ColumnRefExpr::new(2).into_expr(), + ]); + + let proj_node: LogicalProjection = LogicalProjection::new(scan.into_plan_node(), proj_exprs); + let plan = test_optimizer.optimize(proj_node.into_rel_node()).unwrap(); + + assert_eq!(plan.typ, OptRelNodeTyp::Scan); + } +} diff --git a/optd-datafusion-repr/src/rules/project_transpose/project_transpose_common.rs b/optd-datafusion-repr/src/rules/project_transpose/project_transpose_common.rs index 8baca0f0..a5167f3c 100644 --- a/optd-datafusion-repr/src/rules/project_transpose/project_transpose_common.rs +++ b/optd-datafusion-repr/src/rules/project_transpose/project_transpose_common.rs @@ -1,5 +1,34 @@ use crate::plan_nodes::{ColumnRefExpr, Expr, ExprList, OptRelNode}; +pub fn merge_exprs(first: ExprList, second: ExprList) -> ExprList { + let mut res_vec = first.to_vec(); + res_vec.extend(second.to_vec()); + ExprList::new(res_vec) +} + +pub fn split_exprs(exprs: ExprList, left_schema_len: usize) -> (ExprList, ExprList, bool) { + let mut left_vec = vec![]; + let mut right_vec = vec![]; + let mut reached_right = false; + let mut is_left_right_ordered = true; + for expr in exprs.to_vec() { + let col_ref = ColumnRefExpr::from_rel_node(expr.into_rel_node()).unwrap(); + if col_ref.index() < left_schema_len { + // left expr + left_vec.push(col_ref.into_expr()); + if reached_right { + is_left_right_ordered = false; + } + } else { + // right expr + let right_col_ref = ColumnRefExpr::new(col_ref.index() - left_schema_len); + right_vec.push(right_col_ref.into_expr()); + reached_right = true; + } + } + (ExprList::new(left_vec), ExprList::new(right_vec), is_left_right_ordered) +} + /// This struct holds the mapping from original columns to projected columns. /// /// # Example @@ -45,23 +74,40 @@ impl ProjectionMapping { } /// Remaps all column refs in the join condition based on a - /// removed bottom projection node + /// removed bottom projection node on the left child /// /// removed node: /// Join { cond: #0=#5 } /// Projection { exprs: [#1, #0, #3, #5, #4] } --> has mapping + /// Scan + /// Scan /// ----> /// Join { cond: #1=#4 } - pub fn rewrite_join_cond(&self, cond: Expr, child_schema_len: usize) -> Expr { - let schema_size = self.forward.len(); - cond.rewrite_column_refs(&|col_idx| { - if col_idx < schema_size { - self.projection_col_maps_to(col_idx) - } else { - Some(col_idx - schema_size + child_schema_len) - } - }) - .unwrap() + /// Scan + /// Scan + pub fn rewrite_join_cond(&self, cond: Expr, left_schema_len: usize, is_added: bool, is_left_child: bool, new_left_schema_len: usize) -> Expr { + if is_added { + cond.rewrite_column_refs(&|col_idx| { + if is_left_child && col_idx < left_schema_len { + self.original_col_maps_to(col_idx) + } else if !is_left_child && col_idx >= left_schema_len { + Some(self.original_col_maps_to(col_idx - left_schema_len).unwrap() + new_left_schema_len) + } else { + Some(col_idx) + } + }) + .unwrap() + } else { + let schema_size = self.forward.len(); + cond.rewrite_column_refs(&|col_idx| { + if col_idx < schema_size { + self.projection_col_maps_to(col_idx) + } else { + Some(col_idx - schema_size + left_schema_len) + } + }) + .unwrap() + } } /// Remaps all column refs in the filter condition based on an added or diff --git a/optd-sqlplannertest/tests/basic_nodes.planner.sql b/optd-sqlplannertest/tests/basic_nodes.planner.sql index 301f300e..d9d880c5 100644 --- a/optd-sqlplannertest/tests/basic_nodes.planner.sql +++ b/optd-sqlplannertest/tests/basic_nodes.planner.sql @@ -19,8 +19,7 @@ LogicalLimit { skip: 0(u64), fetch: 1(u64) } └── LogicalProjection { exprs: [ #0, #1 ] } └── LogicalScan { table: t1 } PhysicalLimit { skip: 0(u64), fetch: 1(u64) } -└── PhysicalProjection { exprs: [ #0, #1 ] } - └── PhysicalScan { table: t1 } +└── PhysicalScan { table: t1 } 0 0 0 0 1 1 diff --git a/optd-sqlplannertest/tests/eliminate_duplicated_expr.planner.sql b/optd-sqlplannertest/tests/eliminate_duplicated_expr.planner.sql index b31e774f..8c595a05 100644 --- a/optd-sqlplannertest/tests/eliminate_duplicated_expr.planner.sql +++ b/optd-sqlplannertest/tests/eliminate_duplicated_expr.planner.sql @@ -12,8 +12,7 @@ select * from t1; /* LogicalProjection { exprs: [ #0, #1 ] } └── LogicalScan { table: t1 } -PhysicalProjection { exprs: [ #0, #1 ] } -└── PhysicalScan { table: t1 } +PhysicalScan { table: t1 } 0 0 1 1 5 2 @@ -45,8 +44,7 @@ PhysicalSort │ │ └── #0 │ └── SortOrder { order: Asc } │ └── #1 -└── PhysicalProjection { exprs: [ #0, #1 ] } - └── PhysicalScan { table: t1 } +└── PhysicalScan { table: t1 } 0 0 0 2 1 1 diff --git a/optd-sqlplannertest/tests/filter.planner.sql b/optd-sqlplannertest/tests/filter.planner.sql index 8ba252fa..f2ad8484 100644 --- a/optd-sqlplannertest/tests/filter.planner.sql +++ b/optd-sqlplannertest/tests/filter.planner.sql @@ -27,8 +27,7 @@ select * from t1 where true; LogicalProjection { exprs: [ #0, #1 ] } └── LogicalFilter { cond: true } └── LogicalScan { table: t1 } -PhysicalProjection { exprs: [ #0, #1 ] } -└── PhysicalScan { table: t1 } +PhysicalScan { table: t1 } 0 0 1 1 2 2 diff --git a/optd-sqlplannertest/tests/tpch.planner.sql b/optd-sqlplannertest/tests/tpch.planner.sql index 8bf88051..710e4a9a 100644 --- a/optd-sqlplannertest/tests/tpch.planner.sql +++ b/optd-sqlplannertest/tests/tpch.planner.sql @@ -384,8 +384,7 @@ PhysicalLimit { skip: 0(u64), fetch: 100(u64) } │ │ │ │ │ └── PhysicalScan { table: part } │ │ │ │ └── PhysicalProjection { exprs: [ #0, #1, #3 ] } │ │ │ │ └── PhysicalScan { table: partsupp } - │ │ │ └── PhysicalProjection { exprs: [ #0, #1, #2, #3, #4, #5, #6 ] } - │ │ │ └── PhysicalScan { table: supplier } + │ │ │ └── PhysicalScan { table: supplier } │ │ └── PhysicalProjection { exprs: [ #0, #1, #2 ] } │ │ └── PhysicalScan { table: nation } │ └── PhysicalProjection { exprs: [ #0 ] } diff --git a/optd-sqlplannertest/tests/verbose.planner.sql b/optd-sqlplannertest/tests/verbose.planner.sql index 910a1663..391306be 100644 --- a/optd-sqlplannertest/tests/verbose.planner.sql +++ b/optd-sqlplannertest/tests/verbose.planner.sql @@ -10,16 +10,14 @@ insert into t1 values (0), (1), (2), (3); select * from t1; /* -PhysicalProjection { exprs: [ #0 ] } -└── PhysicalScan { table: t1 } +PhysicalScan { table: t1 } */ -- Test verbose explain select * from t1; /* -PhysicalProjection { exprs: [ #0 ], cost: weighted=1.06,row_cnt=1.00,compute=0.06,io=1.00 } -└── PhysicalScan { table: t1, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } +PhysicalScan { table: t1, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } */ -- Test verbose explain with aggregation