Skip to content

Commit

Permalink
[CALCITE-6652] RelDecorrelator can't decorrelate query with limit 1
Browse files Browse the repository at this point in the history
  • Loading branch information
suibianwanwank authored and rubenada committed Feb 26, 2025
1 parent 4ee2e41 commit 0a13d99
Show file tree
Hide file tree
Showing 6 changed files with 975 additions and 62 deletions.
166 changes: 156 additions & 10 deletions core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.BiRel;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
Expand Down Expand Up @@ -73,6 +74,7 @@
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlFunction;
Expand All @@ -94,6 +96,7 @@
import org.apache.calcite.util.mapping.Mappings;
import org.apache.calcite.util.trace.CalciteTrace;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
Expand Down Expand Up @@ -523,6 +526,19 @@ protected RexNode removeCorrelationExpr(
return null;
}

if (isCorVarDefined && (rel.fetch != null || rel.offset != null)) {
if (rel.fetch != null
&& rel.offset == null
&& RexLiteral.intValue(rel.fetch) == 1) {
return decorrelateFetchOneSort(rel, frame);
}
// Can not decorrelate if the sort has per-correlate-key attributes like
// offset or fetch limit, because these attributes scope would change to
// global after decorrelation. They should take effect within the scope
// of the correlation key actually.
return null;
}

final RelNode newInput = frame.r;

Mappings.TargetMapping mapping =
Expand Down Expand Up @@ -767,16 +783,6 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
public @Nullable Frame getInvoke(RelNode r, boolean isCorVarDefined, @Nullable RelNode parent) {
final Frame frame = dispatcher.invoke(r, isCorVarDefined);
currentRel = parent;
if (frame != null && isCorVarDefined && r instanceof Sort) {
final Sort sort = (Sort) r;
// Can not decorrelate if the sort has per-correlate-key attributes like
// offset or fetch limit, because these attributes scope would change to
// global after decorrelation. They should take effect within the scope
// of the correlation key actually.
if (sort.offset != null || sort.fetch != null) {
return null;
}
}
if (frame != null) {
map.put(r, frame);
}
Expand All @@ -795,6 +801,146 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
return null;
}

protected @Nullable Frame decorrelateFetchOneSort(Sort sort, final Frame frame) {
Frame aggFrame = decorrelateSortAsAggregate(sort, frame);
if (aggFrame != null) {
return aggFrame;
}
//
// Rewrite logic:
//
// If sorted without offset and fetch = 1 (enforced by the caller), rewrite the sort to be
// Aggregate(group=(corVar.. , field..))
// project(first_value(field) over (partition by corVar order by (sort collation)))
// input
//
// 1. For the original sorted input, apply the FIRST_VALUE window function to produce
// the result of sorting with LIMIT 1, and the same as the decorrelate of aggregate,
// add correlated variables in partition list to maintain semantic consistency.
// 2. To ensure that there is at most one row of output for
// any combination of correlated variables, distinct for correlated variables.
// 3. Since we have partitioned by all correlated variables
// in the sorted output field window, so for any combination of correlated variables,
// all other field values are unique. So the following two are equivalent:
// - group by corVar1, covVar2, field1, field2
// - any_value(fields1), any_value(fields2) group by corVar1, covVar2
// Here we use the first.
final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();

final PairList<RexNode, String> corVarProjects = PairList.of();
List<RelDataTypeField> fieldList = frame.r.getRowType().getFieldList();
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
corDefOutputs.put(entry.getKey(),
sort.getRowType().getFieldCount() + corVarProjects.size());
RexInputRef.add2(corVarProjects, entry.getValue(), fieldList);
}

final List<RexNode> sortExprs =
new ArrayList<>(sort.getCollation().getFieldCollations().size());
for (RelFieldCollation collation : sort.getCollation().getFieldCollations()) {
Integer newIdx = requireNonNull(frame.oldToNewOutputs.get(collation.getFieldIndex()));
RexNode node = RexInputRef.of(newIdx, fieldList);
if (collation.direction == RelFieldCollation.Direction.DESCENDING) {
node = relBuilder.desc(node);
}
if (collation.nullDirection == RelFieldCollation.NullDirection.FIRST) {
node = relBuilder.nullsFirst(node);
} else if (collation.nullDirection == RelFieldCollation.NullDirection.LAST) {
node = relBuilder.nullsLast(node);
}
sortExprs.add(node);
}

final PairList<RexNode, String> newProjExprs = PairList.of();
for (RelDataTypeField field : sort.getRowType().getFieldList()) {
final int newIdx =
requireNonNull(frame.oldToNewOutputs.get(field.getIndex()));

RelBuilder.AggCall aggCall =
relBuilder.aggregateCall(SqlStdOperatorTable.FIRST_VALUE,
RexInputRef.of(newIdx, fieldList));

// Convert each field from the sorted output to a window function that partitions by
// correlated variables, orders by the collation, and return the first_value.
RexNode winCall = aggCall.over()
.orderBy(sortExprs)
.partitionBy(corVarProjects.leftList())
.toRex();
mapOldToNewOutputs.put(newProjExprs.size(), newProjExprs.size());
newProjExprs.add(winCall, field.getName());
}
newProjExprs.addAll(corVarProjects);
RelNode result = relBuilder.push(frame.r)
.project(newProjExprs.leftList(), newProjExprs.rightList())
.distinct().build();

return register(sort, result, mapOldToNewOutputs, corDefOutputs);
}

protected @Nullable Frame decorrelateSortAsAggregate(Sort sort, final Frame frame) {
final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
if (sort.getCollation().getFieldCollations().size() == 1
&& sort.getRowType().getFieldCount() == 1
&& !frame.corDefOutputs.isEmpty()) {
//
// Rewrite logic:
//
// If sorted with no OFFSET and FETCH = 1, and only one collation field,
// rewrite the Sort as Aggregate using MIN/MAX function.
// Example:
// Sort(sort0=[$0], dir0=[ASC], fetch=[1])
// input
// Rewrite to:
// Aggregate(group=(corVar), agg=[min($0))
//
// Note: MIN/MAX is not strictly equivalent to LIMIT 1. When the input has 0 rows,
// MIN/MAX returns NULL, while LIMIT 1 returns 0 rows.
// However, in the decorrelate, we add correlated variables to the group list
// to ensure equivalence when Correlate is transformed to Join. When the group list
// is non-empty, MIN/MAX will also return 0 rows if the input has 0 rows.
// So in this case, the transformation is legal.
RelFieldCollation collation = Util.first(sort.getCollation().getFieldCollations());

if (collation.nullDirection != RelFieldCollation.NullDirection.LAST) {
return null;
}

SqlAggFunction aggFunction;
switch (collation.getDirection()) {
case ASCENDING:
case STRICTLY_ASCENDING:
aggFunction = SqlStdOperatorTable.MIN;
break;
case DESCENDING:
case STRICTLY_DESCENDING:
aggFunction = SqlStdOperatorTable.MAX;
break;
default:
return null;
}

final int newIdx = requireNonNull(frame.oldToNewOutputs.get(collation.getFieldIndex()));
RelBuilder.AggCall aggCall = relBuilder.push(frame.r)
.aggregateCall(aggFunction, relBuilder.fields(ImmutableList.of(newIdx)));

// As with the aggregate decorrelate, add correlated variables to the group list.
final List<RexInputRef> groupKey = new ArrayList<>();
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
groupKey.add(RexInputRef.of(entry.getValue(), frame.r.getRowType()));
corDefOutputs.put(entry.getKey(), corDefOutputs.size());
}

RelNode aggregate = relBuilder.aggregate(relBuilder.groupKey(groupKey), aggCall).build();

// Add the mapping for the added aggregate fields.
mapOldToNewOutputs.put(0, groupKey.size());
return register(sort, aggregate, mapOldToNewOutputs, corDefOutputs);
}
return null;
}

public @Nullable Frame decorrelateRel(LogicalProject rel, boolean isCorVarDefined) {
return decorrelateRel((Project) rel, isCorVarDefined);
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/apache/calcite/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -2086,7 +2086,7 @@ public static <T> Iterable<T> orEmpty(@Nullable Iterable<T> v0) {
*
* @throws java.lang.IndexOutOfBoundsException if the list is empty
*/
public <E> E first(List<E> list) {
public static <E> E first(List<E> list) {
return list.get(0);
}

Expand Down
105 changes: 105 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8577,6 +8577,111 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) {
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateProjectWithFetchOne() {
final String query = "SELECT name, "
+ "(SELECT sal FROM emp where dept.deptno = emp.deptno order by sal limit 1) "
+ "FROM dept";
sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateProjectWithFetchOneDesc() {
final String query = "SELECT name, "
+ "(SELECT emp.sal FROM emp WHERE dept.deptno = emp.deptno "
+ "ORDER BY emp.sal desc nulls last LIMIT 1) "
+ "FROM dept";
sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateFilterWithFetchOne() {
final String query = "SELECT name FROM dept "
+ "WHERE 10 > (SELECT emp.sal FROM emp where dept.deptno = emp.deptno "
+ "ORDER BY emp.sal limit 1)";
sql(query).withRule(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateFilterWithFetchOneDesc() {
final String query = "SELECT name FROM dept "
+ "WHERE 10 > (SELECT emp.sal FROM emp where dept.deptno = emp.deptno "
+ "ORDER BY emp.sal desc nulls last limit 1)";
sql(query).withRule(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateFilterWithFetchOneDesc1() {
final String query = "SELECT name FROM dept "
+ "WHERE 10 > (SELECT emp.sal FROM emp where dept.deptno = emp.deptno "
+ "ORDER BY emp.sal desc limit 1)";
sql(query).withRule(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateProjectWithMultiKeyAndFetchOne() {
final String query = "SELECT name, "
+ "(SELECT sal FROM emp where dept.deptno = emp.deptno "
+ "order by year(hiredate), emp.sal limit 1) FROM dept";
sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateProjectWithMultiKeyAndFetchOne1() {
final String query = "SELECT name, "
+ "(SELECT sal FROM emp where dept.deptno = emp.deptno and dept.name = emp.ename "
+ "order by year(hiredate), emp.sal limit 1) FROM dept";
sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateFilterWithMultiKeyAndFetchOne() {
final String query = "SELECT name FROM dept "
+ "WHERE 10 > (SELECT emp.sal FROM emp where dept.deptno = emp.deptno "
+ "order by year(hiredate), emp.sal desc limit 1)";
sql(query).withRule(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-434">[CALCITE-434]
* Converting predicates on date dimension columns into date ranges</a>,
Expand Down
Loading

0 comments on commit 0a13d99

Please sign in to comment.