Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.doris.nereids.exceptions;

import java.util.Optional;

/**
* cast exception.
*/
Expand All @@ -27,7 +25,7 @@ public class CastException extends AnalysisException {
private final String message;

public CastException(String message) {
super(ErrorCode.NONE, message, Optional.of(0), Optional.of(0), Optional.empty());
super(ErrorCode.NONE, message);
this.message = message;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;

/** extract unique function expression which exist multiple times, and add them to a new project child.
* for example:
Expand Down Expand Up @@ -207,20 +209,20 @@ public Rule build() {
allConjuncts.addAll(join.getHashJoinConjuncts());
allConjuncts.addAll(join.getOtherJoinConjuncts());
allConjuncts.addAll(join.getMarkJoinConjuncts());
Optional<Pair<List<Expression>, LogicalProject<Plan>>> rewrittenOpt
= rewriteExpressions(join, allConjuncts);
Optional<JoinRewriteResult> rewrittenOpt = rewriteJoinExpressions(join, allConjuncts);
if (!rewrittenOpt.isPresent()) {
return join;
}

LogicalProject<Plan> newLeftChild = rewrittenOpt.get().second;
List<Expression> newAllConjuncts = rewrittenOpt.get().first;
Plan newLeftChild = rewrittenOpt.get().left;
Plan newRightChild = rewrittenOpt.get().right;
List<Expression> newAllConjuncts = rewrittenOpt.get().newConjuncts;
List<Expression> newHashOtherConjuncts = newAllConjuncts.subList(0, hashOtherConjunctsSize);
List<Expression> newMarkJoinConjuncts = ImmutableList.copyOf(
newAllConjuncts.subList(hashOtherConjunctsSize, totalConjunctsSize));
// TODO: code from FindHashConditionForJoin
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
newLeftChild.getOutput(), join.right().getOutput(), newHashOtherConjuncts);
newLeftChild.getOutput(), newRightChild.getOutput(), newHashOtherConjuncts);
List<Expression> newHashJoinConjuncts = pair.first;
List<Expression> newOtherJoinConjuncts = pair.second;
JoinType joinType = join.getJoinType();
Expand All @@ -233,7 +235,7 @@ public Rule build() {
newMarkJoinConjuncts,
join.getDistributeHint(),
join.getMarkJoinSlotReference(),
ImmutableList.of(newLeftChild, join.right()),
ImmutableList.of(newLeftChild, newRightChild),
join.getJoinReorderContext());
}).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
}
Expand Down Expand Up @@ -269,6 +271,85 @@ public <T extends Expression> Optional<Pair<List<T>, LogicalProject<Plan>>> rewr
return Optional.of(Pair.of(newTargetsBuilder.build(), new LogicalProject<>(projects, plan.child(0))));
}

private Optional<JoinRewriteResult> rewriteJoinExpressions(LogicalJoin<Plan, Plan> join,
Collection<Expression> targets) {
Map<Expression, Integer> volatileExpressionCounter = Maps.newLinkedHashMap();
Map<Expression, Set<Slot>> volatileExpressionSlots = Maps.newLinkedHashMap();
for (Expression target : targets) {
target.foreach(e -> {
Expression expr = (Expression) e;
if (expr.isVolatile()) {
volatileExpressionCounter.merge(expr, 1, Integer::sum);
Set<Slot> volatileInputSlots = expr.getInputSlots();
volatileExpressionSlots
.computeIfAbsent(expr, ignored -> Sets.newLinkedHashSet())
.addAll(volatileInputSlots.isEmpty() ? target.getInputSlots() : volatileInputSlots);
}
});
}

ImmutableList.Builder<NamedExpression> leftAliases = ImmutableList.builder();
ImmutableList.Builder<NamedExpression> rightAliases = ImmutableList.builder();
Map<Expression, Slot> replaceMap = Maps.newHashMap();
Set<Slot> leftOutputSet = join.left().getOutputSet();
Set<Slot> rightOutputSet = join.right().getOutputSet();
for (Entry<Expression, Integer> entry : volatileExpressionCounter.entrySet()) {
if (entry.getValue() <= 1) {
continue;
}
Set<Slot> inputSlots = volatileExpressionSlots.get(entry.getKey());
Set<Slot> volatileInputSlots = entry.getKey().getInputSlots();
if (!volatileInputSlots.isEmpty()
&& !leftOutputSet.containsAll(inputSlots)
&& !rightOutputSet.containsAll(inputSlots)) {
continue;
}
ExprId exprId = StatementScopeIdGenerator.newExprId();
String functionName = entry.getKey() instanceof Function
? ((Function) entry.getKey()).getName() : "volatile";
Alias alias = new Alias(exprId, entry.getKey(), "$_" + functionName + "_" + exprId.asInt() + "_$");
replaceMap.put(alias.child(0), alias.toSlot());
// Join can not add a project at join-pair scope, but repeated volatile expressions
// still need one materialized value. Slot-free volatile functions use the containing
// conjunct's slots to choose a side, so t2.k + rand() can project rand() on the right.
// Volatile functions with input slots use their own slots to avoid projecting
// volatile_udf(t2.k) on the left only because its containing conjunct also uses t1.
// Volatile functions whose own slots span both join children cannot be projected into
// either child, so they are not rewritten here.
// Put right-only expressions on the right child; otherwise keep the previous
// left-child behavior as the conservative default.
if (!inputSlots.isEmpty() && rightOutputSet.containsAll(inputSlots)) {
rightAliases.add(alias);
} else {
leftAliases.add(alias);
}
}
if (replaceMap.isEmpty()) {
return Optional.empty();
}

List<NamedExpression> leftAliasList = leftAliases.build();
List<NamedExpression> rightAliasList = rightAliases.build();
Plan left = appendProjectIfNeeded(join.left(), leftAliasList);
Plan right = appendProjectIfNeeded(join.right(), rightAliasList);
ImmutableList.Builder<Expression> newTargetsBuilder = ImmutableList.builderWithExpectedSize(targets.size());
for (Expression target : targets) {
newTargetsBuilder.add(ExpressionUtils.replace(target, replaceMap));
}
return Optional.of(new JoinRewriteResult(newTargetsBuilder.build(), left, right));
}

private Plan appendProjectIfNeeded(Plan child, List<NamedExpression> aliases) {
if (aliases.isEmpty()) {
return child;
}
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
.addAll(child.getOutput())
.addAll(aliases)
.build();
return new LogicalProject<>(projects, child);
}

/**
* if a unique function exists multiple times in the targets, then add a project to alias it.
*/
Expand Down Expand Up @@ -296,4 +377,16 @@ public List<NamedExpression> tryGenUniqueFunctionAlias(Collection<? extends Expr

return builder.build();
}

private static class JoinRewriteResult {
private final List<Expression> newConjuncts;
private final Plan left;
private final Plan right;

private JoinRewriteResult(List<Expression> newConjuncts, Plan left, Plan right) {
this.newConjuncts = newConjuncts;
this.left = left;
this.right = right;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ private Plan inferNewPredicate(Plan plan, Set<Expression> expressions) {
Set<Expression> predicates = new LinkedHashSet<>();
Set<Slot> planOutputs = plan.getOutputSet();
for (Expression expr : expressions) {
if (expr.containsVolatileExpression()) {
// Volatile expressions (e.g. rand(), uuid()) must not be cloned into
// subtrees that did not already evaluate them. Otherwise, callers that perform
// slot substitution (e.g. SetOp visitors below) would introduce a fresh
// per-row evaluation of the volatile expression on a sibling branch, changing
// query semantics (see EXCEPT/INTERSECT regression cases).
continue;
}
Set<Slot> slots = expr.getInputSlots();
if (!slots.isEmpty() && planOutputs.containsAll(slots)) {
predicates.add(expr);
Expand All @@ -242,6 +250,11 @@ private Plan inferNewPredicateRemoveUselessIsNull(Plan plan, Set<Expression> exp
Set<Expression> predicates = new LinkedHashSet<>();
Set<Slot> planOutputs = plan.getOutputSet();
for (Expression expr : expressions) {
if (expr.containsVolatileExpression()) {
// See inferNewPredicate for rationale: never clone volatile
// predicates into a subtree that did not already evaluate them.
continue;
}
Set<Slot> slots = expr.getInputSlots();
if (slots.isEmpty() || !planOutputs.containsAll(slots)) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ public Expression visit(Expression expression, ReplacerContext ctx) {
if (input.isEmpty() || expression instanceof Slot) {
return expression;
}
// A mixed expression like `t1.a + rand() > t2.b` has inputSlots={t1.a}; if we alias
// it into a child Project, rand()'s evaluation granularity changes from "per join
// pair" to "per row of that child", which silently changes results. Keep such
// expressions inline in otherJoinConjuncts, but still recurse to extract deterministic
// child expressions.
if (expression.containsVolatileExpression()) {
return super.visit(expression, ctx);
}
if (ctx.leftSlots.containsAll(input)) {
Alias alias = ctx.aliasMap.computeIfAbsent(expression, o -> new Alias(o));
ctx.leftAlias.add(alias);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ public Rule build() {
}
for (Expression expr : filter.getConjuncts()) {
Set<Slot> exprInputSlots = expr.getInputSlots();
if (partitionKeySlots.containsAll(exprInputSlots)) {
// A conjunct containing a volatile function such as rand()/uuid()
// must NOT be pushed below the partition top-N. It would filter base rows before
// top-N selection, replacing "top-N then random filter" with "random filter then
// top-N", and the surviving rows of every partition would no longer be the true
// top-N. Empty-input-slot predicates like `rand() > 0.5` would also bypass the
// `containsAll` check otherwise.
if (!expr.containsVolatileExpression() && partitionKeySlots.containsAll(exprInputSlots)) {
bottomConjunctsBuilder.add(expr);
} else {
upperConjunctsBuilder.add(expr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
Expand All @@ -44,6 +45,7 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -61,11 +63,50 @@ public Rule build() {
.when(s -> s.arity() > 0
|| (s instanceof LogicalUnion && !((LogicalUnion) s).getConstantExprsList().isEmpty())))
.thenApply(ctx -> {
LogicalFilter<LogicalSetOperation> filter = ctx.root;
LogicalSetOperation setOperation = filter.child();
LogicalFilter<LogicalSetOperation> origFilter = ctx.root;
LogicalSetOperation setOperation = origFilter.child();

// Pushing a conjunct that contains a volatile expression (rand/uuid/random_bytes/...)
// into each branch changes semantics for every set-op except UNION ALL.
// - UNION ALL: each branch row = exactly one output row (1:1), so evaluating
// rand() once per branch row still matches the per-output-row semantic.
// - UNION DISTINCT / INTERSECT / EXCEPT: the set-op semantics depend on the
// full branch row sets before dedup/intersect/except. Sampling rows in each
// branch independently changes which rows participate (e.g. INTERSECT becomes
// "half of A intersect half of B" instead of "half of (A intersect B)").
boolean canPushVolatileExpr = setOperation instanceof LogicalUnion
&& setOperation.getQualifier() == Qualifier.ALL;
Set<Expression> pushableConjuncts;
Set<Expression> keptAboveConjuncts;
boolean allConjunctsPushable;
if (canPushVolatileExpr) {
pushableConjuncts = origFilter.getConjuncts();
keptAboveConjuncts = ImmutableSet.of();
allConjunctsPushable = true;
} else {
pushableConjuncts = new LinkedHashSet<>();
Set<Expression> kept = new LinkedHashSet<>();
for (Expression c : origFilter.getConjuncts()) {
if (c.containsVolatileExpression()) {
kept.add(c);
} else {
pushableConjuncts.add(c);
}
}
keptAboveConjuncts = kept;
if (pushableConjuncts.isEmpty()) {
return null;
}
allConjunctsPushable = false;
}
LogicalFilter<LogicalSetOperation> filter = allConjunctsPushable
? origFilter
: new LogicalFilter<>(ImmutableSet.copyOf(pushableConjuncts), setOperation);

List<Plan> newChildren = new ArrayList<>();
List<List<SlotReference>> newRegularChildrenOutputs = Lists.newArrayList();
CascadesContext cascadesContext = ctx.cascadesContext;
Plan rewritten;
if (setOperation instanceof LogicalUnion) {
List<List<NamedExpression>> constantExprs = ((LogicalUnion) setOperation).getConstantExprsList();
StatementContext statementContext = ctx.statementContext;
Expand All @@ -85,7 +126,7 @@ public Rule build() {

List<NamedExpression> setOutputs = setOperation.getOutputs();
if (newChildren.isEmpty() && newConstantExprs.isEmpty()) {
return new LogicalEmptyRelation(
rewritten = new LogicalEmptyRelation(
statementContext.getNextRelationId(), setOutputs
);
} else if (newChildren.isEmpty() && newConstantExprs.size() == 1) {
Expand All @@ -104,27 +145,32 @@ public Rule build() {
}
newOneRowRelationOutput.add(oneRowRelationOutput);
}
return new LogicalOneRowRelation(
rewritten = new LogicalOneRowRelation(
ctx.statementContext.getNextRelationId(), newOneRowRelationOutput.build()
);
}
} else {
Builder<List<SlotReference>> newChildrenOutput
= ImmutableList.builderWithExpectedSize(newChildren.size());
for (Plan newChild : newChildren) {
newChildrenOutput.add((List) newChild.getOutput());
}

Builder<List<SlotReference>> newChildrenOutput
= ImmutableList.builderWithExpectedSize(newChildren.size());
for (Plan newChild : newChildren) {
newChildrenOutput.add((List) newChild.getOutput());
rewritten = ((LogicalUnion) setOperation).withChildrenAndConstExprsList(
newChildren, newRegularChildrenOutputs, newConstantExprs);
}

return ((LogicalUnion) setOperation).withChildrenAndConstExprsList(
newChildren, newRegularChildrenOutputs, newConstantExprs);
} else {
addFiltersToNewChildren(setOperation, filter, setOperation.children(),
setOperation.getRegularChildrenOutputs(),
cascadesContext, newChildren, newRegularChildrenOutputs, null,
(rowIndex, columnIndex) -> setOperation.getRegularChildOutput(rowIndex).get(columnIndex),
Function.identity());
rewritten = setOperation.withChildren(newChildren);
}

addFiltersToNewChildren(setOperation, filter, setOperation.children(),
setOperation.getRegularChildrenOutputs(),
cascadesContext, newChildren, newRegularChildrenOutputs, null,
(rowIndex, columnIndex) -> setOperation.getRegularChildOutput(rowIndex).get(columnIndex),
Function.identity());
return setOperation.withChildren(newChildren);
if (keptAboveConjuncts.isEmpty()) {
return rewritten;
}
return new LogicalFilter<>(ImmutableSet.copyOf(keptAboveConjuncts), rewritten);
}).toRule(RuleType.PUSH_DOWN_FILTER_THROUGH_SET_OPERATION);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,18 @@ public Rule build() {
}).toRule(RuleType.PUSH_DOWN_FILTER_THROUGH_WINDOW);
}

/**
* Returns whether {@code conjunct} can be pushed below a window operator with the given
* common partition keys.
*/
public static boolean canPushDown(Expression conjunct, Set<SlotReference> commonPartitionKeys) {
return commonPartitionKeys.containsAll(conjunct.getInputSlots());
// A conjunct that contains a volatile function such as rand()/uuid()
// must NOT be pushed below the window node. Pushing it down filters base rows before
// window evaluation, which changes which rows belong to each partition and therefore
// changes the value of every window function (row_number, rank, sum, ...). In addition,
// a predicate like `rand() > 0.5` has empty input slots, so `containsAll(emptySet)`
// would otherwise wrongly return true.
return !conjunct.containsVolatileExpression()
&& commonPartitionKeys.containsAll(conjunct.getInputSlots());
}
}
Loading
Loading