Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(optimizer): can't apply pull_up_correlated_predicate_agg_rule with non-null-propagating expr #20012

Merged
merged 6 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions e2e_test/batch/aggregate/issue_19835.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
CREATE TABLE foo (
id int
);

statement ok
CREATE TABLE bar (
id int,
foo_id int
);

statement ok
insert into foo values (1),(2);

query I rowsort
select
foo.id, array(
select
1
from bar
where bar.foo_id = foo.id
)
from foo;
----
1 {}
2 {}


query I rowsort
select foo.id, (select count(id) from bar where bar.foo_id = foo.id) from foo;
----
1 0
2 0

query I rowsort
select foo.id, (select avg(id) from bar where bar.foo_id = foo.id) from foo;
----
1 NULL
2 NULL

query I rowsort
select foo.id, (select max(id) from bar where bar.foo_id = foo.id) from foo;
----
1 NULL
2 NULL

query I rowsort
select foo.id, (select coalesce( max(id), 114514) from bar where bar.foo_id = foo.id) from foo;
----
1 114514
2 114514

statement ok
drop table foo;

statement ok
drop table bar;
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- count is not null-propagating
WHERE T.A > (SELECT COUNT(*) FROM T2 WHERE B = D);
expected_outputs:
- batch_plan
Expand All @@ -504,10 +505,21 @@
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- avg is null-propagating
WHERE T.A > (SELECT avg(c) FROM T2 WHERE B = D);
expected_outputs:
- batch_plan
- stream_plan
- name: a case can't be optimized by PullUpCorrelatedPredicateAggRule
sql: |
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- Coalesce is not null-propagating
WHERE T.A > (SELECT coalesce(avg(c), 114514) FROM T2 WHERE B = D);
expected_outputs:
- batch_plan
- stream_plan
- name: improve multi scalar subqueries optimization time. issue 16952. case 1.
sql: |
create table t1(a int, b int);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,23 +992,36 @@
select Array(select c from t2 where b = d) arr from t1;
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchHashJoin { type: LeftOuter, predicate: t1.b = t2.d, output: [$expr1] }
└─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t1.b, output: [$expr1] }
├─BatchExchange { order: [], dist: HashShard(t1.b) }
│ └─BatchScan { table: t1, columns: [t1.b], distribution: SomeShard }
└─BatchProject { exprs: [Coalesce(array_agg(t2.c), ARRAY[]:List(Int32)) as $expr1, t2.d] }
└─BatchHashAgg { group_key: [t2.d], aggs: [array_agg(t2.c)] }
└─BatchExchange { order: [], dist: HashShard(t2.d) }
└─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard }
└─BatchProject { exprs: [t1.b, Coalesce(array_agg(t2.c) filter(IsNotNull(1:Int32)), ARRAY[]:List(Int32)) as $expr1] }
└─BatchHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c) filter(IsNotNull(1:Int32))] }
└─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c, 1:Int32] }
├─BatchHashAgg { group_key: [t1.b], aggs: [] }
│ └─BatchExchange { order: [], dist: HashShard(t1.b) }
│ └─BatchScan { table: t1, columns: [t1.b], distribution: SomeShard }
└─BatchExchange { order: [], dist: HashShard(t2.d) }
└─BatchProject { exprs: [t2.d, t2.c, 1:Int32] }
└─BatchFilter { predicate: IsNotNull(t2.d) }
└─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [arr, t1._row_id(hidden), t1.b(hidden), t2.d(hidden)], stream_key: [t1._row_id, t1.b], pk_columns: [t1._row_id, t1.b], pk_conflict: NoCheck }
StreamMaterialize { columns: [arr, t1._row_id(hidden), t1.b(hidden), t1.b#1(hidden)], stream_key: [t1._row_id, t1.b], pk_columns: [t1._row_id, t1.b], pk_conflict: NoCheck }
└─StreamExchange { dist: HashShard(t1._row_id, t1.b) }
└─StreamHashJoin { type: LeftOuter, predicate: t1.b = t2.d, output: [$expr1, t1._row_id, t1.b, t2.d] }
└─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t1.b, output: [$expr1, t1._row_id, t1.b, t1.b] }
├─StreamExchange { dist: HashShard(t1.b) }
│ └─StreamTableScan { table: t1, columns: [t1.b, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) }
└─StreamProject { exprs: [Coalesce(array_agg(t2.c), ARRAY[]:List(Int32)) as $expr1, t2.d] }
└─StreamHashAgg { group_key: [t2.d], aggs: [array_agg(t2.c), count] }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
└─StreamProject { exprs: [t1.b, Coalesce(array_agg(t2.c) filter(IsNotNull(1:Int32)), ARRAY[]:List(Int32)) as $expr1] }
└─StreamHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c) filter(IsNotNull(1:Int32)), count] }
└─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c, 1:Int32, t2._row_id] }
├─StreamProject { exprs: [t1.b], noop_update_hint: true }
│ └─StreamHashAgg { group_key: [t1.b], aggs: [count] }
│ └─StreamExchange { dist: HashShard(t1.b) }
│ └─StreamTableScan { table: t1, columns: [t1.b, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamProject { exprs: [t2.d, t2.c, 1:Int32, t2._row_id] }
└─StreamFilter { predicate: IsNotNull(t2.d) }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
- name: correlated array subquery \du
sql: |
SELECT r.rolname, r.rolsuper, r.rolinherit,
Expand All @@ -1027,18 +1040,32 @@
BatchExchange { order: [rw_users.name ASC], dist: Single }
└─BatchProject { exprs: [rw_users.name, rw_users.is_super, true:Boolean, rw_users.create_user, rw_users.create_db, rw_users.can_login, -1:Int32, null:Timestamptz, $expr1, true:Boolean, true:Boolean] }
└─BatchSort { order: [rw_users.name ASC] }
└─BatchHashJoin { type: LeftOuter, predicate: rw_users.id = null:Int32, output: all }
└─BatchHashJoin { type: LeftOuter, predicate: rw_users.id IS NOT DISTINCT FROM rw_users.id, output: all }
├─BatchExchange { order: [], dist: HashShard(rw_users.id) }
│ └─BatchFilter { predicate: Not(RegexpEq(rw_users.name, '^pg_':Varchar)) }
│ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name, rw_users.is_super, rw_users.create_db, rw_users.create_user, rw_users.can_login], distribution: Single }
└─BatchProject { exprs: [Coalesce(array_agg(rw_users.name), ARRAY[]:List(Varchar)) as $expr1, null:Int32] }
└─BatchHashAgg { group_key: [null:Int32], aggs: [array_agg(rw_users.name)] }
└─BatchExchange { order: [], dist: HashShard(null:Int32) }
└─BatchHashJoin { type: Inner, predicate: null:Int32 = rw_users.id, output: [rw_users.name, null:Int32] }
├─BatchExchange { order: [], dist: HashShard(null:Int32) }
│ └─BatchValues { rows: [] }
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
└─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
└─BatchProject { exprs: [rw_users.id, Coalesce(array_agg(rw_users.name) filter(IsNotNull(1:Int32)), ARRAY[]:List(Varchar)) as $expr1] }
└─BatchHashAgg { group_key: [rw_users.id], aggs: [array_agg(rw_users.name) filter(IsNotNull(1:Int32))] }
└─BatchHashJoin { type: LeftOuter, predicate: rw_users.id IS NOT DISTINCT FROM rw_users.id, output: [rw_users.id, rw_users.name, 1:Int32] }
├─BatchHashAgg { group_key: [rw_users.id], aggs: [] }
│ └─BatchExchange { order: [], dist: HashShard(rw_users.id) }
│ └─BatchProject { exprs: [rw_users.id] }
│ └─BatchFilter { predicate: Not(RegexpEq(rw_users.name, '^pg_':Varchar)) }
│ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
└─BatchProject { exprs: [rw_users.id, rw_users.name, 1:Int32] }
└─BatchHashJoin { type: Inner, predicate: null:Int32 = rw_users.id, output: [rw_users.id, rw_users.name] }
├─BatchExchange { order: [], dist: HashShard(null:Int32) }
│ └─BatchProject { exprs: [rw_users.id, null:Int32] }
│ └─BatchNestedLoopJoin { type: Inner, predicate: true, output: all }
│ ├─BatchExchange { order: [], dist: Single }
│ │ └─BatchHashAgg { group_key: [rw_users.id], aggs: [] }
│ │ └─BatchExchange { order: [], dist: HashShard(rw_users.id) }
│ │ └─BatchValues { rows: [] }
│ └─BatchFilter { predicate: false:Boolean }
│ └─BatchValues { rows: [] }
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
└─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
- name: correlated array subquery (issue 14423)
sql: |
CREATE TABLE array_types ( x BIGINT[] );
Expand Down Expand Up @@ -1069,6 +1096,7 @@
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- count is not null-propagating
WHERE T.A > (SELECT COUNT(*) FROM T2 WHERE B = D);
batch_plan: |-
BatchExchange { order: [], dist: Single }
Expand Down Expand Up @@ -1109,6 +1137,7 @@
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- avg is null-propagating
WHERE T.A > (SELECT avg(c) FROM T2 WHERE B = D);
batch_plan: |-
BatchExchange { order: [], dist: Single }
Expand All @@ -1132,6 +1161,48 @@
└─StreamHashAgg { group_key: [t2.d], aggs: [sum(t2.c), count(t2.c), count] }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
- name: a case can't be optimized by PullUpCorrelatedPredicateAggRule
sql: |
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- Coalesce is not null-propagating
WHERE T.A > (SELECT coalesce(avg(c), 114514) FROM T2 WHERE B = D);
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchHashJoin { type: Inner, predicate: t.b IS NOT DISTINCT FROM t.b AND ($expr1 > $expr2), output: [t.a, t.b] }
├─BatchExchange { order: [], dist: HashShard(t.b) }
│ └─BatchProject { exprs: [t.a, t.b, t.a::Decimal as $expr1] }
│ └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard }
└─BatchProject { exprs: [t.b, Coalesce((sum(t2.c)::Decimal / count(t2.c)::Decimal), 114514:Decimal) as $expr2] }
└─BatchHashAgg { group_key: [t.b], aggs: [sum(t2.c), count(t2.c)] }
└─BatchHashJoin { type: LeftOuter, predicate: t.b IS NOT DISTINCT FROM t2.d, output: [t.b, t2.c] }
├─BatchHashAgg { group_key: [t.b], aggs: [] }
│ └─BatchExchange { order: [], dist: HashShard(t.b) }
│ └─BatchScan { table: t, columns: [t.b], distribution: SomeShard }
└─BatchExchange { order: [], dist: HashShard(t2.d) }
└─BatchProject { exprs: [t2.d, t2.c] }
└─BatchFilter { predicate: IsNotNull(t2.d) }
└─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [a, b, t._row_id(hidden), t.b(hidden)], stream_key: [t._row_id, b], pk_columns: [t._row_id, b], pk_conflict: NoCheck }
└─StreamProject { exprs: [t.a, t.b, t._row_id, t.b] }
└─StreamFilter { predicate: ($expr1 > $expr2) }
└─StreamHashJoin { type: Inner, predicate: t.b IS NOT DISTINCT FROM t.b, output: all }
├─StreamExchange { dist: HashShard(t.b) }
│ └─StreamProject { exprs: [t.a, t.b, t.a::Decimal as $expr1, t._row_id] }
│ └─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamProject { exprs: [t.b, Coalesce((sum(t2.c)::Decimal / count(t2.c)::Decimal), 114514:Decimal) as $expr2] }
└─StreamHashAgg { group_key: [t.b], aggs: [sum(t2.c), count(t2.c), count] }
└─StreamHashJoin { type: LeftOuter, predicate: t.b IS NOT DISTINCT FROM t2.d, output: [t.b, t2.c, t2._row_id] }
├─StreamProject { exprs: [t.b], noop_update_hint: true }
│ └─StreamHashAgg { group_key: [t.b], aggs: [count] }
│ └─StreamExchange { dist: HashShard(t.b) }
│ └─StreamTableScan { table: t, columns: [t.b, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamProject { exprs: [t2.d, t2.c, t2._row_id] }
└─StreamFilter { predicate: IsNotNull(t2.d) }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
- name: improve multi scalar subqueries optimization time. issue 16952. case 1.
sql: |
create table t1(a int, b int);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,19 @@ impl Rule for PullUpCorrelatedPredicateAggRule {

let new_bottom_proj: PlanRef = LogicalProject::new(filter, bottom_proj_exprs).into();

// If there is a count aggregate, bail out and leave for general subquery unnesting to deal.
// When group by is empty, count would return 0 instead of null.
// Unless we can prove that the row generated by the empty group could be eliminated by the top filter.
//
// A typical example is the count is generate by avg, because avg = sum / count. In this case, when input is empty,
// sum is null, so avg is null. And null-rejected expression will be false, so we can still apply this rule and we don't need to generate a 0 value for count.
let count_exists = agg_calls
.iter()
.any(|agg_call| matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Count)));

if count_exists {
// When group input is empty, not count agg would return null.
// We can apply this rule only if:
// 1. The `group by + proj` returns null for empty input
// 2. OR the top filter is null for empty input
{
// When group input is empty, if the agg is not `count`, it would return null.
let null_agg_pos = agg_calls
.iter()
.positions(|agg_call| {
!matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Count))
})
.collect_vec();

xxchan marked this conversation as resolved.
Show resolved Hide resolved
// If no null agg, bail out.
// We don't have null args, so the output will never be null. Bail out.
if null_agg_pos.is_empty() {
return None;
}
Expand All @@ -185,21 +178,26 @@ impl Rule for PullUpCorrelatedPredicateAggRule {

// Shift the top project expressions to the right by apply_left schema len, because it is used to check null-rejected by the top filter.
let apply_left_schema = apply_left.schema().len();
let mut top_proj_all_null = true;
let mut top_proj_null_bitset =
FixedBitSet::with_capacity(top_project.base.schema().len() + apply_left_schema);
for (i, expr) in top_proj_exprs.iter().enumerate() {
if Strong::is_null(expr, agg_null_bitset.clone()) {
top_proj_null_bitset.insert(i + apply_left_schema);
} else {
top_proj_all_null = false;
Copy link
Contributor

@chenzl25 chenzl25 Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When Strong::is_null return false, it means we can't get any useful information, so we shouldn't allow apply rule based on it.

}
}

// Only all expr in conjunctions are null, we can apply this rule, otherwise, bail out.
if top_filter
let top_filter_all_null = top_filter
.predicate()
.conjunctions
.iter()
.any(|expr| !Strong::is_null(expr, top_proj_null_bitset.clone()))
{
.all(|expr| Strong::is_null(expr, top_proj_null_bitset.clone()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be any.

Suggested change
.all(|expr| Strong::is_null(expr, top_proj_null_bitset.clone()));
.any(|expr| Strong::is_null(expr, top_proj_null_bitset.clone()));

let can_apply =
top_proj_all_null || (!top_filter.predicate().always_true() && top_filter_all_null);

if !can_apply {
return None;
}
}
Expand Down
Loading