Skip to content

Replace distinct with aggregate for where-in/exists subquery #5430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::logical_plan::plan;
use crate::utils::{
self, exprlist_to_fields, from_plan, grouping_set_expr_count,
self, expand_wildcard, exprlist_to_fields, from_plan, grouping_set_expr_count,
grouping_set_to_exprlist,
};
use crate::{
Expand Down Expand Up @@ -1776,6 +1776,40 @@ impl Aggregate {
_ => plan_err!("Could not coerce into Aggregate!"),
}
}

/// Check whether it is a Distinct.
/// A Distinct means all fields of the schema are the expressions of group by.
Comment on lines +1780 to +1781
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite sure what this means -- is this check designed to check the input expressions to the GroupBy or the output expressions?

If the output expressions, here would be an alternate description

Suggested change
/// Check whether it is a Distinct.
/// A Distinct means all fields of the schema are the expressions of group by.
/// Return true if the output values are distinct (have no duplicates)
///
/// In order for this to return true, all fields of the output schema must be expressions of the group by

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this correct? Or is this check designed to check the input expressions only?

Copy link
Contributor

@mingmwang mingmwang Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also feel confusing to the naming and the implementation of this method. Maybe we should call this is_group_by_only() ?
And as alamb mentioned, SELECT a ... GROUP BY a, a should be the group by only aggregation also.

pub fn is_distinct(&self) -> datafusion_common::Result<bool> {
let group_expr_size = self.group_expr.len();
if !self.aggr_expr.is_empty() || group_expr_size != self.schema.fields().len() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't SELECT a ... GROUP BY a, a be distinct even though the number of group exprs didn't match? Maybe this case isn't important to handle now

return Ok(false);
}

let expected_group_exprs = expand_wildcard(&self.schema, self.input.as_ref())?;
let expected_expr_set = expected_group_exprs.iter().collect::<HashSet<&Expr>>();

// Literals are allowed in group by.
let mut group_expr_set: HashSet<&Expr> = HashSet::default();
let mut literal_expr_set: HashSet<Expr> = HashSet::default();
for group_expr in self.group_expr.iter() {
if let Expr::Literal(_) = group_expr {
let lit_column = Expr::Column(Column::from_name(format!("{group_expr}")));
literal_expr_set.insert(lit_column);
} else {
group_expr_set.insert(group_expr);
}
}

literal_expr_set.iter().for_each(|expr| {
group_expr_set.insert(expr);
});

Ok(group_expr_set
.intersection(&expected_expr_set)
.collect::<HashSet<_>>()
.len()
== group_expr_size)
}
}

/// Sorts its input according to a list of sort expressions.
Expand Down
62 changes: 61 additions & 1 deletion datafusion/optimizer/src/decorrelate_where_exists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ fn optimize_exists(
return Ok(None);
}
},
LogicalPlan::Aggregate(aggregate) if aggregate.is_distinct()? => {
match aggregate.input.as_ref() {
LogicalPlan::Projection(subqry_proj) => &subqry_proj.input,
_ => {
return Ok(None);
}
}
}
LogicalPlan::Projection(subqry_proj) => &subqry_proj.input,
_ => {
// Subquery currently only supports distinct or projection
Expand Down Expand Up @@ -224,7 +232,7 @@ mod tests {
use crate::test::*;
use datafusion_common::Result;
use datafusion_expr::{
col, exists, lit, logical_plan::LogicalPlanBuilder, not_exists,
col, count, exists, lit, logical_plan::LogicalPlanBuilder, not_exists,
};
use std::ops::Add;

Expand Down Expand Up @@ -670,4 +678,56 @@ mod tests {

assert_plan_eq(&plan, expected)
}

#[test]
fn exists_subquery_non_distinct_aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
let subquery_scan = test_table_scan_with_name("sq")?;

let subquery = LogicalPlanBuilder::from(subquery_scan)
.filter(col("sq.a").gt(col("test.b")))?
.project(vec![col("sq.a"), col("sq.c")])?
.aggregate(vec![col("a"), col("c")], vec![count(col("a"))])?
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(exists(Arc::new(subquery)))?
.project(vec![col("test.b")])?
.build()?;

// Should not be optimized to join.
let expected = "Projection: test.b [b:UInt32]\
\n Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [a:UInt32, c:UInt32, COUNT(sq.a):Int64;N]\
\n Aggregate: groupBy=[[sq.a, sq.c]], aggr=[[COUNT(sq.a)]] [a:UInt32, c:UInt32, COUNT(sq.a):Int64;N]\
\n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\
\n Filter: sq.a > test.b [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_plan_eq(&plan, expected)
}

#[test]
fn exists_subquery_aggragte_distinct() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn exists_subquery_aggragte_distinct() -> Result<()> {
fn exists_subquery_aggregate_distinct() -> Result<()> {

let table_scan = test_table_scan()?;
let subquery_scan = test_table_scan_with_name("sq")?;

let subquery = LogicalPlanBuilder::from(subquery_scan)
.filter(col("sq.a").gt(col("test.b")))?
.project(vec![col("sq.a"), col("sq.c")])?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for:

  1. When there is no projection in the subquery?
  2. When the projection in the subquery is an expression (like sq.a + sq.b)?

.aggregate(vec![col("a"), col("c")], Vec::<Expr>::new())?
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(exists(Arc::new(subquery)))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftSemi Join: Filter: sq.a > test.b [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n Projection: sq.a [a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";

assert_plan_eq(&plan, expected)
}
}
Loading