Skip to content

Commit f1246a9

Browse files
authored
perf: Optimize CASE for any WHEN false (#17835)
* Implement WHEN false logic for case statements * Fix tests to use valid column names (e.g c3 instead of a) * Add comments and add negative test case * clean up * Fix negative test to avoid (case -> or/and) simplification * Delete let guard comment (not useful anymore) * Modify logic to move all kept elements to a new vector instead of removing from the original vector
1 parent 35f45b5 commit f1246a9

File tree

1 file changed

+128
-64
lines changed

1 file changed

+128
-64
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 128 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,33 +1436,49 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14361436

14371437
// CASE WHEN true THEN A ... END --> A
14381438
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1439+
// CASE WHEN false THEN A END --> NULL
1440+
// CASE WHEN false THEN A ELSE B END --> B
1441+
// CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END
14391442
Expr::Case(Case {
14401443
expr: None,
1441-
mut when_then_expr,
1442-
else_expr: _,
1443-
// if let guard is not stabilized so we can't use it yet: https://github.com/rust-lang/rust/issues/51114
1444-
// Once it's supported we can avoid searching through when_then_expr twice in the below .any() and .position() calls
1445-
// }) if let Some(i) = when_then_expr.iter().position(|(when, _)| is_true(when.as_ref())) => {
1444+
when_then_expr,
1445+
mut else_expr,
14461446
}) if when_then_expr
14471447
.iter()
1448-
.any(|(when, _)| is_true(when.as_ref())) =>
1448+
.any(|(when, _)| is_true(when.as_ref()) || is_false(when.as_ref())) =>
14491449
{
1450-
let i = when_then_expr
1451-
.iter()
1452-
.position(|(when, _)| is_true(when.as_ref()))
1453-
.unwrap();
1454-
let (_, then_) = when_then_expr.swap_remove(i);
1455-
// CASE WHEN true THEN A ... END --> A
1456-
if i == 0 {
1457-
return Ok(Transformed::yes(*then_));
1450+
let out_type = info.get_data_type(&when_then_expr[0].1)?;
1451+
let mut new_when_then_expr = Vec::with_capacity(when_then_expr.len());
1452+
1453+
for (when, then) in when_then_expr.into_iter() {
1454+
if is_true(when.as_ref()) {
1455+
// Skip adding the rest of the when-then expressions after WHEN true
1456+
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1457+
else_expr = Some(then);
1458+
break;
1459+
} else if !is_false(when.as_ref()) {
1460+
new_when_then_expr.push((when, then));
1461+
}
1462+
// else: skip WHEN false cases
1463+
}
1464+
1465+
// Exclude CASE statement altogether if there are no when-then expressions left
1466+
if new_when_then_expr.is_empty() {
1467+
// CASE WHEN false THEN A ELSE B END --> B
1468+
if let Some(else_expr) = else_expr {
1469+
return Ok(Transformed::yes(*else_expr));
1470+
// CASE WHEN false THEN A END --> NULL
1471+
} else {
1472+
let null =
1473+
Expr::Literal(ScalarValue::try_new_null(&out_type)?, None);
1474+
return Ok(Transformed::yes(null));
1475+
}
14581476
}
14591477

1460-
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1461-
when_then_expr.truncate(i);
14621478
Transformed::yes(Expr::Case(Case {
14631479
expr: None,
1464-
when_then_expr,
1465-
else_expr: Some(then_),
1480+
when_then_expr: new_when_then_expr,
1481+
else_expr,
14661482
}))
14671483
}
14681484

@@ -3810,53 +3826,53 @@ mod tests {
38103826

38113827
#[test]
38123828
fn simplify_expr_case_when_first_true() {
3813-
// CASE WHEN true THEN 1 ELSE x END --> 1
3829+
// CASE WHEN true THEN 1 ELSE c1 END --> 1
38143830
assert_eq!(
38153831
simplify(Expr::Case(Case::new(
38163832
None,
38173833
vec![(Box::new(lit(true)), Box::new(lit(1)),)],
3818-
Some(Box::new(col("x"))),
3834+
Some(Box::new(col("c1"))),
38193835
))),
38203836
lit(1)
38213837
);
38223838

3823-
// CASE WHEN true THEN col("a") ELSE col("b") END --> col("a")
3839+
// CASE WHEN true THEN col('a') ELSE col('b') END --> col('a')
38243840
assert_eq!(
38253841
simplify(Expr::Case(Case::new(
38263842
None,
3827-
vec![(Box::new(lit(true)), Box::new(col("a")),)],
3828-
Some(Box::new(col("b"))),
3843+
vec![(Box::new(lit(true)), Box::new(lit("a")),)],
3844+
Some(Box::new(lit("b"))),
38293845
))),
3830-
col("a")
3846+
lit("a")
38313847
);
38323848

3833-
// CASE WHEN true THEN col("a") WHEN col("x") > 5 THEN col("b") ELSE col("c") END --> col("a")
3849+
// CASE WHEN true THEN col('a') WHEN col('x') > 5 THEN col('b') ELSE col('c') END --> col('a')
38343850
assert_eq!(
38353851
simplify(Expr::Case(Case::new(
38363852
None,
38373853
vec![
3838-
(Box::new(lit(true)), Box::new(col("a"))),
3839-
(Box::new(col("x").gt(lit(5))), Box::new(col("b"))),
3854+
(Box::new(lit(true)), Box::new(lit("a"))),
3855+
(Box::new(lit("x").gt(lit(5))), Box::new(lit("b"))),
38403856
],
3841-
Some(Box::new(col("c"))),
3857+
Some(Box::new(lit("c"))),
38423858
))),
3843-
col("a")
3859+
lit("a")
38443860
);
38453861

3846-
// CASE WHEN true THEN col("a") END --> col("a") (no else clause)
3862+
// CASE WHEN true THEN col('a') END --> col('a') (no else clause)
38473863
assert_eq!(
38483864
simplify(Expr::Case(Case::new(
38493865
None,
3850-
vec![(Box::new(lit(true)), Box::new(col("a")),)],
3866+
vec![(Box::new(lit(true)), Box::new(lit("a")),)],
38513867
None,
38523868
))),
3853-
col("a")
3869+
lit("a")
38543870
);
38553871

3856-
// Negative test: CASE WHEN a THEN 1 ELSE 2 END should not be simplified
3872+
// Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified
38573873
let expr = Expr::Case(Case::new(
38583874
None,
3859-
vec![(Box::new(col("a")), Box::new(lit(1)))],
3875+
vec![(Box::new(col("c2")), Box::new(lit(1)))],
38603876
Some(Box::new(lit(2))),
38613877
));
38623878
assert_eq!(simplify(expr.clone()), expr);
@@ -3869,87 +3885,135 @@ mod tests {
38693885
));
38703886
assert_ne!(simplify(expr), lit(1));
38713887

3872-
// Negative test: CASE WHEN col("x") > 5 THEN 1 ELSE 2 END should not be simplified
3888+
// Negative test: CASE WHEN col('c1') > 5 THEN 1 ELSE 2 END should not be simplified
38733889
let expr = Expr::Case(Case::new(
38743890
None,
3875-
vec![(Box::new(col("x").gt(lit(5))), Box::new(lit(1)))],
3891+
vec![(Box::new(col("c1").gt(lit(5))), Box::new(lit(1)))],
38763892
Some(Box::new(lit(2))),
38773893
));
38783894
assert_eq!(simplify(expr.clone()), expr);
38793895
}
38803896

38813897
#[test]
38823898
fn simplify_expr_case_when_any_true() {
3883-
// CASE WHEN x > 0 THEN a WHEN true THEN b ELSE c END --> CASE WHEN x > 0 THEN a ELSE b END
3899+
// CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END
38843900
assert_eq!(
38853901
simplify(Expr::Case(Case::new(
38863902
None,
38873903
vec![
3888-
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
3889-
(Box::new(lit(true)), Box::new(col("b"))),
3904+
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3905+
(Box::new(lit(true)), Box::new(lit("b"))),
38903906
],
3891-
Some(Box::new(col("c"))),
3907+
Some(Box::new(lit("c"))),
38923908
))),
38933909
Expr::Case(Case::new(
38943910
None,
3895-
vec![(Box::new(col("x").gt(lit(0))), Box::new(col("a")))],
3896-
Some(Box::new(col("b"))),
3911+
vec![(Box::new(col("c3").gt(lit(0))), Box::new(lit("a")))],
3912+
Some(Box::new(lit("b"))),
38973913
))
38983914
);
38993915

3900-
// CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c WHEN z = 0 THEN d ELSE e END
3901-
// --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
3916+
// CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END
3917+
// --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END
39023918
assert_eq!(
39033919
simplify(Expr::Case(Case::new(
39043920
None,
39053921
vec![
3906-
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
3907-
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
3908-
(Box::new(lit(true)), Box::new(col("c"))),
3909-
(Box::new(col("z").eq(lit(0))), Box::new(col("d"))),
3922+
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3923+
(Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))),
3924+
(Box::new(lit(true)), Box::new(lit("c"))),
3925+
(Box::new(col("c3").eq(lit(0))), Box::new(lit("d"))),
39103926
],
3911-
Some(Box::new(col("e"))),
3927+
Some(Box::new(lit("e"))),
39123928
))),
39133929
Expr::Case(Case::new(
39143930
None,
39153931
vec![
3916-
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
3917-
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
3932+
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3933+
(Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))),
39183934
],
3919-
Some(Box::new(col("c"))),
3935+
Some(Box::new(lit("c"))),
39203936
))
39213937
);
39223938

