Skip to content

Commit

Permalink
Backport fixes avoiding to CSE volatile and short-cutting expressions…
Browse files Browse the repository at this point in the history
… to v34 (#227)

* fix: don't extract common sub expr in `CASE WHEN` clause (apache#8833)

* fix: don't extract common sub expr in CASE WHEN clause

* fix ci

* fix

* fix: common_subexpr_eliminate rule should not apply to short-circuit expression (apache#8928)

* fix: common_subexpr_eliminate rule should not apply to short-circuit expression

* add more tests

* format

* minor

* apply reviews

* add some commont

* fmt

---------

Co-authored-by: Huaijin <[email protected]>
  • Loading branch information
joroKr21 and haohuaijin authored Apr 4, 2024
1 parent d88e414 commit 587db3b
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 64 deletions.
48 changes: 48 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,54 @@ impl Expr {
Ok(Transformed::Yes(expr))
})
}

/// Returns true if some of this `exprs` subexpressions may not be evaluated
/// and thus any side effects (like divide by zero) may not be encountered
pub fn short_circuits(&self) -> bool {
match self {
Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce)
}
Expr::BinaryExpr(BinaryExpr { op, .. }) => {
matches!(op, Operator::And | Operator::Or)
}
Expr::Case { .. } => true,
// Use explicit pattern match instead of a default
// implementation, so that in the future if someone adds
// new Expr types, they will check here as well
Expr::AggregateFunction(..)
| Expr::Alias(..)
| Expr::Between(..)
| Expr::Cast(..)
| Expr::Column(..)
| Expr::Exists(..)
| Expr::GetIndexedField(..)
| Expr::GroupingSet(..)
| Expr::InList(..)
| Expr::InSubquery(..)
| Expr::IsFalse(..)
| Expr::IsNotFalse(..)
| Expr::IsNotNull(..)
| Expr::IsNotTrue(..)
| Expr::IsNotUnknown(..)
| Expr::IsNull(..)
| Expr::IsTrue(..)
| Expr::IsUnknown(..)
| Expr::Like(..)
| Expr::ScalarSubquery(..)
| Expr::ScalarVariable(_, _)
| Expr::SimilarTo(..)
| Expr::Not(..)
| Expr::Negative(..)
| Expr::OuterReferenceColumn(_, _)
| Expr::TryCast(..)
| Expr::Wildcard { .. }
| Expr::WindowFunction(..)
| Expr::Literal(..)
| Expr::Sort(..)
| Expr::Placeholder(..) => false,
}
}
}

// modifies expr if it is a placeholder with datatype of right
Expand Down
37 changes: 23 additions & 14 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

use crate::utils::is_volatile_expression;
use crate::{utils, OptimizerConfig, OptimizerRule};

