Skip to content

Commit

Permalink
[GLUTEN-8432][CH]Remove duplicate output attributes of aggregate's ch…
Browse files Browse the repository at this point in the history
…ild (#8450)

* Try to remove grouping keys in arguments of count(distinct)

* adjust child's output in CHHashAggregateExecTransformer

* update

* update
  • Loading branch information
lgbo-ustc authored Jan 17, 2025
1 parent ac8e03a commit 17d017e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ case class CHHashAggregateExecTransformer(
}
}

// CH does not support duplicate columns in a block. So there should not be duplicate attributes
// in child's output.
// There is an exception case, when a shuffle result is reused, the child's output may contain
// duplicate columns. It's mismatched with the the real output of CH.
protected lazy val childOutput: Seq[Attribute] = {
child.output
}

override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
Expand Down Expand Up @@ -168,12 +176,12 @@ case class CHHashAggregateExecTransformer(
if (modes.isEmpty || modes.forall(_ == Complete)) {
// When there is no aggregate function or there is complete mode, it does not need
// to handle outputs according to the AggregateMode
for (attr <- child.output) {
for (attr <- childOutput) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}
(child.output, output)
(childOutput, output)
} else if (!modes.contains(Partial)) {
// non-partial mode
var resultAttrIndex = 0
Expand All @@ -193,13 +201,13 @@ case class CHHashAggregateExecTransformer(
(aggregateResultAttributes, output)
} else {
// partial mode
for (attr <- child.output) {
for (attr <- childOutput) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

(child.output, aggregateResultAttributes)
(childOutput, aggregateResultAttributes)
}
}

Expand Down Expand Up @@ -238,7 +246,7 @@ case class CHHashAggregateExecTransformer(
// Use 'child.output' as based Seq[Attribute], the originalInputAttributes
// may be different for each backend.
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(expr, child.output)
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
groupingList.add(exprNode)
})
Expand All @@ -258,7 +266,7 @@ case class CHHashAggregateExecTransformer(
aggExpr => {
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, child.output)
.replaceWithExpressionTransformer(aggExpr.filter.get, childOutput)
.doTransform(args)
aggFilterList.add(exprNode)
} else {
Expand All @@ -272,7 +280,7 @@ case class CHHashAggregateExecTransformer(
aggregateFunc.children.toList.map(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, child.output)
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
})
case PartialMerge if distinct_modes.contains(Partial) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.CHColumnarToRowExec
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.exchange._

/*
* CH doesn't support will for duplicate columns in the a block.
Expand Down Expand Up @@ -51,8 +53,34 @@ case class RemoveDuplicatedColumns(session: SparkSession) extends Rule[SparkPlan
}
case hashAgg: CHHashAggregateExecTransformer =>
val newChildren = hashAgg.children.map(visitPlan)
val newHashAgg = uniqueHashAggregateColumns(hashAgg)
newHashAgg.withNewChildren(newChildren)
var newHashAgg = uniqueHashAggregateColumns(hashAgg)
newHashAgg =
newHashAgg.withNewChildren(newChildren).asInstanceOf[CHHashAggregateExecTransformer]
newHashAgg.child match {
case aqeShuffleRead @ AQEShuffleReadExec(
child @ ShuffleQueryStageExec(
id,
reusedShuffle @ ReusedExchangeExec(output, shuffle: ColumnarShuffleExchangeExec),
canonicalized),
partitionSpecs) =>
if (output.length != shuffle.output.length) {
// reused exchange may remain duplicate columns in the output, even its child has
// removed the duplicate columns. In design, reused exchange's output could be
// different from its child, so we cannot use the child's output as the output of the
// reused exchange directly.
// TODO: we cannot build a UT for this case.
val uniqueOutput = uniqueExpressions(output.map(_.asInstanceOf[NamedExpression]))
.map(_.asInstanceOf[Attribute])
val newReusedShuffle = ReusedExchangeExec(uniqueOutput, shuffle)
val newChild = AQEShuffleReadExec(
ShuffleQueryStageExec(id, newReusedShuffle, canonicalized),
partitionSpecs)
newHashAgg.copy(child = newChild)
} else {
newHashAgg
}
case _ => newHashAgg
}
case _ =>
plan.withNewChildren(plan.children.map(visitPlan))
}
Expand Down

0 comments on commit 17d017e

Please sign in to comment.