3923-
// CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c END (no else)
3924-
// --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
3939+
// CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else)
3940+
// --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END
39253941
assert_eq!(
39263942
simplify(Expr::Case(Case::new(
39273943
None,
39283944
vec![
3929-
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
3930-
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
3931-
(Box::new(lit(true)), Box::new(col("c"))),
3945+
(Box::new(col("c3").gt(lit(0))), Box::new(lit(1))),
3946+
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
3947+
(Box::new(lit(true)), Box::new(lit(3))),
39323948
],
39333949
None,
39343950
))),
39353951
Expr::Case(Case::new(
39363952
None,
39373953
vec![
3938-
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
3939-
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
3954+
(Box::new(col("c3").gt(lit(0))), Box::new(lit(1))),
3955+
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
39403956
],
3941-
Some(Box::new(col("c"))),
3957+
Some(Box::new(lit(3))),
39423958
))
39433959
);
39443960

3945-
// Negative test: CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END should not be simplified
3961+
// Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified
39463962
let expr = Expr::Case(Case::new(
39473963
None,
39483964
vec![
3949-
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
3950-
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
3965+
(Box::new(col("c3").gt(lit(0))), Box::new(col("c3"))),
3966+
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
39513967
],
3952-
Some(Box::new(col("c"))),
3968+
Some(Box::new(lit(3))),
3969+
));
3970+
assert_eq!(simplify(expr.clone()), expr);
3971+
}
3972+
3973+
#[test]
3974+
fn simplify_expr_case_when_any_false() {
3975+
// CASE WHEN false THEN 'a' END --> NULL
3976+
assert_eq!(
3977+
simplify(Expr::Case(Case::new(
3978+
None,
3979+
vec![(Box::new(lit(false)), Box::new(lit("a")))],
3980+
None,
3981+
))),
3982+
Expr::Literal(ScalarValue::Utf8(None), None)
3983+
);
3984+
3985+
// CASE WHEN false THEN 2 ELSE 1 END --> 1
3986+
assert_eq!(
3987+
simplify(Expr::Case(Case::new(
3988+
None,
3989+
vec![(Box::new(lit(false)), Box::new(lit(2)))],
3990+
Some(Box::new(lit(1))),
3991+
))),
3992+
lit(1),
3993+
);
3994+
3995+
// CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END
3996+
assert_eq!(
3997+
simplify(Expr::Case(Case::new(
3998+
None,
3999+
vec![
4000+
(Box::new(col("c3").lt(lit(10))), Box::new(lit("b"))),
4001+
(Box::new(lit(false)), Box::new(col("c3"))),
4002+
],
4003+
Some(Box::new(col("c4"))),
4004+
))),
4005+
Expr::Case(Case::new(
4006+
None,
4007+
vec![(Box::new(col("c3").lt(lit(10))), Box::new(lit("b")))],
4008+
Some(Box::new(col("c4"))),
4009+
))
4010+
);
4011+
4012+
// Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified
4013+
let expr = Expr::Case(Case::new(
4014+
None,
4015+
vec![(Box::new(col("c3").eq(lit(4))), Box::new(lit(1)))],
4016+
Some(Box::new(lit(2))),
39534017
));
39544018
assert_eq!(simplify(expr.clone()), expr);
39554019
}

0 commit comments

Comments
 (0)