From 0322f8c3c09d44dc9e16aea89c46235bba80dd49 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Mon, 27 Feb 2023 07:46:37 -0500 Subject: [PATCH] Replace distinct with aggregate for subquery --- datafusion/expr/src/logical_plan/plan.rs | 36 ++- .../optimizer/src/decorrelate_where_exists.rs | 62 ++++- .../src/replace_distinct_aggregate.rs | 241 +++++++++++++++++- 3 files changed, 333 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c3ef861eb3b4..436ab2790d3d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -24,7 +24,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::plan; use crate::utils::{ - self, exprlist_to_fields, from_plan, grouping_set_expr_count, + self, expand_wildcard, exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, }; use crate::{ @@ -1776,6 +1776,40 @@ impl Aggregate { _ => plan_err!("Could not coerce into Aggregate!"), } } + + /// Check whether it is a Distinct. + /// A Distinct means all fields of the schema are the expressions of group by. + pub fn is_distinct(&self) -> datafusion_common::Result { + let group_expr_size = self.group_expr.len(); + if !self.aggr_expr.is_empty() || group_expr_size != self.schema.fields().len() { + return Ok(false); + } + + let expected_group_exprs = expand_wildcard(&self.schema, self.input.as_ref())?; + let expected_expr_set = expected_group_exprs.iter().collect::>(); + + // Literals are allowed in group by. + let mut group_expr_set: HashSet<&Expr> = HashSet::default(); + let mut literal_expr_set: HashSet = HashSet::default(); + for group_expr in self.group_expr.iter() { + if let Expr::Literal(_) = group_expr { + let lit_column = Expr::Column(Column::from_name(format!("{group_expr}"))); + literal_expr_set.insert(lit_column); + } else { + group_expr_set.insert(group_expr); + } + } + + literal_expr_set.iter().for_each(|expr| { + group_expr_set.insert(expr); + }); + + Ok(group_expr_set + .intersection(&expected_expr_set) + .collect::>() + .len() + == group_expr_size) + } } /// Sorts its input according to a list of sort expressions. diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 72a68b3123b0..2bfb014b3a8f 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -149,6 +149,14 @@ fn optimize_exists( return Ok(None); } }, + LogicalPlan::Aggregate(aggregate) if aggregate.is_distinct()? => { + match aggregate.input.as_ref() { + LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, + _ => { + return Ok(None); + } + } + } LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, _ => { // Subquery currently only supports distinct or projection @@ -224,7 +232,7 @@ mod tests { use crate::test::*; use datafusion_common::Result; use datafusion_expr::{ - col, exists, lit, logical_plan::LogicalPlanBuilder, not_exists, + col, count, exists, lit, logical_plan::LogicalPlanBuilder, not_exists, }; use std::ops::Add; @@ -670,4 +678,56 @@ mod tests { assert_plan_eq(&plan, expected) } + + #[test] + fn exists_subquery_non_distinct_aggregate() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter(col("sq.a").gt(col("test.b")))? + .project(vec![col("sq.a"), col("sq.c")])? + .aggregate(vec![col("a"), col("c")], vec![count(col("a"))])? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + // Should not be optimized to join. + let expected = "Projection: test.b [b:UInt32]\ + \n Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ + \n Subquery: [a:UInt32, c:UInt32, COUNT(sq.a):Int64;N]\ + \n Aggregate: groupBy=[[sq.a, sq.c]], aggr=[[COUNT(sq.a)]] [a:UInt32, c:UInt32, COUNT(sq.a):Int64;N]\ + \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ + \n Filter: sq.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } + + #[test] + fn exists_subquery_aggragte_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter(col("sq.a").gt(col("test.b")))? + .project(vec![col("sq.a"), col("sq.c")])? + .aggregate(vec![col("a"), col("c")], Vec::::new())? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: sq.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 2a604bc3ffea..b9f53eede96b 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -16,13 +16,15 @@ // under the License. use crate::optimizer::ApplyOrder; +use crate::utils::split_conjunction; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::Distinct; -use datafusion_expr::{Aggregate, LogicalPlan}; +use datafusion_expr::{Aggregate, Distinct, Expr, Filter, LogicalPlan, Subquery}; use ApplyOrder::BottomUp; +use std::sync::Arc; + /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// /// ```text @@ -43,6 +45,39 @@ impl ReplaceDistinctWithAggregate { pub fn new() -> Self { Self {} } + + fn optimize_expr( + &self, + expr: &Expr, + _config: &dyn OptimizerConfig, + ) -> Result> { + let optimized_expr = match expr { + Expr::Exists { subquery, negated } => self + .try_optimize(subquery.subquery.as_ref(), _config)? + .map(|plan| Expr::Exists { + subquery: Subquery { + subquery: Arc::new(plan), + }, + negated: *negated, + }), + Expr::InSubquery { + expr, + subquery, + negated, + } => self + .try_optimize(subquery.subquery.as_ref(), _config)? + .map(|plan| Expr::InSubquery { + expr: expr.clone(), + subquery: Subquery { + subquery: Arc::new(plan), + }, + negated: *negated, + }), + _ => None, + }; + + Ok(optimized_expr) + } } impl OptimizerRule for ReplaceDistinctWithAggregate { @@ -62,6 +97,44 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { )?); Ok(Some(aggregate)) } + LogicalPlan::Filter(filter) => { + let expr_list = split_conjunction(&filter.predicate); + if expr_list.is_empty() { + return Ok(None); + } + + let mut optimized_expr_list = split_conjunction(&filter.predicate) + .iter() + .map(|expr| self.optimize_expr(expr, _config)) + .collect::>>()?; + + assert!(!optimized_expr_list.is_empty()); + + if optimized_expr_list.iter().any(|expr| expr.is_some()) { + // Conjunction the optimized predicates + let filter_expr = std::mem::replace( + &mut optimized_expr_list[0], + Option::::None, + ) + .unwrap_or_else(|| expr_list[0].clone()); + + let new_filter = optimized_expr_list + .into_iter() + .zip(expr_list) + .skip(1) + .fold(filter_expr, |mut filter_expr, (optimized_expr, expr)| { + let next_expr = + optimized_expr.unwrap_or_else(|| expr.clone()); + filter_expr = filter_expr.and(next_expr); + filter_expr + }); + + let new_filter = Filter::try_new(new_filter, filter.input.clone())?; + Ok(Some(LogicalPlan::Filter(new_filter))) + } else { + Ok(None) + } + } _ => Ok(None), } } @@ -78,8 +151,8 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { #[cfg(test)] mod tests { use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; - use crate::test::{assert_optimized_plan_eq, test_table_scan}; - use datafusion_expr::{col, LogicalPlanBuilder}; + use crate::test::*; + use datafusion_expr::{col, exists, in_subquery, LogicalPlanBuilder}; use std::sync::Arc; #[test] @@ -100,4 +173,164 @@ mod tests { expected, ) } + + #[test] + fn replace_single_distinct_where_in() -> datafusion_common::Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + // distinct in where-in subquery + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter(col("test.a").eq(col("sq.a")))? + .project(vec![col("sq.b")])? + .distinct()? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("test.c"), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b\ + \n Filter: test.c IN ()\ + \n Subquery:\ + \n Aggregate: groupBy=[[sq.b]], aggr=[[]]\ + \n Projection: sq.b\ + \n Filter: test.a = sq.a\ + \n TableScan: sq\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } + + #[test] + fn replace_distinct_in_where_in() -> datafusion_common::Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + // distinct in where-in subquery + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter(col("test.a").eq(col("sq.a")))? + .project(vec![col("sq.b"), col("sq.c")])? + .distinct()? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("test.c"), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b\ + \n Filter: test.c IN ()\ + \n Subquery:\ + \n Aggregate: groupBy=[[sq.b, sq.c]], aggr=[[]]\ + \n Projection: sq.b, sq.c\ + \n Filter: test.a = sq.a\ + \n TableScan: sq\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } + + #[test] + fn replace_distinct_in_where_exists() -> datafusion_common::Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + // distinct in where-exists subquery + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter(col("test.a").eq(col("sq.a")))? + .project(vec![col("sq.b"), col("sq.c")])? + .distinct()? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b\ + \n Filter: EXISTS ()\ + \n Subquery:\ + \n Aggregate: groupBy=[[sq.b, sq.c]], aggr=[[]]\ + \n Projection: sq.b, sq.c\ + \n Filter: test.a = sq.a\ + \n TableScan: sq\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } + + #[test] + fn replace_distinct_multi_predicate() -> datafusion_common::Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan1 = test_table_scan_with_name("sq1")?; + let subquery_scan2 = test_table_scan_with_name("sq2")?; + let subquery_scan3 = test_table_scan_with_name("sq3")?; + let subquery_scan4 = test_table_scan_with_name("sq4")?; + + let subquery1 = LogicalPlanBuilder::from(subquery_scan1) + .project(vec![col("sq1.a")])? + .build()?; + let subquery2 = LogicalPlanBuilder::from(subquery_scan2) + .project(vec![col("sq2.a")])? + .distinct()? + .build()?; + let subquery3 = LogicalPlanBuilder::from(subquery_scan3) + .project(vec![col("sq3.a"), col("sq3.b")])? + .build()?; + + let subquery4 = LogicalPlanBuilder::from(subquery_scan4) + .project(vec![col("sq4.a"), col("sq4.b")])? + .distinct()? + .build()?; + + // filter: `test.a` IN subquery1 and `test.a` IN subquery2 and EXISTS subquery3 and EXISTS subquery4 + // subquery2 and subquery4 have distinct operator, and subquery1 and subquery3 do not have. + let filter = in_subquery(col("test.a"), Arc::new(subquery1)) + .and(in_subquery(col("test.a"), Arc::new(subquery2))) + .and(exists(Arc::new(subquery3))) + .and(exists(Arc::new(subquery4))); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(filter)? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b\ + \n Filter: test.a IN () AND test.a IN () AND EXISTS () AND EXISTS ()\ + \n Subquery:\ + \n Projection: sq1.a\ + \n TableScan: sq1\ + \n Subquery:\ + \n Aggregate: groupBy=[[sq2.a]], aggr=[[]]\ + \n Projection: sq2.a\ + \n TableScan: sq2\ + \n Subquery:\ + \n Projection: sq3.a, sq3.b\ + \n TableScan: sq3\ + \n Subquery:\ + \n Aggregate: groupBy=[[sq4.a, sq4.b]], aggr=[[]]\ + \n Projection: sq4.a, sq4.b\ + \n TableScan: sq4\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } }