Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Prune aggregate non-required columns after mv transparent rewrite #55286

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ public class OpRuleBit {
public static final int OP_MV_TRANSPARENT_REWRITE = 2;
// Operator has been partition pruned or not, if partition pruned, no need to prune again.
public static final int OP_PARTITION_PRUNED = 3;
// Operator has been mv transparent union rewrite and needs to prune agg columns.
public static final int OP_MV_AGG_PRUNE_COLUMNS = 4;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.QueryMaterializationContext;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
Expand Down Expand Up @@ -217,7 +218,7 @@ private OptExpression redirectToMVDefinedQuery(OptimizerContext context,
/**
* Get transparent plan if possible.
* What's the transparent plan?
* see {@link MvPartitionCompensator#getMvTransparentPlan(MaterializationContext, MVCompensation, List)
* see {@link MvPartitionCompensator#getMvTransparentPlan(MaterializationContext, MVCompensation, List, ColumnRefSet)}
*/
private OptExpression getMvTransparentPlan(MaterializationContext mvContext,
OptExpression input,
Expand Down Expand Up @@ -247,7 +248,7 @@ private OptExpression getMvTransparentPlan(MaterializationContext mvContext,
return null;
}
OptExpression transparentPlan = MvPartitionCompensator.getMvTransparentPlan(mvContext, mvCompensation,
expectOutputColumns);
expectOutputColumns, false);
return transparentPlan;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

package com.starrocks.sql.optimizer.rule.transformation.materialization;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
Expand Down Expand Up @@ -42,10 +41,13 @@
import java.util.Set;
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.operator.OpRuleBit.OP_MV_AGG_PRUNE_COLUMNS;

public class MVColumnPruner {
private ColumnRefSet requiredOutputColumns;

public OptExpression pruneColumns(OptExpression queryExpression) {
public OptExpression pruneColumns(OptExpression queryExpression, ColumnRefSet requiredOutputColumns) {
this.requiredOutputColumns = requiredOutputColumns;
if (queryExpression.getOp() instanceof LogicalFilterOperator) {
OptExpression newQueryOptExpression = doPruneColumns(queryExpression.inputAt(0));
Operator filterOp = queryExpression.getOp();
Expand All @@ -58,13 +60,13 @@ public OptExpression pruneColumns(OptExpression queryExpression) {
}

public OptExpression doPruneColumns(OptExpression optExpression) {
// TODO: remove this check after we support more operators.
Projection projection = optExpression.getOp().getProjection();
// OptExpression after mv rewrite must have projection.
if (projection == null) {
return optExpression;
}
Preconditions.checkState(projection != null);
requiredOutputColumns = new ColumnRefSet(projection.getOutputColumns());
// OptExpression after mv rewrite must have projection.
return optExpression.getOp().accept(new ColumnPruneVisitor(), optExpression, null);
}

Expand Down Expand Up @@ -124,20 +126,60 @@ public OptExpression visitLogicalTableScan(OptExpression optExpression, Void con

public OptExpression visitLogicalAggregate(OptExpression optExpression, Void context) {
LogicalAggregationOperator aggregationOperator = (LogicalAggregationOperator) optExpression.getOp();
if (aggregationOperator.getProjection() != null) {
Projection projection = aggregationOperator.getProjection();
projection.getColumnRefMap().values().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
}
if (aggregationOperator.getPredicate() != null) {
requiredOutputColumns.union(Utils.extractColumnRef(aggregationOperator.getPredicate()));
}
requiredOutputColumns.union(aggregationOperator.getGroupingKeys());
for (Map.Entry<ColumnRefOperator, CallOperator> entry : aggregationOperator.getAggregations().entrySet()) {
requiredOutputColumns.union(entry.getKey());
requiredOutputColumns.union(Utils.extractColumnRef(entry.getValue()));
// It's safe to prune columns if the aggregation operator has been rewritten by mv since the rewritten
// mv plan should be rollup from the original plan.
// TODO: We can do this in more normal ways rather than only mv rewrite later,
// issue: https://github.com/StarRocks/starrocks/issues/55285
if (aggregationOperator.isOpRuleBitSet(OP_MV_AGG_PRUNE_COLUMNS)) {
// project
Projection newProjection = null;
if (aggregationOperator.getProjection() != null) {
newProjection = new Projection(aggregationOperator.getProjection().getColumnRefMap());
newProjection.getColumnRefMap().values().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
}
// group by
final List<ColumnRefOperator> newGroupByKeys = aggregationOperator.getGroupingKeys()
.stream()
.filter(col -> requiredOutputColumns.contains(col))
.collect(Collectors.toList());
requiredOutputColumns.union(newGroupByKeys);
// partition by
final List<ColumnRefOperator> newPartitionByKeys = aggregationOperator.getPartitionByColumns()
.stream()
.filter(col -> requiredOutputColumns.contains(col))
.collect(Collectors.toList());
requiredOutputColumns.union(newPartitionByKeys);
// aggregations
final Map<ColumnRefOperator, CallOperator> newAggregations = aggregationOperator.getAggregations()
.entrySet()
.stream()
.filter(e -> requiredOutputColumns.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
newAggregations.values().stream().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
final LogicalAggregationOperator newAggOp = new LogicalAggregationOperator.Builder()
.withOperator(aggregationOperator)
.setProjection(newProjection)
.setGroupingKeys(newGroupByKeys)
.setPartitionByColumns(newPartitionByKeys)
.setAggregations(newAggregations)
.build();
return OptExpression.create(newAggOp, visitChildren(optExpression));
} else {
if (aggregationOperator.getProjection() != null) {
Projection projection = aggregationOperator.getProjection();
projection.getColumnRefMap().values().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
}
requiredOutputColumns.union(aggregationOperator.getGroupingKeys());
for (Map.Entry<ColumnRefOperator, CallOperator> entry : aggregationOperator.getAggregations().entrySet()) {
requiredOutputColumns.union(entry.getKey());
requiredOutputColumns.union(Utils.extractColumnRef(entry.getValue()));
}
List<OptExpression> children = visitChildren(optExpression);
return OptExpression.create(aggregationOperator, children);
}
List<OptExpression> children = visitChildren(optExpression);
return OptExpression.create(aggregationOperator, children);
}

public OptExpression visitLogicalUnion(OptExpression optExpression, Void context) {
LiShuMing marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -174,14 +216,11 @@ public OptExpression visitLogicalUnion(OptExpression optExpression, Void context
}
List<List<ColumnRefOperator>> newChildOutputColumns = Lists.newArrayList();
for (int childIdx = 0; childIdx < optExpression.arity(); ++childIdx) {
List<ColumnRefOperator> childOutputCols = unionOperator.getChildOutputColumns().get(childIdx);
List<ColumnRefOperator> newChildOutputCols = Lists.newArrayList();
newUnionOutputIdxes.stream()
final List<ColumnRefOperator> childOutputCols = unionOperator.getChildOutputColumns().get(childIdx);
final List<ColumnRefOperator> newChildOutputCols = newUnionOutputIdxes.stream()
.map(idx -> childOutputCols.get(idx))
.forEach(x -> {
requiredOutputColumns.union(x);
newChildOutputCols.add(x);
});
.collect(Collectors.toList());
requiredOutputColumns.union(newChildOutputCols);
newChildOutputColumns.add(newChildOutputCols);
}
LogicalUnionOperator newUnionOperator = new LogicalUnionOperator(newUnionOutputColRefs, newChildOutputColumns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,8 @@ public OptExpression postRewrite(OptimizerContext optimizerContext,
if (candidate == null) {
return null;
}
candidate = new MVColumnPruner().pruneColumns(candidate);
final ColumnRefSet requiredOutputColumns = optimizerContext.getTaskContext().getRequiredColumns();
candidate = new MVColumnPruner().pruneColumns(candidate, requiredOutputColumns);
candidate = new MVPartitionPruner(optimizerContext, mvRewriteContext).prunePartition(candidate);
return candidate;
}
Expand Down Expand Up @@ -1270,7 +1271,7 @@ private OptExpression buildMVScanOptExpression(MaterializationContext materializ
final List<ColumnRefOperator> originalOutputColumns = MvUtils.getMvScanOutputColumnRefs(mv, mvScanOperator);
// build mv scan opt expression with or without compensate
final OptExpression mvScanOptExpression = mvCompensation.isTransparentRewrite() ?
getMvTransparentPlan(materializationContext, mvCompensation, originalOutputColumns) :
getMvTransparentPlan(materializationContext, mvCompensation, originalOutputColumns, true) :
getMVScanPlanWithoutCompensate(rewriteContext, columnRewriter, mvColumnRefToScalarOp);
return mvScanOptExpression;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.ScanOperatorPredicates;
import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalUnionOperator;
Expand Down Expand Up @@ -88,6 +89,7 @@
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.OptimizerTraceUtil.logMVRewrite;
import static com.starrocks.sql.optimizer.operator.OpRuleBit.OP_MV_AGG_PRUNE_COLUMNS;
import static com.starrocks.sql.optimizer.operator.OpRuleBit.OP_MV_UNION_REWRITE;
import static com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils.deriveLogicalProperty;
import static com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils.mergeRanges;
Expand Down Expand Up @@ -187,7 +189,7 @@ public static MVCompensation getMvCompensation(OptExpression queryPlan,
* @param mvContext: materialized view context
* @return: a pair of compensated mv scan plan(refreshed partitions) and its output columns
*/
private static Pair<OptExpression, List<ColumnRefOperator>> getMvScanPlan(MaterializationContext mvContext) {
private static Pair<OptExpression, List<ColumnRefOperator>> getMVScanPlan(MaterializationContext mvContext) {
// NOTE: mv's scan operator has already been partition pruned by filtering refreshed partitions,
// see MvRewritePreprocessor#createScanMvOperator.
final LogicalOlapScanOperator mvScanOperator = mvContext.getScanMvOperator();
Expand Down Expand Up @@ -216,8 +218,11 @@ private static Pair<OptExpression, List<ColumnRefOperator>> getMvScanPlan(Materi
* @param mvCompensation: materialized view's compensation info
* @return: a pair of compensated mv query scan plan(to-refreshed partitions) and its output columns
*/
private static Pair<OptExpression, List<ColumnRefOperator>> getMvQueryPlan(MaterializationContext mvContext,
MVCompensation mvCompensation) {
private static Pair<OptExpression, List<ColumnRefOperator>> getMVCompensationPlan(
MaterializationContext mvContext,
MVCompensation mvCompensation,
List<ColumnRefOperator> originalOutputColumns,
boolean isMVRewrite) {
final OptExpression mvQueryPlan = mvContext.getMvExpression();
OptExpression compensateMvQueryPlan = getMvCompensateQueryPlan(mvContext, mvCompensation, mvQueryPlan);
if (compensateMvQueryPlan == null) {
Expand All @@ -234,7 +239,21 @@ private static Pair<OptExpression, List<ColumnRefOperator>> getMvQueryPlan(Mater
List<ColumnRefOperator> orgMvQueryOutputColumnRefs = mvContext.getMvOutputColumnRefs();
List<ColumnRefOperator> mvQueryOutputColumnRefs = duplicator.getMappedColumns(orgMvQueryOutputColumnRefs);
newMvQueryPlan.getOp().setOpRuleBit(OP_MV_UNION_REWRITE);
return Pair.create(newMvQueryPlan, mvQueryOutputColumnRefs);
if (isMVRewrite) {
// NOTE: mvScanPlan and mvCompensatePlan will output all columns of the mv's defined query,
// it may contain more columns than the requiredColumns.
// 1. For simple non-blocking operators(scan/join), it can be pruned by normal rules, but for
// aggregate operators, it should be handled in MVColumnPruner.
// 2. For mv rewrite, it's safe to prune aggregate columns in mv compensate plan, but it cannot determine
// required columns in the transparent rule.
List<LogicalAggregationOperator> list = Lists.newArrayList();
Utils.extractOperator(newMvQueryPlan, list, op -> op instanceof LogicalAggregationOperator);
list.stream().forEach(op -> op.setOpRuleBit(OP_MV_AGG_PRUNE_COLUMNS));
}
// Adjust query output columns to mv's output columns to make sure the output columns are the same as
// expectOutputColumns which are mv scan operator's output columns.
return adjustOptExpressionOutputColumnType(mvContext.getQueryRefFactory(),
newMvQueryPlan, mvQueryOutputColumnRefs, originalOutputColumns);
}

public static OptExpression getMvCompensateQueryPlan(MaterializationContext mvContext,
Expand Down Expand Up @@ -266,32 +285,30 @@ public static OptExpression getMvCompensateQueryPlan(MaterializationContext mvCo
*/
public static OptExpression getMvTransparentPlan(MaterializationContext mvContext,
MVCompensation mvCompensation,
List<ColumnRefOperator> originalOutputColumns) {
List<ColumnRefOperator> originalOutputColumns,
boolean isMVRewrite) {
Preconditions.checkArgument(originalOutputColumns != null);
Preconditions.checkState(mvCompensation.getState().isCompensate());

Pair<OptExpression, List<ColumnRefOperator>> mvScanPlans = getMvScanPlan(mvContext);
if (mvScanPlans == null) {
final Pair<OptExpression, List<ColumnRefOperator>> mvScanPlan = getMVScanPlan(mvContext);
if (mvScanPlan == null) {
logMVRewrite(mvContext, "Get mv scan transparent plan failed");
return null;
}

Pair<OptExpression, List<ColumnRefOperator>> mvQueryPlans = getMvQueryPlan(mvContext, mvCompensation);
if (mvQueryPlans == null) {
final Pair<OptExpression, List<ColumnRefOperator>> mvCompensationPlan = getMVCompensationPlan(mvContext,
mvCompensation, originalOutputColumns, isMVRewrite);
if (mvCompensationPlan == null) {
logMVRewrite(mvContext, "Get mv query transparent plan failed");
return null;
}
// Adjust query output columns to mv's output columns to make sure the output columns are the same as
// expectOutputColumns which are mv scan operator's output columns.
mvQueryPlans = adjustOptExpressionOutputColumnType(mvContext.getQueryRefFactory(),
mvQueryPlans.first, mvQueryPlans.second, originalOutputColumns);

LogicalUnionOperator unionOperator = new LogicalUnionOperator.Builder()
.setOutputColumnRefOp(originalOutputColumns)
.setChildOutputColumns(Lists.newArrayList(mvScanPlans.second, mvQueryPlans.second))
.setChildOutputColumns(Lists.newArrayList(mvScanPlan.second, mvCompensationPlan.second))
.isUnionAll(true)
.build();
OptExpression result = OptExpression.create(unionOperator, mvScanPlans.first, mvQueryPlans.first);
OptExpression result = OptExpression.create(unionOperator, mvScanPlan.first, mvCompensationPlan.first);
deriveLogicalProperty(result);
return result;
}
LiShuMing marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ public String getFragmentPlan(String sql, String traceModule) throws Exception {
return getFragmentPlan(sql, TExplainLevel.NORMAL, traceModule);
}

public String getFragmentPlan(String sql, TExplainLevel level) throws Exception {
return getFragmentPlan(sql, level, null);
}

public String getFragmentPlan(String sql, TExplainLevel level, String traceModule) throws Exception {
Pair<String, Pair<ExecPlan, String>> result =
UtFrameUtils.getFragmentPlanWithTrace(connectContext, sql, traceModule);
Expand Down
Loading
Loading