@@ -1436,33 +1436,49 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
1436
1436
1437
1437
// CASE WHEN true THEN A ... END --> A
1438
1438
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1439
+ // CASE WHEN false THEN A END --> NULL
1440
+ // CASE WHEN false THEN A ELSE B END --> B
1441
+ // CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END
1439
1442
Expr :: Case ( Case {
1440
1443
expr : None ,
1441
- mut when_then_expr,
1442
- else_expr : _,
1443
- // if let guard is not stabilized so we can't use it yet: https://github.com/rust-lang/rust/issues/51114
1444
- // Once it's supported we can avoid searching through when_then_expr twice in the below .any() and .position() calls
1445
- // }) if let Some(i) = when_then_expr.iter().position(|(when, _)| is_true(when.as_ref())) => {
1444
+ when_then_expr,
1445
+ mut else_expr,
1446
1446
} ) if when_then_expr
1447
1447
. iter ( )
1448
- . any ( |( when, _) | is_true ( when. as_ref ( ) ) ) =>
1448
+ . any ( |( when, _) | is_true ( when. as_ref ( ) ) || is_false ( when . as_ref ( ) ) ) =>
1449
1449
{
1450
- let i = when_then_expr
1451
- . iter ( )
1452
- . position ( |( when, _) | is_true ( when. as_ref ( ) ) )
1453
- . unwrap ( ) ;
1454
- let ( _, then_) = when_then_expr. swap_remove ( i) ;
1455
- // CASE WHEN true THEN A ... END --> A
1456
- if i == 0 {
1457
- return Ok ( Transformed :: yes ( * then_) ) ;
1450
+ let out_type = info. get_data_type ( & when_then_expr[ 0 ] . 1 ) ?;
1451
+ let mut new_when_then_expr = Vec :: with_capacity ( when_then_expr. len ( ) ) ;
1452
+
1453
+ for ( when, then) in when_then_expr. into_iter ( ) {
1454
+ if is_true ( when. as_ref ( ) ) {
1455
+ // Skip adding the rest of the when-then expressions after WHEN true
1456
+ // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1457
+ else_expr = Some ( then) ;
1458
+ break ;
1459
+ } else if !is_false ( when. as_ref ( ) ) {
1460
+ new_when_then_expr. push ( ( when, then) ) ;
1461
+ }
1462
+ // else: skip WHEN false cases
1463
+ }
1464
+
1465
+ // Exclude CASE statement altogether if there are no when-then expressions left
1466
+ if new_when_then_expr. is_empty ( ) {
1467
+ // CASE WHEN false THEN A ELSE B END --> B
1468
+ if let Some ( else_expr) = else_expr {
1469
+ return Ok ( Transformed :: yes ( * else_expr) ) ;
1470
+ // CASE WHEN false THEN A END --> NULL
1471
+ } else {
1472
+ let null =
1473
+ Expr :: Literal ( ScalarValue :: try_new_null ( & out_type) ?, None ) ;
1474
+ return Ok ( Transformed :: yes ( null) ) ;
1475
+ }
1458
1476
}
1459
1477
1460
- // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1461
- when_then_expr. truncate ( i) ;
1462
1478
Transformed :: yes ( Expr :: Case ( Case {
1463
1479
expr : None ,
1464
- when_then_expr,
1465
- else_expr : Some ( then_ ) ,
1480
+ when_then_expr : new_when_then_expr ,
1481
+ else_expr,
1466
1482
} ) )
1467
1483
}
1468
1484
@@ -3810,53 +3826,53 @@ mod tests {
3810
3826
3811
3827
#[ test]
3812
3828
fn simplify_expr_case_when_first_true ( ) {
3813
- // CASE WHEN true THEN 1 ELSE x END --> 1
3829
+ // CASE WHEN true THEN 1 ELSE c1 END --> 1
3814
3830
assert_eq ! (
3815
3831
simplify( Expr :: Case ( Case :: new(
3816
3832
None ,
3817
3833
vec![ ( Box :: new( lit( true ) ) , Box :: new( lit( 1 ) ) , ) ] ,
3818
- Some ( Box :: new( col( "x " ) ) ) ,
3834
+ Some ( Box :: new( col( "c1 " ) ) ) ,
3819
3835
) ) ) ,
3820
3836
lit( 1 )
3821
3837
) ;
3822
3838
3823
- // CASE WHEN true THEN col("a" ) ELSE col("b" ) END --> col("a" )
3839
+ // CASE WHEN true THEN col('a' ) ELSE col('b' ) END --> col('a' )
3824
3840
assert_eq ! (
3825
3841
simplify( Expr :: Case ( Case :: new(
3826
3842
None ,
3827
- vec![ ( Box :: new( lit( true ) ) , Box :: new( col ( "a" ) ) , ) ] ,
3828
- Some ( Box :: new( col ( "b" ) ) ) ,
3843
+ vec![ ( Box :: new( lit( true ) ) , Box :: new( lit ( "a" ) ) , ) ] ,
3844
+ Some ( Box :: new( lit ( "b" ) ) ) ,
3829
3845
) ) ) ,
3830
- col ( "a" )
3846
+ lit ( "a" )
3831
3847
) ;
3832
3848
3833
- // CASE WHEN true THEN col("a" ) WHEN col("x" ) > 5 THEN col("b" ) ELSE col("c" ) END --> col("a" )
3849
+ // CASE WHEN true THEN col('a' ) WHEN col('x' ) > 5 THEN col('b' ) ELSE col('c' ) END --> col('a' )
3834
3850
assert_eq ! (
3835
3851
simplify( Expr :: Case ( Case :: new(
3836
3852
None ,
3837
3853
vec![
3838
- ( Box :: new( lit( true ) ) , Box :: new( col ( "a" ) ) ) ,
3839
- ( Box :: new( col ( "x" ) . gt( lit( 5 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3854
+ ( Box :: new( lit( true ) ) , Box :: new( lit ( "a" ) ) ) ,
3855
+ ( Box :: new( lit ( "x" ) . gt( lit( 5 ) ) ) , Box :: new( lit ( "b" ) ) ) ,
3840
3856
] ,
3841
- Some ( Box :: new( col ( "c" ) ) ) ,
3857
+ Some ( Box :: new( lit ( "c" ) ) ) ,
3842
3858
) ) ) ,
3843
- col ( "a" )
3859
+ lit ( "a" )
3844
3860
) ;
3845
3861
3846
- // CASE WHEN true THEN col("a" ) END --> col("a" ) (no else clause)
3862
+ // CASE WHEN true THEN col('a' ) END --> col('a' ) (no else clause)
3847
3863
assert_eq ! (
3848
3864
simplify( Expr :: Case ( Case :: new(
3849
3865
None ,
3850
- vec![ ( Box :: new( lit( true ) ) , Box :: new( col ( "a" ) ) , ) ] ,
3866
+ vec![ ( Box :: new( lit( true ) ) , Box :: new( lit ( "a" ) ) , ) ] ,
3851
3867
None ,
3852
3868
) ) ) ,
3853
- col ( "a" )
3869
+ lit ( "a" )
3854
3870
) ;
3855
3871
3856
- // Negative test: CASE WHEN a THEN 1 ELSE 2 END should not be simplified
3872
+ // Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified
3857
3873
let expr = Expr :: Case ( Case :: new (
3858
3874
None ,
3859
- vec ! [ ( Box :: new( col( "a " ) ) , Box :: new( lit( 1 ) ) ) ] ,
3875
+ vec ! [ ( Box :: new( col( "c2 " ) ) , Box :: new( lit( 1 ) ) ) ] ,
3860
3876
Some ( Box :: new ( lit ( 2 ) ) ) ,
3861
3877
) ) ;
3862
3878
assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
@@ -3869,87 +3885,135 @@ mod tests {
3869
3885
) ) ;
3870
3886
assert_ne ! ( simplify( expr) , lit( 1 ) ) ;
3871
3887
3872
- // Negative test: CASE WHEN col("x" ) > 5 THEN 1 ELSE 2 END should not be simplified
3888
+ // Negative test: CASE WHEN col('c1' ) > 5 THEN 1 ELSE 2 END should not be simplified
3873
3889
let expr = Expr :: Case ( Case :: new (
3874
3890
None ,
3875
- vec ! [ ( Box :: new( col( "x " ) . gt( lit( 5 ) ) ) , Box :: new( lit( 1 ) ) ) ] ,
3891
+ vec ! [ ( Box :: new( col( "c1 " ) . gt( lit( 5 ) ) ) , Box :: new( lit( 1 ) ) ) ] ,
3876
3892
Some ( Box :: new ( lit ( 2 ) ) ) ,
3877
3893
) ) ;
3878
3894
assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
3879
3895
}
3880
3896
3881
3897
#[ test]
3882
3898
fn simplify_expr_case_when_any_true ( ) {
3883
- // CASE WHEN x > 0 THEN a WHEN true THEN b ELSE c END --> CASE WHEN x > 0 THEN a ELSE b END
3899
+ // CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END
3884
3900
assert_eq ! (
3885
3901
simplify( Expr :: Case ( Case :: new(
3886
3902
None ,
3887
3903
vec![
3888
- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3889
- ( Box :: new( lit( true ) ) , Box :: new( col ( "b" ) ) ) ,
3904
+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( "a" ) ) ) ,
3905
+ ( Box :: new( lit( true ) ) , Box :: new( lit ( "b" ) ) ) ,
3890
3906
] ,
3891
- Some ( Box :: new( col ( "c" ) ) ) ,
3907
+ Some ( Box :: new( lit ( "c" ) ) ) ,
3892
3908
) ) ) ,
3893
3909
Expr :: Case ( Case :: new(
3894
3910
None ,
3895
- vec![ ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ] ,
3896
- Some ( Box :: new( col ( "b" ) ) ) ,
3911
+ vec![ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( "a" ) ) ) ] ,
3912
+ Some ( Box :: new( lit ( "b" ) ) ) ,
3897
3913
) )
3898
3914
) ;
3899
3915
3900
- // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c WHEN z = 0 THEN d ELSE e END
3901
- // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
3916
+ // CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END
3917
+ // --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END
3902
3918
assert_eq ! (
3903
3919
simplify( Expr :: Case ( Case :: new(
3904
3920
None ,
3905
3921
vec![
3906
- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3907
- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3908
- ( Box :: new( lit( true ) ) , Box :: new( col ( "c" ) ) ) ,
3909
- ( Box :: new( col( "z " ) . eq( lit( 0 ) ) ) , Box :: new( col ( "d" ) ) ) ,
3922
+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( "a" ) ) ) ,
3923
+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( "b" ) ) ) ,
3924
+ ( Box :: new( lit( true ) ) , Box :: new( lit ( "c" ) ) ) ,
3925
+ ( Box :: new( col( "c3 " ) . eq( lit( 0 ) ) ) , Box :: new( lit ( "d" ) ) ) ,
3910
3926
] ,
3911
- Some ( Box :: new( col ( "e" ) ) ) ,
3927
+ Some ( Box :: new( lit ( "e" ) ) ) ,
3912
3928
) ) ) ,
3913
3929
Expr :: Case ( Case :: new(
3914
3930
None ,
3915
3931
vec![
3916
- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3917
- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3932
+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( "a" ) ) ) ,
3933
+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( "b" ) ) ) ,
3918
3934
] ,
3919
- Some ( Box :: new( col ( "c" ) ) ) ,
3935
+ Some ( Box :: new( lit ( "c" ) ) ) ,
3920
3936
) )
3921
3937
) ;
3922
3938
3923
- // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c END (no else)
3924
- // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
3939
+ // CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else)
3940
+ // --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END
3925
3941
assert_eq ! (
3926
3942
simplify( Expr :: Case ( Case :: new(
3927
3943
None ,
3928
3944
vec![
3929
- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3930
- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3931
- ( Box :: new( lit( true ) ) , Box :: new( col ( "c" ) ) ) ,
3945
+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( 1 ) ) ) ,
3946
+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( 2 ) ) ) ,
3947
+ ( Box :: new( lit( true ) ) , Box :: new( lit ( 3 ) ) ) ,
3932
3948
] ,
3933
3949
None ,
3934
3950
) ) ) ,
3935
3951
Expr :: Case ( Case :: new(
3936
3952
None ,
3937
3953
vec![
3938
- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3939
- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3954
+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( 1 ) ) ) ,
3955
+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( 2 ) ) ) ,
3940
3956
] ,
3941
- Some ( Box :: new( col ( "c" ) ) ) ,
3957
+ Some ( Box :: new( lit ( 3 ) ) ) ,
3942
3958
) )
3943
3959
) ;
3944
3960
3945
- // Negative test: CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END should not be simplified
3961
+ // Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified
3946
3962
let expr = Expr :: Case ( Case :: new (
3947
3963
None ,
3948
3964
vec ! [
3949
- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col( "a " ) ) ) ,
3950
- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3965
+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( col( "c3 " ) ) ) ,
3966
+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( 2 ) ) ) ,
3951
3967
] ,
3952
- Some ( Box :: new ( col ( "c" ) ) ) ,
3968
+ Some ( Box :: new ( lit ( 3 ) ) ) ,
3969
+ ) ) ;
3970
+ assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
3971
+ }
3972
+
3973
+ #[ test]
3974
+ fn simplify_expr_case_when_any_false ( ) {
3975
+ // CASE WHEN false THEN 'a' END --> NULL
3976
+ assert_eq ! (
3977
+ simplify( Expr :: Case ( Case :: new(
3978
+ None ,
3979
+ vec![ ( Box :: new( lit( false ) ) , Box :: new( lit( "a" ) ) ) ] ,
3980
+ None ,
3981
+ ) ) ) ,
3982
+ Expr :: Literal ( ScalarValue :: Utf8 ( None ) , None )
3983
+ ) ;
3984
+
3985
+ // CASE WHEN false THEN 2 ELSE 1 END --> 1
3986
+ assert_eq ! (
3987
+ simplify( Expr :: Case ( Case :: new(
3988
+ None ,
3989
+ vec![ ( Box :: new( lit( false ) ) , Box :: new( lit( 2 ) ) ) ] ,
3990
+ Some ( Box :: new( lit( 1 ) ) ) ,
3991
+ ) ) ) ,
3992
+ lit( 1 ) ,
3993
+ ) ;
3994
+
3995
+ // CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END
3996
+ assert_eq ! (
3997
+ simplify( Expr :: Case ( Case :: new(
3998
+ None ,
3999
+ vec![
4000
+ ( Box :: new( col( "c3" ) . lt( lit( 10 ) ) ) , Box :: new( lit( "b" ) ) ) ,
4001
+ ( Box :: new( lit( false ) ) , Box :: new( col( "c3" ) ) ) ,
4002
+ ] ,
4003
+ Some ( Box :: new( col( "c4" ) ) ) ,
4004
+ ) ) ) ,
4005
+ Expr :: Case ( Case :: new(
4006
+ None ,
4007
+ vec![ ( Box :: new( col( "c3" ) . lt( lit( 10 ) ) ) , Box :: new( lit( "b" ) ) ) ] ,
4008
+ Some ( Box :: new( col( "c4" ) ) ) ,
4009
+ ) )
4010
+ ) ;
4011
+
4012
+ // Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified
4013
+ let expr = Expr :: Case ( Case :: new (
4014
+ None ,
4015
+ vec ! [ ( Box :: new( col( "c3" ) . eq( lit( 4 ) ) ) , Box :: new( lit( 1 ) ) ) ] ,
4016
+ Some ( Box :: new ( lit ( 2 ) ) ) ,
3953
4017
) ) ;
3954
4018
assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
3955
4019
}
0 commit comments