Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-8432][CH]Remove duplicate output attributes of aggregate's child #8450

Merged
merged 4 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ case class CHHashAggregateExecTransformer(
}
}

// CH does not support duplicate columns in a block. So there should not be duplicate attributes
// in child's output.
// There is an exception case, when a shuffle result is reused, the child's output may contain
// duplicate columns. It's mismatched with the the real output of CH.
protected lazy val childOutput: Seq[Attribute] = {
child.output
}

override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
Expand Down Expand Up @@ -168,12 +176,12 @@ case class CHHashAggregateExecTransformer(
if (modes.isEmpty || modes.forall(_ == Complete)) {
// When there is no aggregate function or there is complete mode, it does not need
// to handle outputs according to the AggregateMode
for (attr <- child.output) {
for (attr <- childOutput) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}
(child.output, output)
(childOutput, output)
} else if (!modes.contains(Partial)) {
// non-partial mode
var resultAttrIndex = 0
Expand All @@ -193,13 +201,13 @@ case class CHHashAggregateExecTransformer(
(aggregateResultAttributes, output)
} else {
// partial mode
for (attr <- child.output) {
for (attr <- childOutput) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

(child.output, aggregateResultAttributes)
(childOutput, aggregateResultAttributes)
}
}

Expand Down Expand Up @@ -238,7 +246,7 @@ case class CHHashAggregateExecTransformer(
// Use 'child.output' as based Seq[Attribute], the originalInputAttributes
// may be different for each backend.
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(expr, child.output)
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
groupingList.add(exprNode)
})
Expand All @@ -258,7 +266,7 @@ case class CHHashAggregateExecTransformer(
aggExpr => {
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, child.output)
.replaceWithExpressionTransformer(aggExpr.filter.get, childOutput)
.doTransform(args)
aggFilterList.add(exprNode)
} else {
Expand All @@ -272,7 +280,7 @@ case class CHHashAggregateExecTransformer(
aggregateFunc.children.toList.map(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, child.output)
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
})
case PartialMerge if distinct_modes.contains(Partial) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.CHColumnarToRowExec
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.exchange._

/*
* CH doesn't support will for duplicate columns in the a block.
Expand Down Expand Up @@ -51,8 +53,34 @@ case class RemoveDuplicatedColumns(session: SparkSession) extends Rule[SparkPlan
}
case hashAgg: CHHashAggregateExecTransformer =>
val newChildren = hashAgg.children.map(visitPlan)
val newHashAgg = uniqueHashAggregateColumns(hashAgg)
newHashAgg.withNewChildren(newChildren)
var newHashAgg = uniqueHashAggregateColumns(hashAgg)
newHashAgg =
newHashAgg.withNewChildren(newChildren).asInstanceOf[CHHashAggregateExecTransformer]
newHashAgg.child match {
case aqeShuffleRead @ AQEShuffleReadExec(
child @ ShuffleQueryStageExec(
id,
reusedShuffle @ ReusedExchangeExec(output, shuffle: ColumnarShuffleExchangeExec),
canonicalized),
partitionSpecs) =>
if (output.length != shuffle.output.length) {
// reused exchange may remain duplicate columns in the output, even its child has
// removed the duplicate columns. In design, reused exchange's output could be
// different from its child, so we cannot use the child's output as the output of the
// reused exchange directly.
// TODO: we cannot build a UT for this case.
val uniqueOutput = uniqueExpressions(output.map(_.asInstanceOf[NamedExpression]))
.map(_.asInstanceOf[Attribute])
val newReusedShuffle = ReusedExchangeExec(uniqueOutput, shuffle)
val newChild = AQEShuffleReadExec(
ShuffleQueryStageExec(id, newReusedShuffle, canonicalized),
partitionSpecs)
newHashAgg.copy(child = newChild)
} else {
newHashAgg
}
case _ => newHashAgg
}
case _ =>
plan.withNewChildren(plan.children.map(visitPlan))
}
Expand Down
Loading