From 0325eaa0a5f0cc1d1259a3b2506c366bcf660300 Mon Sep 17 00:00:00 2001 From: yujun Date: Wed, 10 Jun 2026 09:01:25 +0800 Subject: [PATCH 1/2] [fix](nereids) Require injectivity for SimplifyAggGroupBy simplification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The rule simplified GROUP BY f(x) to GROUP BY x without verifying that f(x) is injective (one-to-one). This caused wrong results when: - literal is NULL (any op: a+NULL → always NULL) - Multiply/Divide by zero (a*0 → always 0) - Multiply/Divide with float/double operands (precision loss) Also adds proper handling of implicit lossless widening casts (integral→integral, float→double, integral→decimal, decimal→decimal) and removes the now-unused ExpressionUtils.extractSlotOrCastOnSlot. Co-Authored-By: Claude Opus 4.7 --- .../rules/rewrite/SimplifyAggGroupBy.java | 80 ++++++- .../doris/nereids/util/ExpressionUtils.java | 35 --- .../rules/rewrite/SimplifyAggGroupByTest.java | 216 ++++++++++++++++++ 3 files changed, 292 insertions(+), 39 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java index 37d4d4806f087a..77360835305dca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java @@ -19,15 +19,19 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.coercion.FractionalType; +import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; @@ -71,7 +75,7 @@ public Rule build() { } @VisibleForTesting - protected static boolean isBinaryArithmeticSlot(TreeNode expr) { + protected static boolean isBinaryArithmeticSlot(Expression expr) { if (expr instanceof Slot) { return true; } @@ -81,7 +85,75 @@ protected static boolean isBinaryArithmeticSlot(TreeNode expr) { if (!supportedFunctions.contains(expr.getClass())) { return false; } - return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal - || ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal; + + // Multiply / Divide: both sides must not be float/double + if (expr instanceof Multiply || expr instanceof Divide) { + if (expr.child(0).getDataType().isFloatLikeType() + || expr.child(1).getDataType().isFloatLikeType()) { + return false; + } + } + + Expression slotExpr; + Literal literal; + if (expr.child(0) instanceof Literal) { + literal = (Literal) expr.child(0); + slotExpr = expr.child(1); + } else if (expr.child(1) instanceof Literal) { + literal = (Literal) expr.child(1); + slotExpr = expr.child(0); + } else { + return false; + } + + if (!canExtractSlot(slotExpr)) { + return false; + } + + return checkLiteral(expr, literal); + } + + @VisibleForTesting + protected static boolean checkLiteral(Expression expr, Literal literal) { + if (literal.isNullLiteral()) { + return false; + } + if (expr instanceof Multiply || expr instanceof Divide) { + if (literal.isZero()) { + return false; + } + } + return true; + } + + @VisibleForTesting + protected static boolean canExtractSlot(Expression expr) { + while (expr instanceof Cast) { + Cast cast = (Cast) expr; + Expression inner = cast.child(); + if (!isLosslessWidening(inner.getDataType(), cast.getDataType())) { + return false; + } + expr = inner; + } + return expr instanceof Slot; + } + + @VisibleForTesting + protected static boolean isLosslessWidening(DataType src, DataType tgt) { + if (src instanceof IntegralType && tgt instanceof IntegralType) { + return src.width() <= tgt.width(); + } + if (src instanceof FractionalType && tgt instanceof FractionalType) { + return src.width() <= tgt.width(); + } + if (src.isDecimalLikeType() && tgt.isDecimalLikeType()) { + return DecimalV3Type.forType(src).getRange() <= DecimalV3Type.forType(tgt).getRange() + && DecimalV3Type.forType(src).getScale() <= DecimalV3Type.forType(tgt).getScale(); + } + if (src instanceof IntegralType && tgt.isDecimalLikeType()) { + return ((IntegralType) src).range() <= DecimalV3Type.forType(tgt).getRange(); + } + return false; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 53ccdb2d403447..12c1833f5ceb43 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -42,7 +42,6 @@ import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.IsNull; @@ -422,40 +421,6 @@ public static S selectMinimumColumn(Collection sl } /** - * Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot} - * or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot} - *

- * for example: - * - SlotReference to a column: - * col - * - Cast on SlotReference: - * cast(int_col as string) - * cast(cast(int_col as long) as string) - * - * @param expr input expression - * @return Return Optional[ExprId] of underlying slot reference if input expression is a slot or cast on slot. - * Otherwise, return empty optional result. - */ - public static Optional isSlotOrCastOnSlot(Expression expr) { - return extractSlotOrCastOnSlot(expr).map(Slot::getExprId); - } - - /** - * Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot} - * or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot} - */ - public static Optional extractSlotOrCastOnSlot(Expression expr) { - while (expr instanceof Cast) { - expr = expr.child(0); - } - - if (expr instanceof SlotReference) { - return Optional.of((Slot) expr); - } else { - return Optional.empty(); - } - } - /** * Generate replaceMap Slot -> Expression from NamedExpression[Expression as name] */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java index 32c2cc4356d048..572e82b298f69a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java @@ -18,18 +18,28 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Mod; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; @@ -156,4 +166,210 @@ void testisBinaryArithmeticSlot() { Divide divide = new Divide(id, Literal.of(2)); Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide)); } + + // ========== new tests for injectivity checks ========== + + @Test + void testMultiplyByZero() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(id, Literal.of(0)))); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(Literal.of(0), id))); + } + + @Test + void testDivideZeroNumerator() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Divide(Literal.of(0), id))); + } + + @Test + void testDivideByZero() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Divide(id, Literal.of(0)))); + } + + @Test + void testNullLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Add(id, NullLiteral.INSTANCE))); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(id, NullLiteral.INSTANCE))); + } + + @Test + void testMultiplyWithDoubleLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(id, new DoubleLiteral(0.1)))); + } + + @Test + void testDivideWithDoubleLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Divide(id, new DoubleLiteral(2.0)))); + } + + @Test + void testMultiplyWithFloatSlot() { + Slot floatSlot = new SlotReference("f", FloatType.INSTANCE); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(floatSlot, Literal.of(2)))); + } + + @Test + void testMultiplyDoubleSlotWithIntLiteral() { + Slot doubleSlot = new SlotReference("d", DoubleType.INSTANCE); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(doubleSlot, Literal.of(2)))); + } + + @Test + void testAddWithDoubleLiteral() { + // Add/Subtract should still be allowed with float/double + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Add(id, new DoubleLiteral(1.0)))); + } + + // ========== tests for isLosslessWidening ========== + + @Test + void testIntegerWidening() { + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + TinyIntType.INSTANCE, IntegerType.INSTANCE)); + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + IntegerType.INSTANCE, BigIntType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + IntegerType.INSTANCE, TinyIntType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + BigIntType.INSTANCE, IntegerType.INSTANCE)); + } + + @Test + void testFloatWidening() { + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + FloatType.INSTANCE, DoubleType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + DoubleType.INSTANCE, FloatType.INSTANCE)); + } + + @Test + void testDecimalWidening() { + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + DecimalV3Type.createDecimalV3Type(5, 2), + DecimalV3Type.createDecimalV3Type(10, 4))); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + DecimalV3Type.createDecimalV3Type(10, 4), + DecimalV3Type.createDecimalV3Type(5, 2))); + } + + @Test + void testIntegralToDecimalWidening() { + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + TinyIntType.INSTANCE, DecimalV3Type.createDecimalV3Type(10, 0))); + // BigInt has 19 digits, DECIMAL(5,0) only has 5 integer digits + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + BigIntType.INSTANCE, DecimalV3Type.createDecimalV3Type(5, 0))); + } + + @Test + void testCrossFamilyRejected() { + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + IntegerType.INSTANCE, FloatType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + FloatType.INSTANCE, IntegerType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + IntegerType.INSTANCE, DoubleType.INSTANCE)); + } + + // ========== tests for canExtractSlot ========== + + @Test + void testCanExtractSlotBare() { + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(id)); + } + + @Test + void testCanExtractSlotWidening() { + Slot id = scan1.getOutput().get(0); + // INT->BIGINT is lossless widening + Expression cast = new Cast(id, BigIntType.INSTANCE); + Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast)); + } + + @Test + void testCanExtractSlotExplicitCast() { + Slot id = scan1.getOutput().get(0); + // explicit cast should also be acceptable if lossless + Expression cast = new Cast(id, BigIntType.INSTANCE, true); + Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast)); + } + + @Test + void testCanExtractSlotNarrowing() { + Slot id = scan1.getOutput().get(0); + // INT -> TINYINT is narrowing, should be rejected + Expression cast = new Cast(id, TinyIntType.INSTANCE); + Assertions.assertFalse(SimplifyAggGroupBy.canExtractSlot(cast)); + } + + // ========== integration tests via PlanChecker ========== + + @Test + void testMultiplyByZeroNotSimplified() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of(id, new Count().alias("cnt")); + List groupBy = ImmutableList.of(id, new Multiply(id, Literal.of(0))); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + ConnectContext connectContext = MemoTestUtils.createConnectContext(); + connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false); + PlanChecker.from(connectContext, agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.equals(agg)) + ); + } + + @Test + void testNullLiteralNotSimplified() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of(id, new Count().alias("cnt")); + List groupBy = ImmutableList.of(id, new Add(id, NullLiteral.INSTANCE)); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + ConnectContext connectContext = MemoTestUtils.createConnectContext(); + connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false); + PlanChecker.from(connectContext, agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.equals(agg)) + ); + } + + @Test + void testMultiplyDoubleLiteralNotSimplified() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of(id, new Count().alias("cnt")); + List groupBy = ImmutableList.of(id, new Multiply(id, new DoubleLiteral(0.1))); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + ConnectContext connectContext = MemoTestUtils.createConnectContext(); + connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false); + PlanChecker.from(connectContext, agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.equals(agg)) + ); + } } From f5f80c17436cda102e27ff538f6431876ed2355d Mon Sep 17 00:00:00 2001 From: yujun Date: Wed, 10 Jun 2026 16:54:54 +0800 Subject: [PATCH 2/2] [fix](simplify agg) Extend float/double reject to all arithmetic ops Float/double precision loss affects Add/Subtract as well as Multiply/Divide (e.g., 1e16 + 1.0 = 1e16 in DOUBLE). Reject all float/double arithmetic uniformly. Decimal ops are allowed as precision overflow is too extreme to worry about in practice. Remove the now-unused FractionalType widening case. Co-Authored-By: Claude Opus 4.7 --- .../rules/rewrite/SimplifyAggGroupBy.java | 14 ++--- .../rules/rewrite/SimplifyAggGroupByTest.java | 53 +++++++++++++++---- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java index 77360835305dca..360e74fcfe0b10 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java @@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV3Type; -import org.apache.doris.nereids.types.coercion.FractionalType; import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; @@ -86,12 +85,10 @@ protected static boolean isBinaryArithmeticSlot(Expression expr) { return false; } - // Multiply / Divide: both sides must not be float/double - if (expr instanceof Multiply || expr instanceof Divide) { - if (expr.child(0).getDataType().isFloatLikeType() - || expr.child(1).getDataType().isFloatLikeType()) { - return false; - } + // Float/double arithmetic: precision loss for all operations + if (expr.child(0).getDataType().isFloatLikeType() + || expr.child(1).getDataType().isFloatLikeType()) { + return false; } Expression slotExpr; @@ -144,9 +141,6 @@ protected static boolean isLosslessWidening(DataType src, DataType tgt) { if (src instanceof IntegralType && tgt instanceof IntegralType) { return src.width() <= tgt.width(); } - if (src instanceof FractionalType && tgt instanceof FractionalType) { - return src.width() <= tgt.width(); - } if (src.isDecimalLikeType() && tgt.isDecimalLikeType()) { return DecimalV3Type.forType(src).getRange() <= DecimalV3Type.forType(tgt).getRange() && DecimalV3Type.forType(src).getScale() <= DecimalV3Type.forType(tgt).getScale(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java index 572e82b298f69a..da27684e04fd16 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java @@ -29,7 +29,9 @@ import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; @@ -51,6 +53,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; import java.util.List; class SimplifyAggGroupByTest implements MemoPatternMatchSupported { @@ -231,12 +234,50 @@ void testMultiplyDoubleSlotWithIntLiteral() { @Test void testAddWithDoubleLiteral() { - // Add/Subtract should still be allowed with float/double + // Float/double arithmetic may be imprecise, reject for all ops Slot id = scan1.getOutput().get(0); - Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot( + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( new Add(id, new DoubleLiteral(1.0)))); } + @Test + void testAddWithFloatLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Add(id, new FloatLiteral(1.0f)))); + } + + @Test + void testSubtractWithDoubleLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Subtract(id, new DoubleLiteral(1.0)))); + } + + @Test + void testMultiplyWithDecimalLiteral() { + // Small decimal multiply should pass (precision fits) + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(id, new DecimalLiteral(new BigDecimal("2.0"))))); + } + + @Test + void testDivideWithDecimalLiteral() { + // Divide with decimal: precision overflow too extreme to worry about + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Divide(id, new DecimalLiteral(new BigDecimal("2.0"))))); + } + + @Test + void testAddWithDecimalLiteral() { + // Add/Subtract with decimal are exact, should pass + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Add(id, new DecimalLiteral(new BigDecimal("1.0"))))); + } + // ========== tests for isLosslessWidening ========== @Test @@ -251,14 +292,6 @@ void testIntegerWidening() { BigIntType.INSTANCE, IntegerType.INSTANCE)); } - @Test - void testFloatWidening() { - Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( - FloatType.INSTANCE, DoubleType.INSTANCE)); - Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( - DoubleType.INSTANCE, FloatType.INSTANCE)); - } - @Test void testDecimalWidening() { Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening(