Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(optimizer): can't apply pull_up_correlated_predicate_agg_rule with non-null-propagating expr (#20012) #20344

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions e2e_test/batch/aggregate/issue_19835.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
CREATE TABLE foo (
id int
);

statement ok
CREATE TABLE bar (
id int,
foo_id int
);

statement ok
insert into foo values (1),(2);

query I rowsort
select
foo.id, array(
select
1
from bar
where bar.foo_id = foo.id
)
from foo;
----
1 {}
2 {}


query I rowsort
select foo.id, (select count(id) from bar where bar.foo_id = foo.id) from foo;
----
1 0
2 0

query I rowsort
select foo.id, (select avg(id) from bar where bar.foo_id = foo.id) from foo;
----
1 NULL
2 NULL

query I rowsort
select foo.id, (select max(id) from bar where bar.foo_id = foo.id) from foo;
----
1 NULL
2 NULL

query I rowsort
select foo.id, (select coalesce( max(id), 114514) from bar where bar.foo_id = foo.id) from foo;
----
1 114514
2 114514

statement ok
drop table foo;

statement ok
drop table bar;
Original file line number Diff line number Diff line change
@@ -495,6 +495,7 @@
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- count is not null-propagating
WHERE T.A > (SELECT COUNT(*) FROM T2 WHERE B = D);
expected_outputs:
- batch_plan
@@ -504,10 +505,31 @@
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- avg is null-propagating
WHERE T.A > (SELECT avg(c) FROM T2 WHERE B = D);
expected_outputs:
- batch_plan
- stream_plan
- name: a case can't be optimized by PullUpCorrelatedPredicateAggRule
sql: |
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- Coalesce is not null-propagating
WHERE T.A > (SELECT coalesce(avg(c), 114514) FROM T2 WHERE B = D);
expected_outputs:
- batch_plan
- stream_plan
- name: a case can't be optimized by PullUpCorrelatedPredicateAggRule
sql: |
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- Can apply if any of the conjunction is null
WHERE null AND T.A > (SELECT coalesce(avg(c), 114514) FROM T2 WHERE B = D);
expected_outputs:
- batch_plan
- stream_plan
- name: improve multi scalar subqueries optimization time. issue 16952. case 1.
sql: |
create table t1(a int, b int);
Original file line number Diff line number Diff line change
@@ -992,23 +992,36 @@
select Array(select c from t2 where b = d) arr from t1;
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchHashJoin { type: LeftOuter, predicate: t1.b = t2.d, output: [$expr1] }
└─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t1.b, output: [$expr1] }
├─BatchExchange { order: [], dist: HashShard(t1.b) }
│ └─BatchScan { table: t1, columns: [t1.b], distribution: SomeShard }
└─BatchProject { exprs: [Coalesce(array_agg(t2.c), ARRAY[]:List(Int32)) as $expr1, t2.d] }
└─BatchHashAgg { group_key: [t2.d], aggs: [array_agg(t2.c)] }
└─BatchExchange { order: [], dist: HashShard(t2.d) }
└─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard }
└─BatchProject { exprs: [t1.b, Coalesce(array_agg(t2.c) filter(IsNotNull(1:Int32)), ARRAY[]:List(Int32)) as $expr1] }
└─BatchHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c) filter(IsNotNull(1:Int32))] }
└─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c, 1:Int32] }
├─BatchHashAgg { group_key: [t1.b], aggs: [] }
│ └─BatchExchange { order: [], dist: HashShard(t1.b) }
│ └─BatchScan { table: t1, columns: [t1.b], distribution: SomeShard }
└─BatchExchange { order: [], dist: HashShard(t2.d) }
└─BatchProject { exprs: [t2.d, t2.c, 1:Int32] }
└─BatchFilter { predicate: IsNotNull(t2.d) }
└─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [arr, t1._row_id(hidden), t1.b(hidden), t2.d(hidden)], stream_key: [t1._row_id, t1.b], pk_columns: [t1._row_id, t1.b], pk_conflict: NoCheck }
StreamMaterialize { columns: [arr, t1._row_id(hidden), t1.b(hidden), t1.b#1(hidden)], stream_key: [t1._row_id, t1.b], pk_columns: [t1._row_id, t1.b], pk_conflict: NoCheck }
└─StreamExchange { dist: HashShard(t1._row_id, t1.b) }
└─StreamHashJoin { type: LeftOuter, predicate: t1.b = t2.d, output: [$expr1, t1._row_id, t1.b, t2.d] }
└─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t1.b, output: [$expr1, t1._row_id, t1.b, t1.b] }
├─StreamExchange { dist: HashShard(t1.b) }
│ └─StreamTableScan { table: t1, columns: [t1.b, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) }
└─StreamProject { exprs: [Coalesce(array_agg(t2.c), ARRAY[]:List(Int32)) as $expr1, t2.d] }
└─StreamHashAgg { group_key: [t2.d], aggs: [array_agg(t2.c), count] }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
└─StreamProject { exprs: [t1.b, Coalesce(array_agg(t2.c) filter(IsNotNull(1:Int32)), ARRAY[]:List(Int32)) as $expr1] }
└─StreamHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c) filter(IsNotNull(1:Int32)), count] }
└─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c, 1:Int32, t2._row_id] }
├─StreamProject { exprs: [t1.b], noop_update_hint: true }
│ └─StreamHashAgg { group_key: [t1.b], aggs: [count] }
│ └─StreamExchange { dist: HashShard(t1.b) }
│ └─StreamTableScan { table: t1, columns: [t1.b, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamProject { exprs: [t2.d, t2.c, 1:Int32, t2._row_id] }
└─StreamFilter { predicate: IsNotNull(t2.d) }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
- name: correlated array subquery \du
sql: |
SELECT r.rolname, r.rolsuper, r.rolinherit,
@@ -1027,18 +1040,32 @@
BatchExchange { order: [rw_users.name ASC], dist: Single }
└─BatchProject { exprs: [rw_users.name, rw_users.is_super, true:Boolean, rw_users.create_user, rw_users.create_db, rw_users.can_login, -1:Int32, null:Timestamptz, $expr1, true:Boolean, true:Boolean] }
└─BatchSort { order: [rw_users.name ASC] }
└─BatchHashJoin { type: LeftOuter, predicate: rw_users.id = null:Int32, output: all }
└─BatchHashJoin { type: LeftOuter, predicate: rw_users.id IS NOT DISTINCT FROM rw_users.id, output: all }
├─BatchExchange { order: [], dist: HashShard(rw_users.id) }
│ └─BatchFilter { predicate: Not(RegexpEq(rw_users.name, '^pg_':Varchar)) }
│ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name, rw_users.is_super, rw_users.create_db, rw_users.create_user, rw_users.can_login], distribution: Single }
└─BatchProject { exprs: [Coalesce(array_agg(rw_users.name), ARRAY[]:List(Varchar)) as $expr1, null:Int32] }
└─BatchHashAgg { group_key: [null:Int32], aggs: [array_agg(rw_users.name)] }
└─BatchExchange { order: [], dist: HashShard(null:Int32) }
└─BatchHashJoin { type: Inner, predicate: null:Int32 = rw_users.id, output: [rw_users.name, null:Int32] }
├─BatchExchange { order: [], dist: HashShard(null:Int32) }
│ └─BatchValues { rows: [] }
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
└─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
└─BatchProject { exprs: [rw_users.id, Coalesce(array_agg(rw_users.name) filter(IsNotNull(1:Int32)), ARRAY[]:List(Varchar)) as $expr1] }
└─BatchHashAgg { group_key: [rw_users.id], aggs: [array_agg(rw_users.name) filter(IsNotNull(1:Int32))] }
└─BatchHashJoin { type: LeftOuter, predicate: rw_users.id IS NOT DISTINCT FROM rw_users.id, output: [rw_users.id, rw_users.name, 1:Int32] }
├─BatchHashAgg { group_key: [rw_users.id], aggs: [] }
│ └─BatchExchange { order: [], dist: HashShard(rw_users.id) }
│ └─BatchProject { exprs: [rw_users.id] }
│ └─BatchFilter { predicate: Not(RegexpEq(rw_users.name, '^pg_':Varchar)) }
│ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
└─BatchProject { exprs: [rw_users.id, rw_users.name, 1:Int32] }
└─BatchHashJoin { type: Inner, predicate: null:Int32 = rw_users.id, output: [rw_users.id, rw_users.name] }
├─BatchExchange { order: [], dist: HashShard(null:Int32) }
│ └─BatchProject { exprs: [rw_users.id, null:Int32] }
│ └─BatchNestedLoopJoin { type: Inner, predicate: true, output: all }
│ ├─BatchExchange { order: [], dist: Single }
│ │ └─BatchHashAgg { group_key: [rw_users.id], aggs: [] }
│ │ └─BatchExchange { order: [], dist: HashShard(rw_users.id) }
│ │ └─BatchValues { rows: [] }
│ └─BatchFilter { predicate: false:Boolean }
│ └─BatchValues { rows: [] }
└─BatchExchange { order: [], dist: HashShard(rw_users.id) }
└─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single }
- name: correlated array subquery (issue 14423)
sql: |
CREATE TABLE array_types ( x BIGINT[] );
@@ -1069,6 +1096,7 @@
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- count is not null-propagating
WHERE T.A > (SELECT COUNT(*) FROM T2 WHERE B = D);
batch_plan: |-
BatchExchange { order: [], dist: Single }
@@ -1109,6 +1137,7 @@
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- avg is null-propagating
WHERE T.A > (SELECT avg(c) FROM T2 WHERE B = D);
batch_plan: |-
BatchExchange { order: [], dist: Single }
@@ -1132,6 +1161,81 @@
└─StreamHashAgg { group_key: [t2.d], aggs: [sum(t2.c), count(t2.c), count] }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
- name: a case can't be optimized by PullUpCorrelatedPredicateAggRule
sql: |
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- Coalesce is not null-propagating
WHERE T.A > (SELECT coalesce(avg(c), 114514) FROM T2 WHERE B = D);
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchHashJoin { type: Inner, predicate: t.b IS NOT DISTINCT FROM t.b AND ($expr1 > $expr2), output: [t.a, t.b] }
├─BatchExchange { order: [], dist: HashShard(t.b) }
│ └─BatchProject { exprs: [t.a, t.b, t.a::Decimal as $expr1] }
│ └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard }
└─BatchProject { exprs: [t.b, Coalesce((sum(t2.c)::Decimal / count(t2.c)::Decimal), 114514:Decimal) as $expr2] }
└─BatchHashAgg { group_key: [t.b], aggs: [sum(t2.c), count(t2.c)] }
└─BatchHashJoin { type: LeftOuter, predicate: t.b IS NOT DISTINCT FROM t2.d, output: [t.b, t2.c] }
├─BatchHashAgg { group_key: [t.b], aggs: [] }
│ └─BatchExchange { order: [], dist: HashShard(t.b) }
│ └─BatchScan { table: t, columns: [t.b], distribution: SomeShard }
└─BatchExchange { order: [], dist: HashShard(t2.d) }
└─BatchProject { exprs: [t2.d, t2.c] }
└─BatchFilter { predicate: IsNotNull(t2.d) }
└─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [a, b, t._row_id(hidden), t.b(hidden)], stream_key: [t._row_id, b], pk_columns: [t._row_id, b], pk_conflict: NoCheck }
└─StreamProject { exprs: [t.a, t.b, t._row_id, t.b] }
└─StreamFilter { predicate: ($expr1 > $expr2) }
└─StreamHashJoin { type: Inner, predicate: t.b IS NOT DISTINCT FROM t.b, output: all }
├─StreamExchange { dist: HashShard(t.b) }
│ └─StreamProject { exprs: [t.a, t.b, t.a::Decimal as $expr1, t._row_id] }
│ └─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamProject { exprs: [t.b, Coalesce((sum(t2.c)::Decimal / count(t2.c)::Decimal), 114514:Decimal) as $expr2] }
└─StreamHashAgg { group_key: [t.b], aggs: [sum(t2.c), count(t2.c), count] }
└─StreamHashJoin { type: LeftOuter, predicate: t.b IS NOT DISTINCT FROM t2.d, output: [t.b, t2.c, t2._row_id] }
├─StreamProject { exprs: [t.b], noop_update_hint: true }
│ └─StreamHashAgg { group_key: [t.b], aggs: [count] }
│ └─StreamExchange { dist: HashShard(t.b) }
│ └─StreamTableScan { table: t, columns: [t.b, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamProject { exprs: [t2.d, t2.c, t2._row_id] }
└─StreamFilter { predicate: IsNotNull(t2.d) }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
- name: a case can't be optimized by PullUpCorrelatedPredicateAggRule
sql: |
CREATE TABLE T (A INT, B INT);
CREATE TABLE T2 (C INT, D INT);
SELECT * FROM T
-- Can apply if any of the conjunction is null
WHERE null AND T.A > (SELECT coalesce(avg(c), 114514) FROM T2 WHERE B = D);
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchHashJoin { type: Inner, predicate: t.b = t2.d AND ($expr1 > $expr2), output: [t.a, t.b] }
├─BatchExchange { order: [], dist: HashShard(t.b) }
│ └─BatchProject { exprs: [t.a, t.b, t.a::Decimal as $expr1] }
│ └─BatchFilter { predicate: null:Boolean }
│ └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard }
└─BatchProject { exprs: [Coalesce((sum(t2.c)::Decimal / count(t2.c)::Decimal), 114514:Decimal) as $expr2, t2.d] }
└─BatchHashAgg { group_key: [t2.d], aggs: [sum(t2.c), count(t2.c)] }
└─BatchExchange { order: [], dist: HashShard(t2.d) }
└─BatchFilter { predicate: null:Boolean }
└─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [a, b, t._row_id(hidden), t2.d(hidden)], stream_key: [t._row_id, b], pk_columns: [t._row_id, b], pk_conflict: NoCheck }
└─StreamProject { exprs: [t.a, t.b, t._row_id, t2.d] }
└─StreamFilter { predicate: ($expr1 > $expr2) }
└─StreamHashJoin { type: Inner, predicate: t.b = t2.d, output: all }
├─StreamExchange { dist: HashShard(t.b) }
│ └─StreamProject { exprs: [t.a, t.b, t.a::Decimal as $expr1, t._row_id] }
│ └─StreamFilter { predicate: null:Boolean }
│ └─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamProject { exprs: [Coalesce((sum(t2.c)::Decimal / count(t2.c)::Decimal), 114514:Decimal) as $expr2, t2.d] }
└─StreamHashAgg { group_key: [t2.d], aggs: [sum(t2.c), count(t2.c), count] }
└─StreamExchange { dist: HashShard(t2.d) }
└─StreamFilter { predicate: null:Boolean }
└─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
- name: improve multi scalar subqueries optimization time. issue 16952. case 1.
sql: |
create table t1(a int, b int);
Original file line number Diff line number Diff line change
@@ -153,26 +153,19 @@ impl Rule for PullUpCorrelatedPredicateAggRule {

let new_bottom_proj: PlanRef = LogicalProject::new(filter, bottom_proj_exprs).into();

// If there is a count aggregate, bail out and leave for general subquery unnesting to deal.
// When group by is empty, count would return 0 instead of null.
// Unless we can prove that the row generated by the empty group could be eliminated by the top filter.
//
// A typical example is the count is generate by avg, because avg = sum / count. In this case, when input is empty,
// sum is null, so avg is null. And null-rejected expression will be false, so we can still apply this rule and we don't need to generate a 0 value for count.
let count_exists = agg_calls
.iter()
.any(|agg_call| matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Count)));

