Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Aug 17, 2023
1 parent b116319 commit 3b55e58
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,7 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(CASE_WHEN, IF), ruleId) {
case a @ Aggregate(groupingExpressions, _, _)
if !groupingExpressions.forall(_.isInstanceOf[NamedExpression]) => a
case a: Aggregate if !a.groupingExpressions.forall(_.isInstanceOf[NamedExpression]) => a
case q: LogicalPlan => q.transformExpressionsUpWithPruning(
_.containsAnyPattern(CASE_WHEN, IF), ruleId) {
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1175,3 +1175,24 @@ Sort [a#x ASC NULLS FIRST], true
+- Project [a#x, b#x]
+- SubqueryAlias testData
+- LocalRelation [a#x, b#x]


-- !query
SELECT c * 2 AS d
FROM (
SELECT if(b > 1, 1, b) AS c
FROM (
SELECT if(a < 0, 0, a) AS b
FROM VALUES (-1), (1), (2) AS v(a)
GROUP BY b
) t1
GROUP BY c
) t2
-- !query analysis
Project [(c#x * 2) AS d#x]
+- SubqueryAlias t2
+- Aggregate [if ((b#x > 1)) 1 else b#x], [if ((b#x > 1)) 1 else b#x AS c#x]
+- SubqueryAlias t1
+- Aggregate [if ((a#x < 0)) 0 else a#x], [if ((a#x < 0)) 0 else a#x AS b#x]
+- SubqueryAlias v
+- LocalRelation [a#x]
13 changes: 13 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,16 @@ GROUP BY a;

SELECT mode(a), mode(b) FROM testData;
SELECT a, mode(b) FROM testData GROUP BY a ORDER BY a;


-- SPARK-44846: PushFoldableIntoBranches in complex grouping expressions cause bindReference error
SELECT c * 2 AS d
FROM (
SELECT if(b > 1, 1, b) AS c
FROM (
SELECT if(a < 0, 0, a) AS b
FROM VALUES (-1), (1), (2) AS v(a)
GROUP BY b
) t1
GROUP BY c
) t2;
18 changes: 18 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -1103,3 +1103,21 @@ NULL 1
1 1
2 1
3 1


-- !query
SELECT c * 2 AS d
FROM (
SELECT if(b > 1, 1, b) AS c
FROM (
SELECT if(a < 0, 0, a) AS b
FROM VALUES (-1), (1), (2) AS v(a)
GROUP BY b
) t1
GROUP BY c
) t2
-- !query schema
struct<d:int>
-- !query output
0
2
15 changes: 0 additions & 15 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3674,21 +3674,6 @@ class DataFrameSuite extends QueryTest
parameters = Map("viewName" -> "AUTHORIZATION"))
}
}

test("SPARK-44846: PushFoldableIntoBranches in complex grouping expressions " +
"cause bindReference error") {
withTempView("t") {
Seq(-1, 1, 2).toDF("a").createOrReplaceTempView("t")
val _sql =
"""
|select c*2 as d from
|(select if(b > 1, 1, b) as c from
|(select if(a < 0, 0 ,a) as b from t group by b) t1
|group by c) t2
|""".stripMargin
checkAnswer(sql(_sql), Seq(Row(0), Row(2)))
}
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down

0 comments on commit 3b55e58

Please sign in to comment.