diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRule.java new file mode 100644 index 0000000000000..92ff33b379b2d --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRule.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; +import org.apache.flink.table.planner.plan.utils.FlinkRexUtil; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.utils.LogicalTypeCasts; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinInfo; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.util.mapping.IntPair; +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; + +/** + * Removes redundant equi-join key references from COALESCE calls above joins. + * + *

In an equi-join {@code ON a.k = b.k}, the non-preserved side's key is either equal to the + * preserved side's key (matched) or NULL (unmatched). This makes it redundant in a COALESCE when it + * appears adjacent-before or anywhere after the preserved side's key: + * + *

+ * + *

For INNER joins both keys are always non-null, so the later-occurring one is always + * unreachable and can be removed regardless of position. FULL OUTER joins are not handled because + * both sides can generate nulls. + * + *

Matches a {@link Project} or {@link Calc} on top of a {@link Join} and uses a {@link + * RexShuttle} to recursively simplify COALESCE calls, including nested ones (e.g., {@code + * CAST(COALESCE(b.k, a.k) AS VARCHAR)}). + */ +@Internal +@Value.Enclosing +public class SimplifyCoalesceWithEquiJoinConditionRule + extends RelRule { + + public static final RelRule PROJECT_INSTANCE = Config.DEFAULT.withProject().toRule(); + + public static final RelRule CALC_INSTANCE = Config.DEFAULT.withCalc().toRule(); + + public SimplifyCoalesceWithEquiJoinConditionRule(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final RelNode relNode = call.rel(0); + final Join join = call.rel(1); + + final JoinInfo joinInfo = join.analyzeCondition(); + if (joinInfo.pairs().isEmpty()) { + return; + } + + final RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + final int leftFieldCount = join.getLeft().getRowType().getFieldCount(); + + final EquiJoinCoalesceSimplifier shuttle = + new EquiJoinCoalesceSimplifier( + rexBuilder, joinInfo, join.getJoinType(), leftFieldCount); + + final RelNode transformed = relNode.accept(shuttle); + if (shuttle.isSimplified()) { + call.transformTo(transformed); + } + } + + // -------------------------------------------------------------------------------------------- + + /** Traverses expressions bottom-up, removing redundant equi-join refs from COALESCE calls. */ + private static class EquiJoinCoalesceSimplifier extends RexShuttle { + + private final RexBuilder rexBuilder; + private final JoinInfo joinInfo; + private final JoinRelType joinType; + private final int leftFieldCount; + private boolean simplified = false; + + private EquiJoinCoalesceSimplifier( + RexBuilder rexBuilder, + JoinInfo joinInfo, + JoinRelType joinType, + int leftFieldCount) { + this.rexBuilder = rexBuilder; + this.joinInfo = joinInfo; + this.joinType = joinType; + this.leftFieldCount = leftFieldCount; + } + + boolean isSimplified() { + return simplified; + } + + @Override + public RexNode visitCall(RexCall call) { + call = (RexCall) super.visitCall(call); + + if (!operatorIsCoalesce(call.getOperator()) || call.getOperands().size() < 2) { + return call; + } + + final List operands = new ArrayList<>(call.getOperands()); + for (final IntPair pair : joinInfo.pairs()) { + tryRemoveRedundantRef(operands, pair); + } + + final boolean changed = operands.size() != call.getOperands().size(); + if (!changed) { + return call; + } + + simplified = true; + + if (operands.size() == 1) { + return castIfNeeded(operands.get(0), call); + } + return call.clone(call.getType(), operands); + } + + /** + * For a given equi-join pair, finds the two key references in the operand list and removes + * the redundant one if safe. + */ + private void tryRemoveRedundantRef(List operands, IntPair equiJoinPair) { + final int leftPos = findRefPosition(operands, equiJoinPair.source); + final int rightPos = findRefPosition(operands, equiJoinPair.target + leftFieldCount); + if (leftPos == -1 || rightPos == -1) { + return; + } + + final int removablePos = findRemovablePosition(leftPos, rightPos); + if (removablePos != -1) { + operands.remove(removablePos); + } + } + + /** Returns the position of the first {@link RexInputRef} with the given index, or -1. */ + private static int findRefPosition(List operands, int inputRefIndex) { + for (int i = 0; i < operands.size(); i++) { + if (operands.get(i) instanceof RexInputRef + && ((RexInputRef) operands.get(i)).getIndex() == inputRefIndex) { + return i; + } + } + return -1; + } + + /** + * Determines which of the two equi-join key positions can be safely removed, or returns -1. + */ + private int findRemovablePosition(int leftPos, int rightPos) { + switch (joinType) { + case INNER: + // Both keys are non-null; the later one is unreachable + return Math.max(leftPos, rightPos); + case LEFT: + return canSafelyRemove(rightPos, leftPos) ? rightPos : -1; + case RIGHT: + return canSafelyRemove(leftPos, rightPos) ? leftPos : -1; + default: + return -1; + } + } + + /** + * The non-preserved ref can be safely removed when it is adjacent-before or anywhere after + * the preserved ref. The only unsafe case is when the non-preserved ref appears earlier + * with other operands in between - removing it would change which intermediate value + * COALESCE returns. + */ + private static boolean canSafelyRemove(int nonPreservedPos, int preservedPos) { + return nonPreservedPos >= preservedPos - 1; + } + + private RexNode castIfNeeded(RexNode node, RexCall originalCall) { + final LogicalType nodeType = FlinkTypeFactory.toLogicalType(node.getType()); + final LogicalType targetType = FlinkTypeFactory.toLogicalType(originalCall.getType()); + if (LogicalTypeCasts.supportsImplicitCast(nodeType, targetType)) { + return node; + } + return rexBuilder.makeCast(originalCall.getType(), node); + } + } + + // -------------------------------------------------------------------------------------------- + + private static boolean operatorIsCoalesce(SqlOperator op) { + return (op instanceof BridgingSqlFunction + && ((BridgingSqlFunction) op) + .getDefinition() + .equals(BuiltInFunctionDefinitions.COALESCE)) + || op.getKind() == SqlKind.COALESCE; + } + + private static boolean hasCoalesceInvocation(RexNode node) { + return FlinkRexUtil.hasOperatorCallMatching( + node, SimplifyCoalesceWithEquiJoinConditionRule::operatorIsCoalesce); + } + + private static boolean isApplicableJoin(Join join) { + final JoinRelType joinType = join.getJoinType(); + return joinType == JoinRelType.LEFT + || joinType == JoinRelType.RIGHT + || joinType == JoinRelType.INNER; + } + + // -------------------------------------------------------------------------------------------- + + /** Configuration for {@link SimplifyCoalesceWithEquiJoinConditionRule}. */ + @Value.Immutable(singleton = false) + public interface Config extends RelRule.Config { + + Config DEFAULT = + ImmutableSimplifyCoalesceWithEquiJoinConditionRule.Config.builder() + .build() + .as(Config.class); + + @Override + default SimplifyCoalesceWithEquiJoinConditionRule toRule() { + return new SimplifyCoalesceWithEquiJoinConditionRule(this); + } + + default Config withProject() { + final Predicate projectPredicate = + p -> + p.getProjects().stream() + .anyMatch( + SimplifyCoalesceWithEquiJoinConditionRule + ::hasCoalesceInvocation); + final RelRule.OperandTransform projectTransform = + operandBuilder -> + operandBuilder + .operand(Project.class) + .predicate(projectPredicate) + .oneInput( + b -> + b.operand(Join.class) + .predicate( + SimplifyCoalesceWithEquiJoinConditionRule + ::isApplicableJoin) + .anyInputs()); + return withOperandSupplier(projectTransform).as(Config.class); + } + + default Config withCalc() { + final Predicate calcPredicate = + c -> + c.getProgram().getExprList().stream() + .anyMatch( + SimplifyCoalesceWithEquiJoinConditionRule + ::hasCoalesceInvocation); + final RelRule.OperandTransform calcTransform = + operandBuilder -> + operandBuilder + .operand(Calc.class) + .predicate(calcPredicate) + .oneInput( + b -> + b.operand(Join.class) + .predicate( + SimplifyCoalesceWithEquiJoinConditionRule + ::isApplicableJoin) + .anyInputs()); + return withOperandSupplier(calcTransform).as(Config.class); + } + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 546081a9c5e6a..a6a094ae8889c 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -75,7 +75,9 @@ object FlinkBatchRuleSets { RemoveUnreachableCoalesceArgumentsRule.PROJECT_INSTANCE, RemoveUnreachableCoalesceArgumentsRule.FILTER_INSTANCE, RemoveUnreachableCoalesceArgumentsRule.JOIN_INSTANCE, - RemoveUnreachableCoalesceArgumentsRule.CALC_INSTANCE + RemoveUnreachableCoalesceArgumentsRule.CALC_INSTANCE, + SimplifyCoalesceWithEquiJoinConditionRule.PROJECT_INSTANCE, + SimplifyCoalesceWithEquiJoinConditionRule.CALC_INSTANCE ) private val LIMIT_RULES: RuleSet = RuleSets.ofList( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala index 15372999cd1ac..03fe6ebeae3cf 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala @@ -84,7 +84,9 @@ object FlinkStreamRuleSets { RemoveUnreachableCoalesceArgumentsRule.PROJECT_INSTANCE, RemoveUnreachableCoalesceArgumentsRule.FILTER_INSTANCE, RemoveUnreachableCoalesceArgumentsRule.JOIN_INSTANCE, - RemoveUnreachableCoalesceArgumentsRule.CALC_INSTANCE + RemoveUnreachableCoalesceArgumentsRule.CALC_INSTANCE, + SimplifyCoalesceWithEquiJoinConditionRule.PROJECT_INSTANCE, + SimplifyCoalesceWithEquiJoinConditionRule.CALC_INSTANCE ) /** RuleSet to simplify predicate expressions in filters and joins */ diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.java new file mode 100644 index 0000000000000..0a3d8c7b81bbd --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.planner.utils.StreamTableTestUtil; +import org.apache.flink.table.planner.utils.TableTestBase; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** Test rule {@link SimplifyCoalesceWithEquiJoinConditionRule}. */ +class SimplifyCoalesceWithEquiJoinConditionRuleTest extends TableTestBase { + + private StreamTableTestUtil util; + + @BeforeEach + void before() { + util = streamTestUtil(TableConfig.getDefault()); + + util.tableEnv() + .executeSql( + "CREATE TABLE orders (" + + " order_id BIGINT NOT NULL," + + " user_id BIGINT NOT NULL," + + " amount DOUBLE," + + " PRIMARY KEY (order_id) NOT ENFORCED" + + ") WITH ('connector' = 'values')"); + + util.tableEnv() + .executeSql( + "CREATE TABLE order_details (" + + " order_id BIGINT NOT NULL," + + " detail STRING," + + " PRIMARY KEY (order_id) NOT ENFORCED" + + ") WITH ('connector' = 'values')"); + + util.tableEnv() + .executeSql( + "CREATE TABLE composite_key_table (" + + " k1 BIGINT NOT NULL," + + " k2 BIGINT NOT NULL," + + " val STRING," + + " PRIMARY KEY (k1, k2) NOT ENFORCED" + + ") WITH ('connector' = 'values')"); + + util.tableEnv() + .executeSql( + "CREATE TABLE composite_key_details (" + + " k1 BIGINT NOT NULL," + + " k2 BIGINT NOT NULL," + + " info STRING," + + " PRIMARY KEY (k1, k2) NOT ENFORCED" + + ") WITH ('connector' = 'values')"); + + util.tableEnv() + .executeSql( + "CREATE TABLE order_details_row (" + + " r ROW NOT NULL, " + + " detail STRING," + + " PRIMARY KEY (r) NOT ENFORCED" + + ") WITH ('connector' = 'values')"); + } + + @Test + void testCoalesceOnLeftJoinEquiKey() { + util.verifyRelPlan( + "SELECT COALESCE(b.order_id, a.order_id) AS order_id, a.amount " + + "FROM orders a LEFT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceReversedArgsOnLeftJoin() { + util.verifyRelPlan( + "SELECT COALESCE(a.order_id, b.order_id) AS order_id, a.amount " + + "FROM orders a LEFT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceOnInnerJoinEquiKey() { + util.verifyRelPlan( + "SELECT COALESCE(b.order_id, a.order_id) AS order_id " + + "FROM orders a INNER JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceOnRightJoinEquiKey() { + util.verifyRelPlan( + "SELECT COALESCE(a.order_id, b.order_id) AS order_id " + + "FROM orders a RIGHT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceOnFullJoinNotSimplified() { + util.verifyRelPlan( + "SELECT COALESCE(b.order_id, a.order_id) AS order_id " + + "FROM orders a FULL JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceOnNonEquiColumnsNotSimplified() { + util.verifyRelPlan( + "SELECT COALESCE(b.detail, CAST(a.amount AS STRING)) AS val " + + "FROM orders a LEFT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceWithThreeArgs() { + util.verifyRelPlan( + "SELECT COALESCE(b.order_id, a.order_id, 0) AS order_id " + + "FROM orders a LEFT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testMultipleCoalesceOnCompositeKey() { + util.verifyRelPlan( + "SELECT COALESCE(b.k1, a.k1) AS k1, COALESCE(b.k2, a.k2) AS k2, a.val " + + "FROM composite_key_table a " + + "LEFT JOIN composite_key_details b ON a.k1 = b.k1 AND a.k2 = b.k2"); + } + + @Test + void testCoalesceThreeArgsAdjacentPair() { + util.verifyRelPlan( + "SELECT COALESCE(b.order_id, a.order_id, a.user_id) AS val " + + "FROM orders a LEFT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceThreeArgsNonPreservedAfterPreserved() { + util.verifyRelPlan( + "SELECT COALESCE(a.user_id, a.order_id, b.order_id) AS val " + + "FROM orders a LEFT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceThreeArgsNonPreservedBeforeWithGapNotSimplified() { + util.verifyRelPlan( + "SELECT COALESCE(b.order_id, a.user_id, a.order_id) AS val " + + "FROM orders a LEFT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceThreeArgsInnerJoin() { + util.verifyRelPlan( + "SELECT COALESCE(a.user_id, b.order_id, a.order_id) AS val " + + "FROM orders a INNER JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceWrappedInCast() { + util.verifyRelPlan( + "SELECT CAST(COALESCE(b.order_id, a.order_id) AS STRING) AS order_id_str " + + "FROM orders a LEFT JOIN order_details b ON a.order_id = b.order_id"); + } + + @Test + void testCoalesceOnNestedRowScalarField() { + util.verifyRelPlan( + "SELECT CAST(COALESCE(b.r.order_id, a.order_id) AS STRING) AS order_id_str " + + "FROM orders a LEFT JOIN order_details_row b ON a.order_id = b.r.order_id"); + } +} diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.xml new file mode 100644 index 0000000000000..cf6cbf60a2df1 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.xml @@ -0,0 +1,305 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +