Skip to content

Commit

Permalink
adjust child's output in CHHashAggregateExecTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jan 7, 2025
1 parent 85a73c4 commit 6577104
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ object CHRuleApi {
injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark))
injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark))
injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark))
injector.injectResolutionRule(spark => new RemoveUselessAttributesInDstinct(spark))
injector.injectOptimizerRule(spark => new CommonSubexpressionEliminateRule(spark))
injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark))
injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ case class CHHashAggregateExecTransformer(
}
}

// CH does not support duplicate columns in a block. So there should not be duplicate attributes
// in child's output.
// There is an exception case, when a shuffle result is reused, the child's output may contain
// duplicate columns. It's mismatched with the the real output of CH.
protected lazy val childOutput: Seq[Attribute] = {
val distinctChildOutput = child.output.distinct
if (distinctChildOutput.length != child.output.length) {
logWarning(s"Found duplicate columns in child's output: ${child.output}")
}
distinctChildOutput
}

override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
Expand Down Expand Up @@ -168,12 +180,12 @@ case class CHHashAggregateExecTransformer(
if (modes.isEmpty || modes.forall(_ == Complete)) {
// When there is no aggregate function or there is complete mode, it does not need
// to handle outputs according to the AggregateMode
for (attr <- child.output) {
for (attr <- childOutput) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}
(child.output, output)
(childOutput, output)
} else if (!modes.contains(Partial)) {
// non-partial mode
var resultAttrIndex = 0
Expand All @@ -193,13 +205,13 @@ case class CHHashAggregateExecTransformer(
(aggregateResultAttributes, output)
} else {
// partial mode
for (attr <- child.output) {
for (attr <- childOutput) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

(child.output, aggregateResultAttributes)
(childOutput, aggregateResultAttributes)
}
}

Expand Down Expand Up @@ -238,7 +250,7 @@ case class CHHashAggregateExecTransformer(
// Use 'child.output' as based Seq[Attribute], the originalInputAttributes
// may be different for each backend.
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(expr, child.output)
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
groupingList.add(exprNode)
})
Expand All @@ -258,7 +270,7 @@ case class CHHashAggregateExecTransformer(
aggExpr => {
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, child.output)
.replaceWithExpressionTransformer(aggExpr.filter.get, childOutput)
.doTransform(args)
aggFilterList.add(exprNode)
} else {
Expand All @@ -272,7 +284,7 @@ case class CHHashAggregateExecTransformer(
aggregateFunc.children.toList.map(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, child.output)
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
})
case PartialMerge if distinct_modes.contains(Partial) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,40 +590,5 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
)
sql("drop table test_8142")
}

test("GLUTEN-8432 count(distinct) contains grouping keys") {
compareResultsAgainstVanillaSpark(
s"""
|select n_regionkey, n_nationkey, count(distinct n_name, n_nationkey, n_comment) as x
|from (
| select
| n_regionkey,
| n_nationkey,
| if(n_nationkey = 0, null, n_name) as n_name,
| if(n_nationkey = 0, null, n_comment) as n_comment
| from nation
|)
|group by n_regionkey, n_nationkey
|order by n_regionkey, n_nationkey
|""".stripMargin,
true,
{ df => }
)
compareResultsAgainstVanillaSpark(
s"""
|select n_regionkey, n_nationkey, count(distinct n_nationkey) as x
|from (
| select
| n_regionkey,
| if (n_nationkey = 0, null, n_nationkey) as n_nationkey
| from nation
|)
|group by n_regionkey, n_nationkey
|order by n_regionkey, n_nationkey
|""".stripMargin,
true,
{ df => }
)
}
}
// scalastyle:off line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -407,40 +407,5 @@ class GlutenClickHouseTPCHParquetAQESuite
assert(result.length == 1)
}
}

test("GLUTEN-8432 count(distinct) contains grouping keys") {
compareResultsAgainstVanillaSpark(
s"""
|select n_regionkey, n_nationkey, count(distinct n_name, n_nationkey, n_comment) as x
|from (
| select
| n_regionkey,
| n_nationkey,
| if(n_nationkey = 0, null, n_name) as n_name,
| if(n_nationkey = 0, null, n_comment) as n_comment
| from nation
|)
|group by n_regionkey, n_nationkey
|order by n_regionkey, n_nationkey
|""".stripMargin,
true,
{ df => }
)
compareResultsAgainstVanillaSpark(
s"""
|select n_regionkey, n_nationkey, count(distinct n_nationkey) as x
|from (
| select
| n_regionkey,
| if (n_nationkey = 0, null, n_nationkey) as n_nationkey
| from nation
|)
|group by n_regionkey, n_nationkey
|order by n_regionkey, n_nationkey
|""".stripMargin,
true,
{ df => }
)
}
}
// scalastyle:off line.size.limit

0 comments on commit 6577104

Please sign in to comment.