-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Replace distinct with aggregate for where-in/exists subquery #5430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; | |
use crate::logical_plan::extension::UserDefinedLogicalNode; | ||
use crate::logical_plan::plan; | ||
use crate::utils::{ | ||
self, exprlist_to_fields, from_plan, grouping_set_expr_count, | ||
self, expand_wildcard, exprlist_to_fields, from_plan, grouping_set_expr_count, | ||
grouping_set_to_exprlist, | ||
}; | ||
use crate::{ | ||
|
@@ -1776,6 +1776,40 @@ impl Aggregate { | |
_ => plan_err!("Could not coerce into Aggregate!"), | ||
} | ||
} | ||
|
||
/// Check whether it is a Distinct. | ||
/// A Distinct means all fields of the schema are the expressions of group by. | ||
pub fn is_distinct(&self) -> datafusion_common::Result<bool> { | ||
let group_expr_size = self.group_expr.len(); | ||
if !self.aggr_expr.is_empty() || group_expr_size != self.schema.fields().len() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't |
||
return Ok(false); | ||
} | ||
|
||
let expected_group_exprs = expand_wildcard(&self.schema, self.input.as_ref())?; | ||
let expected_expr_set = expected_group_exprs.iter().collect::<HashSet<&Expr>>(); | ||
|
||
// Literals are allowed in group by. | ||
let mut group_expr_set: HashSet<&Expr> = HashSet::default(); | ||
let mut literal_expr_set: HashSet<Expr> = HashSet::default(); | ||
for group_expr in self.group_expr.iter() { | ||
if let Expr::Literal(_) = group_expr { | ||
let lit_column = Expr::Column(Column::from_name(format!("{group_expr}"))); | ||
literal_expr_set.insert(lit_column); | ||
} else { | ||
group_expr_set.insert(group_expr); | ||
} | ||
} | ||
|
||
literal_expr_set.iter().for_each(|expr| { | ||
group_expr_set.insert(expr); | ||
}); | ||
|
||
Ok(group_expr_set | ||
.intersection(&expected_expr_set) | ||
.collect::<HashSet<_>>() | ||
.len() | ||
== group_expr_size) | ||
} | ||
} | ||
|
||
/// Sorts its input according to a list of sort expressions. | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -149,6 +149,14 @@ fn optimize_exists( | |||||
return Ok(None); | ||||||
} | ||||||
}, | ||||||
LogicalPlan::Aggregate(aggregate) if aggregate.is_distinct()? => { | ||||||
match aggregate.input.as_ref() { | ||||||
LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, | ||||||
_ => { | ||||||
return Ok(None); | ||||||
} | ||||||
} | ||||||
} | ||||||
LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, | ||||||
_ => { | ||||||
// Subquery currently only supports distinct or projection | ||||||
|
@@ -224,7 +232,7 @@ mod tests { | |||||
use crate::test::*; | ||||||
use datafusion_common::Result; | ||||||
use datafusion_expr::{ | ||||||
col, exists, lit, logical_plan::LogicalPlanBuilder, not_exists, | ||||||
col, count, exists, lit, logical_plan::LogicalPlanBuilder, not_exists, | ||||||
}; | ||||||
use std::ops::Add; | ||||||
|
||||||
|
@@ -670,4 +678,56 @@ mod tests { | |||||
|
||||||
assert_plan_eq(&plan, expected) | ||||||
} | ||||||
|
||||||
#[test] | ||||||
fn exists_subquery_non_distinct_aggregate() -> Result<()> { | ||||||
let table_scan = test_table_scan()?; | ||||||
let subquery_scan = test_table_scan_with_name("sq")?; | ||||||
|
||||||
let subquery = LogicalPlanBuilder::from(subquery_scan) | ||||||
.filter(col("sq.a").gt(col("test.b")))? | ||||||
.project(vec![col("sq.a"), col("sq.c")])? | ||||||
.aggregate(vec![col("a"), col("c")], vec![count(col("a"))])? | ||||||
.build()?; | ||||||
let plan = LogicalPlanBuilder::from(table_scan) | ||||||
.filter(exists(Arc::new(subquery)))? | ||||||
.project(vec![col("test.b")])? | ||||||
.build()?; | ||||||
|
||||||
// Should not be optimized to join. | ||||||
let expected = "Projection: test.b [b:UInt32]\ | ||||||
\n Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\ | ||||||
\n Subquery: [a:UInt32, c:UInt32, COUNT(sq.a):Int64;N]\ | ||||||
\n Aggregate: groupBy=[[sq.a, sq.c]], aggr=[[COUNT(sq.a)]] [a:UInt32, c:UInt32, COUNT(sq.a):Int64;N]\ | ||||||
\n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ | ||||||
\n Filter: sq.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ | ||||||
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ | ||||||
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; | ||||||
|
||||||
assert_plan_eq(&plan, expected) | ||||||
} | ||||||
|
||||||
#[test] | ||||||
fn exists_subquery_aggragte_distinct() -> Result<()> { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
let table_scan = test_table_scan()?; | ||||||
let subquery_scan = test_table_scan_with_name("sq")?; | ||||||
|
||||||
let subquery = LogicalPlanBuilder::from(subquery_scan) | ||||||
.filter(col("sq.a").gt(col("test.b")))? | ||||||
.project(vec![col("sq.a"), col("sq.c")])? | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test for:
|
||||||
.aggregate(vec![col("a"), col("c")], Vec::<Expr>::new())? | ||||||
.build()?; | ||||||
let plan = LogicalPlanBuilder::from(table_scan) | ||||||
.filter(exists(Arc::new(subquery)))? | ||||||
.project(vec![col("test.b")])? | ||||||
.build()?; | ||||||
|
||||||
let expected = "Projection: test.b [b:UInt32]\ | ||||||
\n LeftSemi Join: Filter: sq.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ | ||||||
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ | ||||||
\n Projection: sq.a [a:UInt32]\ | ||||||
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; | ||||||
|
||||||
assert_plan_eq(&plan, expected) | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not quite sure what this means -- is this check designed to check the input expressions to the GroupBy or the output expressions?
If the output expressions, here would be an alternate description
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this correct? Or is this check designed to check the input expressions only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also feel confusing to the naming and the implementation of this method. Maybe we should call this
is_group_by_only()
?And as alamb mentioned,
SELECT a ... GROUP BY a, a
should be the group by only aggregation also.