From 35fe839e27140a6f19b771e857fe6a45d3a88f72 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 7 Jan 2025 11:50:39 +0800 Subject: [PATCH 1/4] Try to remove grouping keys in arguments of count(distinct) --- .../backendsapi/clickhouse/CHRuleApi.scala | 1 + .../RemoveUselessAttributesInDstinct.scala | 88 +++++++++++++++++++ .../execution/GlutenClickHouseTPCHSuite.scala | 35 ++++++++ .../GlutenClickHouseTPCHParquetAQESuite.scala | 35 ++++++++ 4 files changed, 159 insertions(+) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveUselessAttributesInDstinct.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 21ae342a2263..36b797fb9f45 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -63,6 +63,7 @@ object CHRuleApi { injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) + injector.injectResolutionRule(spark => new RemoveUselessAttributesInDstinct(spark)) injector.injectOptimizerRule(spark => new CommonSubexpressionEliminateRule(spark)) injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark)) injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveUselessAttributesInDstinct.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveUselessAttributesInDstinct.scala new file mode 100644 index 000000000000..4b624a82e417 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveUselessAttributesInDstinct.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + +// To simplify `count(distinct a, b, ...) group by x,b,...` to +// `count(distinct a,...) group by x,b,...`. `b` in the aggregate function `count(distinct)` is +// useless. +// We do this for +// 1. There is no need to include grouping keys in arguments of count(distinct) +// 2. It introduces duplicate columns in CH, and CH doesn't support duplicate columns in a block. +// 3. When `reusedExchange` is enabled, it will cause schema mismatch. +class RemoveUselessAttributesInDstinct(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (plan.resolved) { + visitPlan(plan) + } else { + visitPlan(plan) + } + } + + def visitPlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case Aggregate(groupingExpressions, aggregateExpressions, child) => + val newAggregateExpressions = + aggregateExpressions.map( + visitExpression(groupingExpressions, _).asInstanceOf[NamedExpression]) + Aggregate(groupingExpressions, newAggregateExpressions, visitPlan(child)) + case other => + other.withNewChildren(other.children.map(visitPlan)) + } + } + + def visitExpression(groupingExpressions: Seq[Expression], expr: Expression): Expression = { + expr match { + case agg: AggregateExpression => + if (agg.isDistinct && agg.aggregateFunction.isInstanceOf[Count]) { + val newChildren = agg.aggregateFunction.children.filterNot { + child => groupingExpressions.contains(child) + } + // Cannot remove all children in count(distinct). + if (newChildren.isEmpty) { + agg + } else { + val newCount = Count(newChildren) + agg.copy(aggregateFunction = newCount) + } + } else { + agg + } + case fun: UnresolvedFunction => + if (fun.nameParts.mkString(".") == "count" && fun.isDistinct) { + val newArguemtns = fun.arguments.filterNot(arg => groupingExpressions.contains(arg)) + if (newArguemtns.isEmpty) { + fun + } else { + fun.copy(arguments = newArguemtns) + } + } else { + fun + } + case other => + other.withNewChildren(other.children.map(visitExpression(groupingExpressions, _))) + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index 65a01dea3073..9d1ea8f2d0ad 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -590,5 +590,40 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { ) sql("drop table test_8142") } + + test("GLUTEN-8432 count(distinct) contains grouping keys") { + compareResultsAgainstVanillaSpark( + s""" + |select n_regionkey, n_nationkey, count(distinct n_name, n_nationkey, n_comment) as x + |from ( + | select + | n_regionkey, + | n_nationkey, + | if(n_nationkey = 0, null, n_name) as n_name, + | if(n_nationkey = 0, null, n_comment) as n_comment + | from nation + |) + |group by n_regionkey, n_nationkey + |order by n_regionkey, n_nationkey + |""".stripMargin, + true, + { df => } + ) + compareResultsAgainstVanillaSpark( + s""" + |select n_regionkey, n_nationkey, count(distinct n_nationkey) as x + |from ( + | select + | n_regionkey, + | if (n_nationkey = 0, null, n_nationkey) as n_nationkey + | from nation + |) + |group by n_regionkey, n_nationkey + |order by n_regionkey, n_nationkey + |""".stripMargin, + true, + { df => } + ) + } } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetAQESuite.scala index 1c627140b694..98144e41eba7 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetAQESuite.scala @@ -407,5 +407,40 @@ class GlutenClickHouseTPCHParquetAQESuite assert(result.length == 1) } } + + test("GLUTEN-8432 count(distinct) contains grouping keys") { + compareResultsAgainstVanillaSpark( + s""" + |select n_regionkey, n_nationkey, count(distinct n_name, n_nationkey, n_comment) as x + |from ( + | select + | n_regionkey, + | n_nationkey, + | if(n_nationkey = 0, null, n_name) as n_name, + | if(n_nationkey = 0, null, n_comment) as n_comment + | from nation + |) + |group by n_regionkey, n_nationkey + |order by n_regionkey, n_nationkey + |""".stripMargin, + true, + { df => } + ) + compareResultsAgainstVanillaSpark( + s""" + |select n_regionkey, n_nationkey, count(distinct n_nationkey) as x + |from ( + | select + | n_regionkey, + | if (n_nationkey = 0, null, n_nationkey) as n_nationkey + | from nation + |) + |group by n_regionkey, n_nationkey + |order by n_regionkey, n_nationkey + |""".stripMargin, + true, + { df => } + ) + } } // scalastyle:off line.size.limit From 07808ba2e78e1d5c9efbea68cfd458a45c9cf727 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 7 Jan 2025 17:01:42 +0800 Subject: [PATCH 2/4] adjust child's output in CHHashAggregateExecTransformer --- .../backendsapi/clickhouse/CHRuleApi.scala | 1 - .../CHHashAggregateExecTransformer.scala | 26 ++++-- .../RemoveUselessAttributesInDstinct.scala | 88 ------------------- .../execution/GlutenClickHouseTPCHSuite.scala | 35 -------- .../GlutenClickHouseTPCHParquetAQESuite.scala | 35 -------- 5 files changed, 19 insertions(+), 166 deletions(-) delete mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveUselessAttributesInDstinct.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 36b797fb9f45..21ae342a2263 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -63,7 +63,6 @@ object CHRuleApi { injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) - injector.injectResolutionRule(spark => new RemoveUselessAttributesInDstinct(spark)) injector.injectOptimizerRule(spark => new CommonSubexpressionEliminateRule(spark)) injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark)) injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 48b0d7336103..eada5ae4381d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -136,6 +136,18 @@ case class CHHashAggregateExecTransformer( } } + // CH does not support duplicate columns in a block. So there should not be duplicate attributes + // in child's output. + // There is an exception case, when a shuffle result is reused, the child's output may contain + // duplicate columns. It's mismatched with the the real output of CH. + protected lazy val childOutput: Seq[Attribute] = { + val distinctChildOutput = child.output.distinct + if (distinctChildOutput.length != child.output.length) { + logWarning(s"Found duplicate columns in child's output: ${child.output}") + } + distinctChildOutput + } + override protected def doTransform(context: SubstraitContext): TransformContext = { val childCtx = child.asInstanceOf[TransformSupport].transform(context) val operatorId = context.nextOperatorId(this.nodeName) @@ -168,12 +180,12 @@ case class CHHashAggregateExecTransformer( if (modes.isEmpty || modes.forall(_ == Complete)) { // When there is no aggregate function or there is complete mode, it does not need // to handle outputs according to the AggregateMode - for (attr <- child.output) { + for (attr <- childOutput) { typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) nameList.add(ConverterUtils.genColumnNameWithExprId(attr)) nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType)) } - (child.output, output) + (childOutput, output) } else if (!modes.contains(Partial)) { // non-partial mode var resultAttrIndex = 0 @@ -193,13 +205,13 @@ case class CHHashAggregateExecTransformer( (aggregateResultAttributes, output) } else { // partial mode - for (attr <- child.output) { + for (attr <- childOutput) { typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) nameList.add(ConverterUtils.genColumnNameWithExprId(attr)) nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType)) } - (child.output, aggregateResultAttributes) + (childOutput, aggregateResultAttributes) } } @@ -238,7 +250,7 @@ case class CHHashAggregateExecTransformer( // Use 'child.output' as based Seq[Attribute], the originalInputAttributes // may be different for each backend. val exprNode = ExpressionConverter - .replaceWithExpressionTransformer(expr, child.output) + .replaceWithExpressionTransformer(expr, childOutput) .doTransform(args) groupingList.add(exprNode) }) @@ -258,7 +270,7 @@ case class CHHashAggregateExecTransformer( aggExpr => { if (aggExpr.filter.isDefined) { val exprNode = ExpressionConverter - .replaceWithExpressionTransformer(aggExpr.filter.get, child.output) + .replaceWithExpressionTransformer(aggExpr.filter.get, childOutput) .doTransform(args) aggFilterList.add(exprNode) } else { @@ -272,7 +284,7 @@ case class CHHashAggregateExecTransformer( aggregateFunc.children.toList.map( expr => { ExpressionConverter - .replaceWithExpressionTransformer(expr, child.output) + .replaceWithExpressionTransformer(expr, childOutput) .doTransform(args) }) case PartialMerge if distinct_modes.contains(Partial) => diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveUselessAttributesInDstinct.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveUselessAttributesInDstinct.scala deleted file mode 100644 index 4b624a82e417..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveUselessAttributesInDstinct.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule - -// To simplify `count(distinct a, b, ...) group by x,b,...` to -// `count(distinct a,...) group by x,b,...`. `b` in the aggregate function `count(distinct)` is -// useless. -// We do this for -// 1. There is no need to include grouping keys in arguments of count(distinct) -// 2. It introduces duplicate columns in CH, and CH doesn't support duplicate columns in a block. -// 3. When `reusedExchange` is enabled, it will cause schema mismatch. -class RemoveUselessAttributesInDstinct(spark: SparkSession) extends Rule[LogicalPlan] with Logging { - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (plan.resolved) { - visitPlan(plan) - } else { - visitPlan(plan) - } - } - - def visitPlan(plan: LogicalPlan): LogicalPlan = { - plan match { - case Aggregate(groupingExpressions, aggregateExpressions, child) => - val newAggregateExpressions = - aggregateExpressions.map( - visitExpression(groupingExpressions, _).asInstanceOf[NamedExpression]) - Aggregate(groupingExpressions, newAggregateExpressions, visitPlan(child)) - case other => - other.withNewChildren(other.children.map(visitPlan)) - } - } - - def visitExpression(groupingExpressions: Seq[Expression], expr: Expression): Expression = { - expr match { - case agg: AggregateExpression => - if (agg.isDistinct && agg.aggregateFunction.isInstanceOf[Count]) { - val newChildren = agg.aggregateFunction.children.filterNot { - child => groupingExpressions.contains(child) - } - // Cannot remove all children in count(distinct). - if (newChildren.isEmpty) { - agg - } else { - val newCount = Count(newChildren) - agg.copy(aggregateFunction = newCount) - } - } else { - agg - } - case fun: UnresolvedFunction => - if (fun.nameParts.mkString(".") == "count" && fun.isDistinct) { - val newArguemtns = fun.arguments.filterNot(arg => groupingExpressions.contains(arg)) - if (newArguemtns.isEmpty) { - fun - } else { - fun.copy(arguments = newArguemtns) - } - } else { - fun - } - case other => - other.withNewChildren(other.children.map(visitExpression(groupingExpressions, _))) - } - } -} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index 9d1ea8f2d0ad..65a01dea3073 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -590,40 +590,5 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { ) sql("drop table test_8142") } - - test("GLUTEN-8432 count(distinct) contains grouping keys") { - compareResultsAgainstVanillaSpark( - s""" - |select n_regionkey, n_nationkey, count(distinct n_name, n_nationkey, n_comment) as x - |from ( - | select - | n_regionkey, - | n_nationkey, - | if(n_nationkey = 0, null, n_name) as n_name, - | if(n_nationkey = 0, null, n_comment) as n_comment - | from nation - |) - |group by n_regionkey, n_nationkey - |order by n_regionkey, n_nationkey - |""".stripMargin, - true, - { df => } - ) - compareResultsAgainstVanillaSpark( - s""" - |select n_regionkey, n_nationkey, count(distinct n_nationkey) as x - |from ( - | select - | n_regionkey, - | if (n_nationkey = 0, null, n_nationkey) as n_nationkey - | from nation - |) - |group by n_regionkey, n_nationkey - |order by n_regionkey, n_nationkey - |""".stripMargin, - true, - { df => } - ) - } } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetAQESuite.scala index 98144e41eba7..1c627140b694 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetAQESuite.scala @@ -407,40 +407,5 @@ class GlutenClickHouseTPCHParquetAQESuite assert(result.length == 1) } } - - test("GLUTEN-8432 count(distinct) contains grouping keys") { - compareResultsAgainstVanillaSpark( - s""" - |select n_regionkey, n_nationkey, count(distinct n_name, n_nationkey, n_comment) as x - |from ( - | select - | n_regionkey, - | n_nationkey, - | if(n_nationkey = 0, null, n_name) as n_name, - | if(n_nationkey = 0, null, n_comment) as n_comment - | from nation - |) - |group by n_regionkey, n_nationkey - |order by n_regionkey, n_nationkey - |""".stripMargin, - true, - { df => } - ) - compareResultsAgainstVanillaSpark( - s""" - |select n_regionkey, n_nationkey, count(distinct n_nationkey) as x - |from ( - | select - | n_regionkey, - | if (n_nationkey = 0, null, n_nationkey) as n_nationkey - | from nation - |) - |group by n_regionkey, n_nationkey - |order by n_regionkey, n_nationkey - |""".stripMargin, - true, - { df => } - ) - } } // scalastyle:off line.size.limit From d83f15a49acc9ee78e0c1a2df6a8c830a489e230 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Wed, 15 Jan 2025 10:12:04 +0800 Subject: [PATCH 3/4] update --- .../gluten/execution/CHHashAggregateExecTransformer.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index eada5ae4381d..40381a60567a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -143,9 +143,10 @@ case class CHHashAggregateExecTransformer( protected lazy val childOutput: Seq[Attribute] = { val distinctChildOutput = child.output.distinct if (distinctChildOutput.length != child.output.length) { - logWarning(s"Found duplicate columns in child's output: ${child.output}") + logWarning(s"Found duplicate columns in child's output: ${child.output}\n$child") } - distinctChildOutput + // distinctChildOutput + child.output } override protected def doTransform(context: SubstraitContext): TransformContext = { From 9b4c0af9152a0e749a212aa7e12bb34a1a793171 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 16 Jan 2025 16:22:38 +0800 Subject: [PATCH 4/4] update --- .../CHHashAggregateExecTransformer.scala | 5 --- .../extension/RemoveDuplicatedColumns.scala | 32 +++++++++++++++++-- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 40381a60567a..a3f97492927f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -141,11 +141,6 @@ case class CHHashAggregateExecTransformer( // There is an exception case, when a shuffle result is reused, the child's output may contain // duplicate columns. It's mismatched with the the real output of CH. protected lazy val childOutput: Seq[Attribute] = { - val distinctChildOutput = child.output.distinct - if (distinctChildOutput.length != child.output.length) { - logWarning(s"Found duplicate columns in child's output: ${child.output}\n$child") - } - // distinctChildOutput child.output } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala index 7f378b5a41a0..b5ed6a861360 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.CHColumnarToRowExec +import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.exchange._ /* * CH doesn't support will for duplicate columns in the a block. @@ -51,8 +53,34 @@ case class RemoveDuplicatedColumns(session: SparkSession) extends Rule[SparkPlan } case hashAgg: CHHashAggregateExecTransformer => val newChildren = hashAgg.children.map(visitPlan) - val newHashAgg = uniqueHashAggregateColumns(hashAgg) - newHashAgg.withNewChildren(newChildren) + var newHashAgg = uniqueHashAggregateColumns(hashAgg) + newHashAgg = + newHashAgg.withNewChildren(newChildren).asInstanceOf[CHHashAggregateExecTransformer] + newHashAgg.child match { + case aqeShuffleRead @ AQEShuffleReadExec( + child @ ShuffleQueryStageExec( + id, + reusedShuffle @ ReusedExchangeExec(output, shuffle: ColumnarShuffleExchangeExec), + canonicalized), + partitionSpecs) => + if (output.length != shuffle.output.length) { + // reused exchange may remain duplicate columns in the output, even its child has + // removed the duplicate columns. In design, reused exchange's output could be + // different from its child, so we cannot use the child's output as the output of the + // reused exchange directly. + // TODO: we cannot build a UT for this case. + val uniqueOutput = uniqueExpressions(output.map(_.asInstanceOf[NamedExpression])) + .map(_.asInstanceOf[Attribute]) + val newReusedShuffle = ReusedExchangeExec(uniqueOutput, shuffle) + val newChild = AQEShuffleReadExec( + ShuffleQueryStageExec(id, newReusedShuffle, canonicalized), + partitionSpecs) + newHashAgg.copy(child = newChild) + } else { + newHashAgg + } + case _ => newHashAgg + } case _ => plan.withNewChildren(plan.children.map(visitPlan)) }