Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing committed Jan 22, 2025
1 parent dcd4278 commit 728177f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.starrocks.catalog.Column;
import com.starrocks.common.Pair;
import com.starrocks.sql.common.UnsupportedException;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptExpressionVisitor;
Expand Down Expand Up @@ -129,28 +128,32 @@ public OptExpression visitLogicalAggregate(OptExpression optExpression, Void con
// TODO: We can do this in more normal ways rather than only mv rewrite later,
// issue: https://github.com/StarRocks/starrocks/issues/55285
if (aggregationOperator.isOpRuleBitSet(OP_MV_AGG_PRUNE_COLUMNS)) {
// project
Projection newProjection = null;
if (aggregationOperator.getProjection() != null) {
newProjection = aggregationOperator.getProjection();
newProjection = new Projection(aggregationOperator.getProjection().getColumnRefMap());
newProjection.getColumnRefMap().values().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
}
List<ColumnRefOperator> newGroupByKeys = aggregationOperator.getGroupingKeys()
// group by
final List<ColumnRefOperator> newGroupByKeys = aggregationOperator.getGroupingKeys()
.stream()
.filter(col -> requiredOutputColumns.contains(col))
.collect(Collectors.toList());
List<ColumnRefOperator> newPartitionByKeys = aggregationOperator.getPartitionByColumns()
requiredOutputColumns.union(newGroupByKeys);
// partition by
final List<ColumnRefOperator> newPartitionByKeys = aggregationOperator.getPartitionByColumns()
.stream()
.filter(col -> requiredOutputColumns.contains(col))
.collect(Collectors.toList());
Map<ColumnRefOperator, CallOperator> newAggregations = aggregationOperator.getAggregations()
requiredOutputColumns.union(newPartitionByKeys);
// aggregations
final Map<ColumnRefOperator, CallOperator> newAggregations = aggregationOperator.getAggregations()
.entrySet()
.stream()
.filter(e -> requiredOutputColumns.contains(e.getKey()))
.map(e -> {
requiredOutputColumns.union(e.getValue().getUsedColumns());
return Pair.create(e.getKey(), e.getValue());
}).collect(Collectors.toMap(e -> e.first, e -> e.second));
LogicalAggregationOperator newAggOp = new LogicalAggregationOperator.Builder()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
newAggregations.values().stream().forEach(s -> requiredOutputColumns.union(s.getUsedColumns()));
final LogicalAggregationOperator newAggOp = new LogicalAggregationOperator.Builder()
.withOperator(aggregationOperator)
.setProjection(newProjection)
.setGroupingKeys(newGroupByKeys)
Expand Down Expand Up @@ -209,8 +212,8 @@ public OptExpression visitLogicalUnion(OptExpression optExpression, Void context
}
List<List<ColumnRefOperator>> newChildOutputColumns = Lists.newArrayList();
for (int childIdx = 0; childIdx < optExpression.arity(); ++childIdx) {
List<ColumnRefOperator> childOutputCols = unionOperator.getChildOutputColumns().get(childIdx);
List<ColumnRefOperator> newChildOutputCols = newUnionOutputIdxes.stream()
final List<ColumnRefOperator> childOutputCols = unionOperator.getChildOutputColumns().get(childIdx);
final List<ColumnRefOperator> newChildOutputCols = newUnionOutputIdxes.stream()
.map(idx -> childOutputCols.get(idx))
.collect(Collectors.toList());
requiredOutputColumns.union(newChildOutputCols);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.ScanOperatorPredicates;
import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalUnionOperator;
Expand Down Expand Up @@ -188,7 +189,7 @@ public static MVCompensation getMvCompensation(OptExpression queryPlan,
* @param mvContext: materialized view context
* @return: a pair of compensated mv scan plan(refreshed partitions) and its output columns
*/
private static Pair<OptExpression, List<ColumnRefOperator>> getMvScanPlan(MaterializationContext mvContext) {
private static Pair<OptExpression, List<ColumnRefOperator>> getMVScanPlan(MaterializationContext mvContext) {
// NOTE: mv's scan operator has already been partition pruned by filtering refreshed partitions,
// see MvRewritePreprocessor#createScanMvOperator.
final LogicalOlapScanOperator mvScanOperator = mvContext.getScanMvOperator();
Expand Down Expand Up @@ -237,10 +238,18 @@ private static Pair<OptExpression, List<ColumnRefOperator>> getMVCompensationPla
deriveLogicalProperty(newMvQueryPlan);
List<ColumnRefOperator> orgMvQueryOutputColumnRefs = mvContext.getMvOutputColumnRefs();
List<ColumnRefOperator> mvQueryOutputColumnRefs = duplicator.getMappedColumns(orgMvQueryOutputColumnRefs);
// For mv rewrite, it's safe to prune aggregate columns in mv compensate plan, but it cannot determine
// required columns in the transparent rule.
int ruleBit = isMVRewrite ? OP_MV_UNION_REWRITE | OP_MV_AGG_PRUNE_COLUMNS : OP_MV_UNION_REWRITE;
newMvQueryPlan.getOp().setOpRuleBit(ruleBit);
newMvQueryPlan.getOp().setOpRuleBit(OP_MV_UNION_REWRITE);
if (isMVRewrite) {
// NOTE: mvScanPlan and mvCompensatePlan will output all columns of the mv's defined query,
// it may contain more columns than the requiredColumns.
// 1. For simple non-blocking operators(scan/join), it can be pruned by normal rules, but for
// aggregate operators, it should be handled in MVColumnPruner.
// 2. For mv rewrite, it's safe to prune aggregate columns in mv compensate plan, but it cannot determine
// required columns in the transparent rule.
List<LogicalAggregationOperator> list = Lists.newArrayList();
Utils.extractOperator(newMvQueryPlan, list, op -> op instanceof LogicalAggregationOperator);
list.stream().forEach(op -> op.setOpRuleBit(OP_MV_AGG_PRUNE_COLUMNS));
}
// Adjust query output columns to mv's output columns to make sure the output columns are the same as
// expectOutputColumns which are mv scan operator's output columns.
return adjustOptExpressionOutputColumnType(mvContext.getQueryRefFactory(),
Expand Down Expand Up @@ -281,22 +290,19 @@ public static OptExpression getMvTransparentPlan(MaterializationContext mvContex
Preconditions.checkArgument(originalOutputColumns != null);
Preconditions.checkState(mvCompensation.getState().isCompensate());

final Pair<OptExpression, List<ColumnRefOperator>> mvScanPlan = getMvScanPlan(mvContext);
final Pair<OptExpression, List<ColumnRefOperator>> mvScanPlan = getMVScanPlan(mvContext);
if (mvScanPlan == null) {
logMVRewrite(mvContext, "Get mv scan transparent plan failed");
return null;
}

Pair<OptExpression, List<ColumnRefOperator>> mvCompensationPlan = getMVCompensationPlan(mvContext,
final Pair<OptExpression, List<ColumnRefOperator>> mvCompensationPlan = getMVCompensationPlan(mvContext,
mvCompensation, originalOutputColumns, isMVRewrite);
if (mvCompensationPlan == null) {
logMVRewrite(mvContext, "Get mv query transparent plan failed");
return null;
}
// NOTE: mvScanPlan and mvCompensatePlan will output all columns of the mv's defined query,
// it may contain more columns than the requiredColumns.
// For simple non-blocking operators(scan/join), it can be pruned by normal rules, but for
// aggregate operators, it should be handled in MVColumnPruner.

LogicalUnionOperator unionOperator = new LogicalUnionOperator.Builder()
.setOutputColumnRefOp(originalOutputColumns)
.setChildOutputColumns(Lists.newArrayList(mvScanPlan.second, mvCompensationPlan.second))
Expand Down

0 comments on commit 728177f

Please sign in to comment.