Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 125 additions & 22 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,36 +618,56 @@ impl BinaryExpr {
}
}

/// Write a child expression of a binary expression, adding parentheses
/// when needed based on operator precedence. For left-associative operators,
/// right children with equal precedence are parenthesized to preserve the
/// tree structure (e.g. `a - (b - c)` stays parenthesized).
fn write_binary_child(
f: &mut Formatter<'_>,
expr: &Expr,
precedence: u8,
is_right: bool,
fmt_expr: impl Fn(&Expr, &mut Formatter<'_>) -> fmt::Result,
) -> fmt::Result {
match expr {
Expr::BinaryExpr(child) => {
let p = child.op.precedence();
let child_has_lower_precedence = p < precedence;
let child_has_equal_precedence_on_right = is_right && p == precedence;
// p == 0 is currently unreachable since all Operator variants
// have non-zero precedence (see Operator::precedence() in
// expr-common/src/operator.rs), but kept as a defensive guard
// in case a new operator is added without assigning a precedence.
let needs_parens = p == 0
|| child_has_lower_precedence
|| child_has_equal_precedence_on_right;
if needs_parens {
write!(f, "(")?;
fmt_expr(expr, f)?;
write!(f, ")")
} else {
fmt_expr(expr, f)
}
}
_ => fmt_expr(expr, f),
}
}

impl Display for BinaryExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
// Put parentheses around child binary expressions so that we can see the difference
// between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
// based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
// equivalent and the parentheses are not necessary.

fn write_child(
f: &mut Formatter<'_>,
expr: &Expr,
precedence: u8,
) -> fmt::Result {
match expr {
Expr::BinaryExpr(child) => {
let p = child.op.precedence();
if p == 0 || p < precedence {
write!(f, "({child})")?;
} else {
write!(f, "{child}")?;
}
}
_ => write!(f, "{expr}")?,
}
Ok(())
}

let precedence = self.op.precedence();
write_child(f, self.left.as_ref(), precedence)?;
write_binary_child(f, self.left.as_ref(), precedence, false, |e, f| {
write!(f, "{e}")
})?;
write!(f, " {} ", self.op)?;
write_child(f, self.right.as_ref(), precedence)
write_binary_child(f, self.right.as_ref(), precedence, true, |e, f| {
write!(f, "{e}")
})
}
}

Expand Down Expand Up @@ -2858,7 +2878,14 @@ impl Display for SchemaDisplay<'_> {
}
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
write!(f, "{} {op} {}", SchemaDisplay(left), SchemaDisplay(right),)
let precedence = op.precedence();
write_binary_child(f, left.as_ref(), precedence, false, |e, f| {
write!(f, "{}", SchemaDisplay(e))
})?;
write!(f, " {op} ")?;
write_binary_child(f, right.as_ref(), precedence, true, |e, f| {
write!(f, "{}", SchemaDisplay(e))
})
}
Expr::Case(Case {
expr,
Expand Down Expand Up @@ -4043,6 +4070,82 @@ mod test {
);
}

// Helper: build a BinaryExpr from two Exprs and an Operator
fn bin(left: Expr, op: Operator, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
}

#[test]
fn test_binary_expr_parenthesization() {
use Operator::*;
let (a, b, c) = (col("a"), col("b"), col("c"));

// (expr, expected Display, expected SchemaDisplay)
let cases: Vec<(Expr, &str, &str)> = vec![
// Right child, equal precedence — needs parens (columns)
(
bin(a.clone(), Minus, bin(b.clone(), Minus, c.clone())),
"a - (b - c)",
"a - (b - c)",
),
// Right child, equal precedence, mixed ops (columns)
(
bin(a.clone(), Divide, bin(b.clone(), Multiply, c.clone())),
"a / (b * c)",
"a / (b * c)",
),
// Left child, equal precedence — no parens (columns)
(
bin(bin(a.clone(), Minus, b.clone()), Minus, c.clone()),
"a - b - c",
"a - b - c",
),
// Lower-precedence left child — needs parens (original issue #16054)
(
bin(bin(lit(1i64), Plus, lit(2i64)), Multiply, lit(3i64)),
"(Int64(1) + Int64(2)) * Int64(3)",
"(Int64(1) + Int64(2)) * Int64(3)",
),
// Right child, equal precedence — needs parens (literals, minus)
(
bin(lit(1i64), Minus, bin(lit(2i64), Minus, lit(3i64))),
"Int64(1) - (Int64(2) - Int64(3))",
"Int64(1) - (Int64(2) - Int64(3))",
),
// Left child, equal precedence — no parens (literals, minus)
(
bin(bin(lit(1i64), Minus, lit(2i64)), Minus, lit(3i64)),
"Int64(1) - Int64(2) - Int64(3)",
"Int64(1) - Int64(2) - Int64(3)",
),
// Right child, equal precedence — needs parens (literals, div/mul)
(
bin(lit(6i64), Divide, bin(lit(2i64), Multiply, lit(3i64))),
"Int64(6) / (Int64(2) * Int64(3))",
"Int64(6) / (Int64(2) * Int64(3))",
),
// Left child, equal precedence — no parens (literals, div/mul)
(
bin(bin(lit(6i64), Divide, lit(2i64)), Multiply, lit(3i64)),
"Int64(6) / Int64(2) * Int64(3)",
"Int64(6) / Int64(2) * Int64(3)",
),
];

for (expr, expected_display, expected_schema) in &cases {
assert_eq!(
format!("{expr}"),
*expected_display,
"Display mismatch for expected: {expected_display}"
);
assert_eq!(
format!("{}", SchemaDisplay(expr)),
*expected_schema,
"SchemaDisplay mismatch for expected: {expected_schema}"
);
}
}

fn wildcard_options(
opt_ilike: Option<IlikeSelectItem>,
opt_exclude: Option<ExcludeSelectItem>,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ mod test {
assert_optimized_plan_equal!(
plan,
@ r"
Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * (Int32(1) - test.b)), sum(__common_expr_1 AS test.a * (Int32(1) - test.b) * (Int32(1) + test.c))]]
Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ mod tests {
assert_optimized_plan_equal!(
plan,
@ r"
Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Filter: t1.a = t2.a AND t2.c < UInt32(15) OR (t1.a = t2.a OR t2.c = UInt32(688)) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
Expand Down Expand Up @@ -857,7 +857,7 @@ mod tests {
assert_optimized_plan_equal!(
plan,
@ r"
Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Filter: t3.a = t1.a AND t4.c < UInt32(15) OR (t3.a = t1.a OR t4.c = UInt32(688)) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Expand Down Expand Up @@ -937,7 +937,7 @@ mod tests {
Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR (t3.a = t4.a OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2247,7 +2247,7 @@ mod tests {
// not part of the test, just good to know:
assert_snapshot!(plan,
@r"
Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
Filter: sum(test.c) > Int64(10) AND (b > Int64(10) AND sum(test.c) < Int64(20))
Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
Projection: test.a AS b, test.c
TableScan: test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ mod tests {
.build()?;

let actual = get_optimized_plan_formatted(plan, &time);
let expected = "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\
let expected = "Projection: NOT test.a AS (Boolean(true) OR Boolean(false)) != test.a\
\n TableScan: test";

assert_eq!(expected, actual);
Expand Down