use arrow::datatypes::DataType;
Expand All @@ -29,7 +30,7 @@ use datafusion_common::tree_node::{
use datafusion_common::{
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::{is_volatile, Alias};
use datafusion_expr::expr::Alias;
use datafusion_expr::logical_plan::{
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
};
Expand Down Expand Up @@ -518,7 +519,7 @@ enum ExprMask {
}

impl ExprMask {
fn ignores(&self, expr: &Expr) -> Result<bool> {
fn ignores(&self, expr: &Expr) -> bool {
let is_normal_minus_aggregates = matches!(
expr,
Expr::Literal(..)
Expand All @@ -529,14 +530,12 @@ impl ExprMask {
| Expr::Wildcard { .. }
);

let is_volatile = is_volatile(expr)?;

let is_aggr = matches!(expr, Expr::AggregateFunction(..));

Ok(match self {
Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates,
})
match self {
Self::Normal => is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_normal_minus_aggregates,
}
}
}

Expand Down Expand Up @@ -614,7 +613,12 @@ impl ExprIdentifierVisitor<'_> {
impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
type N = Expr;

fn pre_visit(&mut self, _expr: &Expr) -> Result<VisitRecursion> {
fn pre_visit(&mut self, expr: &Expr) -> Result<VisitRecursion> {
// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a short-circuit expression, skip it.
if expr.short_circuits() || is_volatile_expression(expr)? {
return Ok(VisitRecursion::Skip);
}
self.visit_stack
.push(VisitRecord::EnterMark(self.node_count));
self.node_count += 1;
Expand All @@ -628,7 +632,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {

let (idx, sub_expr_desc) = self.pop_enter_mark();
// skip exprs should not be recognize.
if self.expr_mask.ignores(expr)? {
if self.expr_mask.ignores(expr) {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
Expand Down Expand Up @@ -691,7 +695,13 @@ struct CommonSubexprRewriter<'a> {
impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
type N = Expr;

fn pre_visit(&mut self, _: &Expr) -> Result<RewriteRecursion> {
fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
if expr.short_circuits() || is_volatile_expression(expr)? {
return Ok(RewriteRecursion::Stop);
}
if self.curr_index >= self.id_array.len()
|| self.max_series_number > self.id_array[self.curr_index].0
{
Expand Down Expand Up @@ -1247,12 +1257,11 @@ mod test {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))?
.filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
.build()?;

let expected = "Projection: test.a, test.b, test.c\
\n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\
\n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\
\n Filter: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a - Int32(10) > Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand Down
69 changes: 23 additions & 46 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,28 @@
//! [`PushDownFilter`] Moves filters so they are applied as early as possible in
//! the plan.
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use itertools::Itertools;

use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::{
internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned};
use datafusion_expr::Volatility;
use datafusion_expr::{
and,
expr_rewriter::replace_col,
logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union},
or, BinaryExpr, Expr, Filter, Operator, ScalarFunctionDefinition,
TableProviderFilterPushDown,
};
use itertools::Itertools;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::optimizer::ApplyOrder;
use crate::utils::is_volatile_expression;
use crate::{OptimizerConfig, OptimizerRule};

/// Optimizer rule for pushing (moving) filter expressions down in a plan so
/// they are applied as early as possible.
Expand Down Expand Up @@ -734,7 +737,9 @@ impl OptimizerRule for PushDownFilter {

(field.qualified_name(), expr)
})
.partition(|(_, value)| is_volatile_expression(value));
.partition(|(_, value)| {
is_volatile_expression(value).unwrap_or(true)
});

let mut push_predicates = vec![];
let mut keep_predicates = vec![];
Expand Down Expand Up @@ -973,38 +978,6 @@ pub fn replace_cols_by_name(
})
}

/// check whether the expression is volatile predicates
fn is_volatile_expression(e: &Expr) -> bool {
let mut is_volatile = false;
e.apply(&mut |expr| {
Ok(match expr {
Expr::ScalarFunction(f) => match &f.func_def {
ScalarFunctionDefinition::BuiltIn(fun)
if fun.volatility() == Volatility::Volatile =>
{
is_volatile = true;
VisitRecursion::Stop
}
ScalarFunctionDefinition::UDF(fun)
if fun.signature().volatility == Volatility::Volatile =>
{
is_volatile = true;
VisitRecursion::Stop
}
ScalarFunctionDefinition::Name(_) => {
return internal_err!(
"Function `Expr` with name should be resolved."
);
}
_ => VisitRecursion::Continue,
},
_ => VisitRecursion::Continue,
})
})
.unwrap();
is_volatile
}

/// check whether the expression uses the columns in `check_map`.
fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
let mut is_contain = false;
Expand All @@ -1027,22 +1000,26 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {

#[cfg(test)]
mod tests {
use super::*;
use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;

use datafusion_common::{DFSchema, DFSchemaRef};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, random, sum,
BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource,
TableType, UserDefinedLogicalNodeCore,
};
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;

use super::*;

fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
crate::test::assert_optimized_plan_eq(
Expand Down
16 changes: 16 additions & 0 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
//! Collection of utility functions that are leveraged by the query optimizer rules
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::{Column, DFSchemaRef};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::is_volatile;
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::utils as expr_utils;
use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator};
Expand Down Expand Up @@ -92,6 +94,20 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) {
trace!("{description}::\n{}\n", plan.display_indent_schema());
}

/// check whether the expression is volatile predicates
pub(crate) fn is_volatile_expression(e: &Expr) -> Result<bool> {
let mut is_volatile_expr = false;
e.apply(&mut |expr| {
Ok(if is_volatile(expr)? {
is_volatile_expr = true;
VisitRecursion::Stop
} else {
VisitRecursion::Continue
})
})?;
Ok(is_volatile_expr)
}

/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
///
/// See [`split_conjunction_owned`] for more details and an example.
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,6 @@ NULL

# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away
query B
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0)
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0)
----
false
63 changes: 63 additions & 0 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1104,3 +1104,66 @@ SELECT abs(x), abs(x) + abs(y) FROM t;

statement ok
DROP TABLE t;

# related to https://github.com/apache/arrow-datafusion/issues/8814
statement ok
create table t(x int, y int) as values (1,1), (2,2), (3,3), (0,0), (4,0);

query II
SELECT
CASE WHEN B.x > 0 THEN A.x / B.x ELSE 0 END AS value1,
CASE WHEN B.x > 0 AND B.y > 0 THEN A.x / B.x ELSE 0 END AS value3
FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B;
----
0 0
0 0
0 0
0 0
0 0

# Expressions that short circuit should not be refactored out as that may cause side effects (divide by zero)
# at plan time that would not actually happen during execution, so the follow three query should not be extract
# the common sub-expression
query TT
explain select coalesce(1, y/x), coalesce(2, y/x) from t;
----
logical_plan
Projection: coalesce(Int64(1), CAST(t.y / t.x AS Int64)), coalesce(Int64(2), CAST(t.y / t.x AS Int64))
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[coalesce(1, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(1),t.y / t.x), coalesce(2, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(2),t.y / t.x)]
--MemoryExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t;
----
logical_plan
Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x]
--MemoryExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t;
----
logical_plan
Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x]
--MemoryExec: partitions=1, partition_sizes=[1]

# due to the reason describe in https://github.com/apache/arrow-datafusion/issues/8927,
# the following queries will fail
query error
select coalesce(1, y/x), coalesce(2, y/x) from t;

query error
SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t;

query error
SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t;

statement ok
DROP TABLE t;
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/tpch/q14.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ where
----
logical_plan
Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue
--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, part.p_type
--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
----Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_type
------Inner Join: lineitem.l_partkey = part.p_partkey
--------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount
----------Filter: lineitem.l_shipdate >= Date32("9374") AND lineitem.l_shipdate < Date32("9404")
Expand All @@ -45,7 +45,7 @@ ProjectionExec: expr=[100 * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%")
--AggregateExec: mode=Final, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
----CoalescePartitionsExec
------AggregateExec: mode=Partial, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
--------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, p_type@4 as p_type]
--------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, p_type@4 as p_type]
----------CoalesceBatchesExec: target_batch_size=8192
------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)]
--------------CoalesceBatchesExec: target_batch_size=8192
Expand Down

0 comments on commit 587db3b

Please sign in to comment.