diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java index 37d4d4806f087a..360e74fcfe0b10 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java @@ -19,15 +19,18 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; @@ -71,7 +74,7 @@ public Rule build() { } @VisibleForTesting - protected static boolean isBinaryArithmeticSlot(TreeNode expr) { + protected static boolean isBinaryArithmeticSlot(Expression expr) { if (expr instanceof Slot) { return true; } @@ -81,7 +84,70 @@ protected static boolean isBinaryArithmeticSlot(TreeNode expr) { if (!supportedFunctions.contains(expr.getClass())) { return false; } - return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal - || ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal; + + // Float/double arithmetic: precision loss for all operations + if (expr.child(0).getDataType().isFloatLikeType() + || expr.child(1).getDataType().isFloatLikeType()) { + return false; + } + + Expression slotExpr; + Literal literal; + if (expr.child(0) instanceof Literal) { + literal = (Literal) expr.child(0); + slotExpr = expr.child(1); + } else if (expr.child(1) instanceof Literal) { + literal = (Literal) expr.child(1); + slotExpr = expr.child(0); + } else { + return false; + } + + if (!canExtractSlot(slotExpr)) { + return false; + } + + return checkLiteral(expr, literal); + } + + @VisibleForTesting + protected static boolean checkLiteral(Expression expr, Literal literal) { + if (literal.isNullLiteral()) { + return false; + } + if (expr instanceof Multiply || expr instanceof Divide) { + if (literal.isZero()) { + return false; + } + } + return true; + } + + @VisibleForTesting + protected static boolean canExtractSlot(Expression expr) { + while (expr instanceof Cast) { + Cast cast = (Cast) expr; + Expression inner = cast.child(); + if (!isLosslessWidening(inner.getDataType(), cast.getDataType())) { + return false; + } + expr = inner; + } + return expr instanceof Slot; + } + + @VisibleForTesting + protected static boolean isLosslessWidening(DataType src, DataType tgt) { + if (src instanceof IntegralType && tgt instanceof IntegralType) { + return src.width() <= tgt.width(); + } + if (src.isDecimalLikeType() && tgt.isDecimalLikeType()) { + return DecimalV3Type.forType(src).getRange() <= DecimalV3Type.forType(tgt).getRange() + && DecimalV3Type.forType(src).getScale() <= DecimalV3Type.forType(tgt).getScale(); + } + if (src instanceof IntegralType && tgt.isDecimalLikeType()) { + return ((IntegralType) src).range() <= DecimalV3Type.forType(tgt).getRange(); + } + return false; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 53ccdb2d403447..12c1833f5ceb43 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -42,7 +42,6 @@ import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.IsNull; @@ -422,40 +421,6 @@ public static S selectMinimumColumn(Collection sl } /** - * Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot} - * or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot} - *

- * for example: - * - SlotReference to a column: - * col - * - Cast on SlotReference: - * cast(int_col as string) - * cast(cast(int_col as long) as string) - * - * @param expr input expression - * @return Return Optional[ExprId] of underlying slot reference if input expression is a slot or cast on slot. - * Otherwise, return empty optional result. - */ - public static Optional isSlotOrCastOnSlot(Expression expr) { - return extractSlotOrCastOnSlot(expr).map(Slot::getExprId); - } - - /** - * Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot} - * or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot} - */ - public static Optional extractSlotOrCastOnSlot(Expression expr) { - while (expr instanceof Cast) { - expr = expr.child(0); - } - - if (expr instanceof SlotReference) { - return Optional.of((Slot) expr); - } else { - return Optional.empty(); - } - } - /** * Generate replaceMap Slot -> Expression from NamedExpression[Expression as name] */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java index 32c2cc4356d048..da27684e04fd16 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java @@ -18,18 +18,30 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Mod; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; @@ -41,6 +53,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; import java.util.List; class SimplifyAggGroupByTest implements MemoPatternMatchSupported { @@ -156,4 +169,240 @@ void testisBinaryArithmeticSlot() { Divide divide = new Divide(id, Literal.of(2)); Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide)); } + + // ========== new tests for injectivity checks ========== + + @Test + void testMultiplyByZero() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(id, Literal.of(0)))); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(Literal.of(0), id))); + } + + @Test + void testDivideZeroNumerator() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Divide(Literal.of(0), id))); + } + + @Test + void testDivideByZero() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Divide(id, Literal.of(0)))); + } + + @Test + void testNullLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Add(id, NullLiteral.INSTANCE))); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(id, NullLiteral.INSTANCE))); + } + + @Test + void testMultiplyWithDoubleLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(id, new DoubleLiteral(0.1)))); + } + + @Test + void testDivideWithDoubleLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Divide(id, new DoubleLiteral(2.0)))); + } + + @Test + void testMultiplyWithFloatSlot() { + Slot floatSlot = new SlotReference("f", FloatType.INSTANCE); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(floatSlot, Literal.of(2)))); + } + + @Test + void testMultiplyDoubleSlotWithIntLiteral() { + Slot doubleSlot = new SlotReference("d", DoubleType.INSTANCE); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(doubleSlot, Literal.of(2)))); + } + + @Test + void testAddWithDoubleLiteral() { + // Float/double arithmetic may be imprecise, reject for all ops + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Add(id, new DoubleLiteral(1.0)))); + } + + @Test + void testAddWithFloatLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Add(id, new FloatLiteral(1.0f)))); + } + + @Test + void testSubtractWithDoubleLiteral() { + Slot id = scan1.getOutput().get(0); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Subtract(id, new DoubleLiteral(1.0)))); + } + + @Test + void testMultiplyWithDecimalLiteral() { + // Small decimal multiply should pass (precision fits) + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Multiply(id, new DecimalLiteral(new BigDecimal("2.0"))))); + } + + @Test + void testDivideWithDecimalLiteral() { + // Divide with decimal: precision overflow too extreme to worry about + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Divide(id, new DecimalLiteral(new BigDecimal("2.0"))))); + } + + @Test + void testAddWithDecimalLiteral() { + // Add/Subtract with decimal are exact, should pass + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot( + new Add(id, new DecimalLiteral(new BigDecimal("1.0"))))); + } + + // ========== tests for isLosslessWidening ========== + + @Test + void testIntegerWidening() { + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + TinyIntType.INSTANCE, IntegerType.INSTANCE)); + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + IntegerType.INSTANCE, BigIntType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + IntegerType.INSTANCE, TinyIntType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + BigIntType.INSTANCE, IntegerType.INSTANCE)); + } + + @Test + void testDecimalWidening() { + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + DecimalV3Type.createDecimalV3Type(5, 2), + DecimalV3Type.createDecimalV3Type(10, 4))); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + DecimalV3Type.createDecimalV3Type(10, 4), + DecimalV3Type.createDecimalV3Type(5, 2))); + } + + @Test + void testIntegralToDecimalWidening() { + Assertions.assertTrue(SimplifyAggGroupBy.isLosslessWidening( + TinyIntType.INSTANCE, DecimalV3Type.createDecimalV3Type(10, 0))); + // BigInt has 19 digits, DECIMAL(5,0) only has 5 integer digits + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + BigIntType.INSTANCE, DecimalV3Type.createDecimalV3Type(5, 0))); + } + + @Test + void testCrossFamilyRejected() { + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + IntegerType.INSTANCE, FloatType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + FloatType.INSTANCE, IntegerType.INSTANCE)); + Assertions.assertFalse(SimplifyAggGroupBy.isLosslessWidening( + IntegerType.INSTANCE, DoubleType.INSTANCE)); + } + + // ========== tests for canExtractSlot ========== + + @Test + void testCanExtractSlotBare() { + Slot id = scan1.getOutput().get(0); + Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(id)); + } + + @Test + void testCanExtractSlotWidening() { + Slot id = scan1.getOutput().get(0); + // INT->BIGINT is lossless widening + Expression cast = new Cast(id, BigIntType.INSTANCE); + Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast)); + } + + @Test + void testCanExtractSlotExplicitCast() { + Slot id = scan1.getOutput().get(0); + // explicit cast should also be acceptable if lossless + Expression cast = new Cast(id, BigIntType.INSTANCE, true); + Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast)); + } + + @Test + void testCanExtractSlotNarrowing() { + Slot id = scan1.getOutput().get(0); + // INT -> TINYINT is narrowing, should be rejected + Expression cast = new Cast(id, TinyIntType.INSTANCE); + Assertions.assertFalse(SimplifyAggGroupBy.canExtractSlot(cast)); + } + + // ========== integration tests via PlanChecker ========== + + @Test + void testMultiplyByZeroNotSimplified() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of(id, new Count().alias("cnt")); + List groupBy = ImmutableList.of(id, new Multiply(id, Literal.of(0))); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + ConnectContext connectContext = MemoTestUtils.createConnectContext(); + connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false); + PlanChecker.from(connectContext, agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.equals(agg)) + ); + } + + @Test + void testNullLiteralNotSimplified() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of(id, new Count().alias("cnt")); + List groupBy = ImmutableList.of(id, new Add(id, NullLiteral.INSTANCE)); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + ConnectContext connectContext = MemoTestUtils.createConnectContext(); + connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false); + PlanChecker.from(connectContext, agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.equals(agg)) + ); + } + + @Test + void testMultiplyDoubleLiteralNotSimplified() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of(id, new Count().alias("cnt")); + List groupBy = ImmutableList.of(id, new Multiply(id, new DoubleLiteral(0.1))); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + ConnectContext connectContext = MemoTestUtils.createConnectContext(); + connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false); + PlanChecker.from(connectContext, agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.equals(agg)) + ); + } }