From 17d017ea454ac3740a79fef2d542e7a77166929d Mon Sep 17 00:00:00 2001 From: lgbo Date: Fri, 17 Jan 2025 08:54:53 +0800 Subject: [PATCH] [GLUTEN-8432][CH]Remove duplicate output attributes of aggregate's child (#8450) * Try to remove grouping keys in arguments of count(distinct) * adjust child's output in CHHashAggregateExecTransformer * update * update --- .../CHHashAggregateExecTransformer.scala | 22 +++++++++---- .../extension/RemoveDuplicatedColumns.scala | 32 +++++++++++++++++-- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 48b0d7336103..a3f97492927f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -136,6 +136,14 @@ case class CHHashAggregateExecTransformer( } } + // CH does not support duplicate columns in a block. So there should not be duplicate attributes + // in child's output. + // There is an exception case, when a shuffle result is reused, the child's output may contain + // duplicate columns. It's mismatched with the the real output of CH. + protected lazy val childOutput: Seq[Attribute] = { + child.output + } + override protected def doTransform(context: SubstraitContext): TransformContext = { val childCtx = child.asInstanceOf[TransformSupport].transform(context) val operatorId = context.nextOperatorId(this.nodeName) @@ -168,12 +176,12 @@ case class CHHashAggregateExecTransformer( if (modes.isEmpty || modes.forall(_ == Complete)) { // When there is no aggregate function or there is complete mode, it does not need // to handle outputs according to the AggregateMode - for (attr <- child.output) { + for (attr <- childOutput) { typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) nameList.add(ConverterUtils.genColumnNameWithExprId(attr)) nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType)) } - (child.output, output) + (childOutput, output) } else if (!modes.contains(Partial)) { // non-partial mode var resultAttrIndex = 0 @@ -193,13 +201,13 @@ case class CHHashAggregateExecTransformer( (aggregateResultAttributes, output) } else { // partial mode - for (attr <- child.output) { + for (attr <- childOutput) { typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) nameList.add(ConverterUtils.genColumnNameWithExprId(attr)) nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType)) } - (child.output, aggregateResultAttributes) + (childOutput, aggregateResultAttributes) } } @@ -238,7 +246,7 @@ case class CHHashAggregateExecTransformer( // Use 'child.output' as based Seq[Attribute], the originalInputAttributes // may be different for each backend. val exprNode = ExpressionConverter - .replaceWithExpressionTransformer(expr, child.output) + .replaceWithExpressionTransformer(expr, childOutput) .doTransform(args) groupingList.add(exprNode) }) @@ -258,7 +266,7 @@ case class CHHashAggregateExecTransformer( aggExpr => { if (aggExpr.filter.isDefined) { val exprNode = ExpressionConverter - .replaceWithExpressionTransformer(aggExpr.filter.get, child.output) + .replaceWithExpressionTransformer(aggExpr.filter.get, childOutput) .doTransform(args) aggFilterList.add(exprNode) } else { @@ -272,7 +280,7 @@ case class CHHashAggregateExecTransformer( aggregateFunc.children.toList.map( expr => { ExpressionConverter - .replaceWithExpressionTransformer(expr, child.output) + .replaceWithExpressionTransformer(expr, childOutput) .doTransform(args) }) case PartialMerge if distinct_modes.contains(Partial) => diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala index 7f378b5a41a0..b5ed6a861360 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.CHColumnarToRowExec +import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.exchange._ /* * CH doesn't support will for duplicate columns in the a block. @@ -51,8 +53,34 @@ case class RemoveDuplicatedColumns(session: SparkSession) extends Rule[SparkPlan } case hashAgg: CHHashAggregateExecTransformer => val newChildren = hashAgg.children.map(visitPlan) - val newHashAgg = uniqueHashAggregateColumns(hashAgg) - newHashAgg.withNewChildren(newChildren) + var newHashAgg = uniqueHashAggregateColumns(hashAgg) + newHashAgg = + newHashAgg.withNewChildren(newChildren).asInstanceOf[CHHashAggregateExecTransformer] + newHashAgg.child match { + case aqeShuffleRead @ AQEShuffleReadExec( + child @ ShuffleQueryStageExec( + id, + reusedShuffle @ ReusedExchangeExec(output, shuffle: ColumnarShuffleExchangeExec), + canonicalized), + partitionSpecs) => + if (output.length != shuffle.output.length) { + // reused exchange may remain duplicate columns in the output, even its child has + // removed the duplicate columns. In design, reused exchange's output could be + // different from its child, so we cannot use the child's output as the output of the + // reused exchange directly. + // TODO: we cannot build a UT for this case. + val uniqueOutput = uniqueExpressions(output.map(_.asInstanceOf[NamedExpression])) + .map(_.asInstanceOf[Attribute]) + val newReusedShuffle = ReusedExchangeExec(uniqueOutput, shuffle) + val newChild = AQEShuffleReadExec( + ShuffleQueryStageExec(id, newReusedShuffle, canonicalized), + partitionSpecs) + newHashAgg.copy(child = newChild) + } else { + newHashAgg + } + case _ => newHashAgg + } case _ => plan.withNewChildren(plan.children.map(visitPlan)) }