if count_exists {
// When group input is empty, not count agg would return null.
// We can apply this rule only if:
// 1. The `group by + proj` returns null for empty input
// 2. OR the top filter is null for empty input
{
// When group input is empty, if the agg is not `count`, it would return null.
let null_agg_pos = agg_calls
.iter()
.positions(|agg_call| {
!matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Count))
})
.collect_vec();

// If no null agg, bail out.
// We don't have null args, so the output will never be null. Bail out.
if null_agg_pos.is_empty() {
return None;
}
@@ -185,21 +178,25 @@ impl Rule for PullUpCorrelatedPredicateAggRule {

// Shift the top project expressions to the right by apply_left schema len, because it is used to check null-rejected by the top filter.
let apply_left_schema = apply_left.schema().len();
let mut top_proj_all_null = true;
let mut top_proj_null_bitset =
FixedBitSet::with_capacity(top_project.base.schema().len() + apply_left_schema);
for (i, expr) in top_proj_exprs.iter().enumerate() {
if Strong::is_null(expr, agg_null_bitset.clone()) {
top_proj_null_bitset.insert(i + apply_left_schema);
} else {
top_proj_all_null = false;
}
}

// Only all expr in conjunctions are null, we can apply this rule, otherwise, bail out.
if top_filter
let top_filter_any_null = top_filter
.predicate()
.conjunctions
.iter()
.any(|expr| !Strong::is_null(expr, top_proj_null_bitset.clone()))
{
.any(|expr| Strong::is_null(expr, top_proj_null_bitset.clone()));
let can_apply = top_proj_all_null || top_filter_any_null;

if !can_apply {
return None;
}
}