Skip to content

Commit

Permalink
[BugFix] Prune aggregate non-required columns after mv transparent re…
Browse files Browse the repository at this point in the history
…write (backport #55286) (#55337)

Signed-off-by: shuming.li <[email protected]>
Co-authored-by: shuming.li <[email protected]>
  • Loading branch information
mergify[bot] and LiShuMing authored Jan 22, 2025
1 parent 96717c5 commit 4cf0f20
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ public class OpRuleBit {
public static final int OP_MV_TRANSPARENT_REWRITE = 2;
// Operator has been partition pruned or not, if partition pruned, no need to prune again.
public static final int OP_PARTITION_PRUNED = 3;
// Operator has been mv transparent union rewrite and needs to prune agg columns.
public static final int OP_MV_AGG_PRUNE_COLUMNS = 4;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.QueryMaterializationContext;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
Expand Down Expand Up @@ -216,7 +217,7 @@ private OptExpression redirectToMVDefinedQuery(OptimizerContext context,
/**
* Get transparent plan if possible.
* What's the transparent plan?
* see {@link MvPartitionCompensator#getMvTransparentPlan(MaterializationContext, MVCompensation, List)
* see {@link MvPartitionCompensator#getMvTransparentPlan(MaterializationContext, MVCompensation, List, ColumnRefSet)}
*/
private OptExpression getMvTransparentPlan(MaterializationContext mvContext,
OptExpression input,
Expand Down Expand Up @@ -246,7 +247,7 @@ private OptExpression getMvTransparentPlan(MaterializationContext mvContext,
return null;
}
OptExpression transparentPlan = MvPartitionCompensator.getMvTransparentPlan(mvContext, mvCompensation,
expectOutputColumns);
expectOutputColumns, false);
return transparentPlan;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

package com.starrocks.sql.optimizer.rule.transformation.materialization;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
Expand Down Expand Up @@ -42,10 +41,13 @@
import java.util.Set;
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.operator.OpRuleBit.OP_MV_AGG_PRUNE_COLUMNS;

public class MVColumnPruner {
private ColumnRefSet requiredOutputColumns;

public OptExpression pruneColumns(OptExpression queryExpression) {
public OptExpression pruneColumns(OptExpression queryExpression, ColumnRefSet requiredOutputColumns) {
this.requiredOutputColumns = requiredOutputColumns;
if (queryExpression.getOp() instanceof LogicalFilterOperator) {
OptExpression newQueryOptExpression = doPruneColumns(queryExpression.inputAt(0));
Operator filterOp = queryExpression.getOp();
Expand All @@ -58,13 +60,13 @@ public OptExpression pruneColumns(OptExpression queryExpression) {
}

public OptExpression doPruneColumns(OptExpression optExpression) {
// TODO: remove this check after we support more operators.
Projection projection = optExpression.getOp().getProjection();
// OptExpression after mv rewrite must have projection.
if (projection == null) {
return optExpression;
}
Preconditions.checkState(projection != null);
requiredOutputColumns = new ColumnRefSet(projection.getOutputColumns());
// OptExpression after mv rewrite must have projection.
return optExpression.getOp().accept(new ColumnPruneVisitor(), optExpression, null);
}

Expand Down Expand Up @@ -124,20 +126,60 @@ public OptExpression visitLogicalTableScan(OptExpression optExpression, Void con

public OptExpression visitLogicalAggregate(OptExpression optExpression, Void context) {
LogicalAggregationOperator aggregationOperator = (LogicalAggregationOperator) optExpression.getOp();
if (aggregationOperator.getProjection() != null) {
Projection projection = aggregationOperator.getProjection();
projection.getColumnRefMap().values().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
}
if (aggregationOperator.getPredicate() != null) {
requiredOutputColumns.union(Utils.extractColumnRef(aggregationOperator.getPredicate()));
}
requiredOutputColumns.union(aggregationOperator.getGroupingKeys());
for (Map.Entry<ColumnRefOperator, CallOperator> entry : aggregationOperator.getAggregations().entrySet()) {
requiredOutputColumns.union(entry.getKey());
requiredOutputColumns.union(Utils.extractColumnRef(entry.getValue()));
// It's safe to prune columns if the aggregation operator has been rewritten by mv since the rewritten
// mv plan should be rollup from the original plan.
// TODO: We can do this in more normal ways rather than only mv rewrite later,
// issue: https://github.com/StarRocks/starrocks/issues/55285
if (aggregationOperator.isOpRuleBitSet(OP_MV_AGG_PRUNE_COLUMNS)) {
// project
Projection newProjection = null;
if (aggregationOperator.getProjection() != null) {
newProjection = new Projection(aggregationOperator.getProjection().getColumnRefMap());
newProjection.getColumnRefMap().values().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
}
// group by
final List<ColumnRefOperator> newGroupByKeys = aggregationOperator.getGroupingKeys()
.stream()
.filter(col -> requiredOutputColumns.contains(col))
.collect(Collectors.toList());
requiredOutputColumns.union(newGroupByKeys);
// partition by
final List<ColumnRefOperator> newPartitionByKeys = aggregationOperator.getPartitionByColumns()
.stream()
.filter(col -> requiredOutputColumns.contains(col))
.collect(Collectors.toList());
requiredOutputColumns.union(newPartitionByKeys);
// aggregations
final Map<ColumnRefOperator, CallOperator> newAggregations = aggregationOperator.getAggregations()
.entrySet()
.stream()
.filter(e -> requiredOutputColumns.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
newAggregations.values().stream().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
final LogicalAggregationOperator newAggOp = new LogicalAggregationOperator.Builder()
.withOperator(aggregationOperator)
.setProjection(newProjection)
.setGroupingKeys(newGroupByKeys)
.setPartitionByColumns(newPartitionByKeys)
.setAggregations(newAggregations)
.build();
return OptExpression.create(newAggOp, visitChildren(optExpression));
} else {
if (aggregationOperator.getProjection() != null) {
Projection projection = aggregationOperator.getProjection();
projection.getColumnRefMap().values().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
}
requiredOutputColumns.union(aggregationOperator.getGroupingKeys());
for (Map.Entry<ColumnRefOperator, CallOperator> entry : aggregationOperator.getAggregations().entrySet()) {
requiredOutputColumns.union(entry.getKey());
requiredOutputColumns.union(Utils.extractColumnRef(entry.getValue()));
}
List<OptExpression> children = visitChildren(optExpression);
return OptExpression.create(aggregationOperator, children);
}
List<OptExpression> children = visitChildren(optExpression);
return OptExpression.create(aggregationOperator, children);
}

public OptExpression visitLogicalUnion(OptExpression optExpression, Void context) {
Expand Down Expand Up @@ -174,14 +216,11 @@ public OptExpression visitLogicalUnion(OptExpression optExpression, Void context
}
List<List<ColumnRefOperator>> newChildOutputColumns = Lists.newArrayList();
for (int childIdx = 0; childIdx < optExpression.arity(); ++childIdx) {
List<ColumnRefOperator> childOutputCols = unionOperator.getChildOutputColumns().get(childIdx);
List<ColumnRefOperator> newChildOutputCols = Lists.newArrayList();
newUnionOutputIdxes.stream()
final List<ColumnRefOperator> childOutputCols = unionOperator.getChildOutputColumns().get(childIdx);
final List<ColumnRefOperator> newChildOutputCols = newUnionOutputIdxes.stream()
.map(idx -> childOutputCols.get(idx))
.forEach(x -> {
requiredOutputColumns.union(x);
newChildOutputCols.add(x);
});
.collect(Collectors.toList());
requiredOutputColumns.union(newChildOutputCols);
newChildOutputColumns.add(newChildOutputCols);
}
LogicalUnionOperator newUnionOperator = new LogicalUnionOperator(newUnionOutputColRefs, newChildOutputColumns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,8 @@ public OptExpression postRewrite(OptimizerContext optimizerContext,
if (candidate == null) {
return null;
}
candidate = new MVColumnPruner().pruneColumns(candidate);
final ColumnRefSet requiredOutputColumns = optimizerContext.getTaskContext().getRequiredColumns();
candidate = new MVColumnPruner().pruneColumns(candidate, requiredOutputColumns);
candidate = new MVPartitionPruner(optimizerContext, mvRewriteContext).prunePartition(candidate);
return candidate;
}
Expand Down Expand Up @@ -1270,7 +1271,7 @@ private OptExpression buildMVScanOptExpression(MaterializationContext materializ
final List<ColumnRefOperator> originalOutputColumns = MvUtils.getMvScanOutputColumnRefs(mv, mvScanOperator);
// build mv scan opt expression with or without compensate
final OptExpression mvScanOptExpression = mvCompensation.isTransparentRewrite() ?
getMvTransparentPlan(materializationContext, mvCompensation, originalOutputColumns) :
getMvTransparentPlan(materializationContext, mvCompensation, originalOutputColumns, true) :
getMVScanPlanWithoutCompensate(rewriteContext, columnRewriter, mvColumnRefToScalarOp);
return mvScanOptExpression;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.ScanOperatorPredicates;
import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalUnionOperator;
Expand Down Expand Up @@ -89,6 +90,7 @@
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.OptimizerTraceUtil.logMVRewrite;
import static com.starrocks.sql.optimizer.operator.OpRuleBit.OP_MV_AGG_PRUNE_COLUMNS;
import static com.starrocks.sql.optimizer.operator.OpRuleBit.OP_MV_UNION_REWRITE;
import static com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils.deriveLogicalProperty;
import static com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils.mergeRanges;
Expand Down Expand Up @@ -169,7 +171,7 @@ public static MVCompensation getMvCompensation(OptExpression queryPlan,
* @param mvContext: materialized view context
* @return: a pair of compensated mv scan plan(refreshed partitions) and its output columns
*/
private static Pair<OptExpression, List<ColumnRefOperator>> getMvScanPlan(MaterializationContext mvContext) {
private static Pair<OptExpression, List<ColumnRefOperator>> getMVScanPlan(MaterializationContext mvContext) {
// NOTE: mv's scan operator has already been partition pruned by filtering refreshed partitions,
// see MvRewritePreprocessor#createScanMvOperator.
final LogicalOlapScanOperator mvScanOperator = mvContext.getScanMvOperator();
Expand Down Expand Up @@ -198,8 +200,11 @@ private static Pair<OptExpression, List<ColumnRefOperator>> getMvScanPlan(Materi
* @param mvCompensation: materialized view's compensation info
* @return: a pair of compensated mv query scan plan(to-refreshed partitions) and its output columns
*/
private static Pair<OptExpression, List<ColumnRefOperator>> getMvQueryPlan(MaterializationContext mvContext,
MVCompensation mvCompensation) {
private static Pair<OptExpression, List<ColumnRefOperator>> getMVCompensationPlan(
MaterializationContext mvContext,
MVCompensation mvCompensation,
List<ColumnRefOperator> originalOutputColumns,
boolean isMVRewrite) {
final OptExpression mvQueryPlan = mvContext.getMvExpression();
OptExpression compensateMvQueryPlan = getMvCompensateQueryPlan(mvContext, mvCompensation, mvQueryPlan);
if (compensateMvQueryPlan == null) {
Expand All @@ -216,7 +221,21 @@ private static Pair<OptExpression, List<ColumnRefOperator>> getMvQueryPlan(Mater
List<ColumnRefOperator> orgMvQueryOutputColumnRefs = mvContext.getMvOutputColumnRefs();
List<ColumnRefOperator> mvQueryOutputColumnRefs = duplicator.getMappedColumns(orgMvQueryOutputColumnRefs);
newMvQueryPlan.getOp().setOpRuleBit(OP_MV_UNION_REWRITE);
return Pair.create(newMvQueryPlan, mvQueryOutputColumnRefs);
if (isMVRewrite) {
// NOTE: mvScanPlan and mvCompensatePlan will output all columns of the mv's defined query,
// it may contain more columns than the requiredColumns.
// 1. For simple non-blocking operators(scan/join), it can be pruned by normal rules, but for
// aggregate operators, it should be handled in MVColumnPruner.
// 2. For mv rewrite, it's safe to prune aggregate columns in mv compensate plan, but it cannot determine
// required columns in the transparent rule.
List<LogicalAggregationOperator> list = Lists.newArrayList();
Utils.extractOperator(newMvQueryPlan, list, op -> op instanceof LogicalAggregationOperator);
list.stream().forEach(op -> op.setOpRuleBit(OP_MV_AGG_PRUNE_COLUMNS));
}
// Adjust query output columns to mv's output columns to make sure the output columns are the same as
// expectOutputColumns which are mv scan operator's output columns.
return adjustOptExpressionOutputColumnType(mvContext.getQueryRefFactory(),
newMvQueryPlan, mvQueryOutputColumnRefs, originalOutputColumns);
}

public static OptExpression getMvCompensateQueryPlan(MaterializationContext mvContext,
Expand Down Expand Up @@ -249,32 +268,30 @@ public static OptExpression getMvCompensateQueryPlan(MaterializationContext mvCo
*/
public static OptExpression getMvTransparentPlan(MaterializationContext mvContext,
MVCompensation mvCompensation,
List<ColumnRefOperator> originalOutputColumns) {
List<ColumnRefOperator> originalOutputColumns,
boolean isMVRewrite) {
Preconditions.checkArgument(originalOutputColumns != null);
Preconditions.checkState(mvCompensation.getState().isCompensate());

Pair<OptExpression, List<ColumnRefOperator>> mvScanPlans = getMvScanPlan(mvContext);
if (mvScanPlans == null) {
final Pair<OptExpression, List<ColumnRefOperator>> mvScanPlan = getMVScanPlan(mvContext);
if (mvScanPlan == null) {
logMVRewrite(mvContext, "Get mv scan transparent plan failed");
return null;
}

Pair<OptExpression, List<ColumnRefOperator>> mvQueryPlans = getMvQueryPlan(mvContext, mvCompensation);
if (mvQueryPlans == null) {
final Pair<OptExpression, List<ColumnRefOperator>> mvCompensationPlan = getMVCompensationPlan(mvContext,
mvCompensation, originalOutputColumns, isMVRewrite);
if (mvCompensationPlan == null) {
logMVRewrite(mvContext, "Get mv query transparent plan failed");
return null;
}
// Adjust query output columns to mv's output columns to make sure the output columns are the same as
// expectOutputColumns which are mv scan operator's output columns.
mvQueryPlans = adjustOptExpressionOutputColumnType(mvContext.getQueryRefFactory(),
mvQueryPlans.first, mvQueryPlans.second, originalOutputColumns);

LogicalUnionOperator unionOperator = new LogicalUnionOperator.Builder()
.setOutputColumnRefOp(originalOutputColumns)
.setChildOutputColumns(Lists.newArrayList(mvScanPlans.second, mvQueryPlans.second))
.setChildOutputColumns(Lists.newArrayList(mvScanPlan.second, mvCompensationPlan.second))
.isUnionAll(true)
.build();
OptExpression result = OptExpression.create(unionOperator, mvScanPlans.first, mvQueryPlans.first);
OptExpression result = OptExpression.create(unionOperator, mvScanPlan.first, mvCompensationPlan.first);
deriveLogicalProperty(result);
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ public String getFragmentPlan(String sql, String traceModule) throws Exception {
return getFragmentPlan(sql, TExplainLevel.NORMAL, traceModule);
}

public String getFragmentPlan(String sql, TExplainLevel level) throws Exception {
return getFragmentPlan(sql, level, null);
}

public String getFragmentPlan(String sql, TExplainLevel level, String traceModule) throws Exception {
Pair<String, Pair<ExecPlan, String>> result =
UtFrameUtils.getFragmentPlanWithTrace(connectContext, sql, traceModule);
Expand Down
Loading

0 comments on commit 4cf0f20

Please sign in to comment.