diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala similarity index 55% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala index 45b8437bad05..5ba64360ffc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LogicalPlan, Project, Subquery, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, LeafNode, LogicalPlan, OneRowRelation, Project, Subquery, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, NO_GROUPING_AGGREGATE_REFERENCE, SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType @@ -100,7 +100,7 @@ import org.apache.spark.sql.types.DataType * : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125] * +- *(1) Scan OneRowRelation[] */ -object MergeScalarSubqueries extends Rule[LogicalPlan] { +object MergeSubplans extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { plan match { // Subquery reuse needs to be enabled for this optimization. @@ -117,26 +117,24 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { } private def extractCommonScalarSubqueries(plan: LogicalPlan) = { - // Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert references in place of - // `ScalarSubquery`s. + // Collect subplans by level into `PlanMerger`s and insert references in place of them. val planMergers = ArrayBuffer.empty[PlanMerger] - val planWithReferences = insertReferences(plan, planMergers)._1 + val planWithReferences = insertReferences(plan, true, planMergers)._1 // Traverse level by level and convert merged plans to `CTERelationDef`s and keep non-merged // ones. While traversing replace references in plans back to `CTERelationRef`s or to original - // `ScalarSubquery`s. This is safe as a subquery plan at a level can reference only lower level - // other subqueries. - val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]] + // plans. This is safe as a subplan at a level can reference only lower level ot other subplans. + val subplansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]] planMergers.foreach { planMerger => val mergedPlans = planMerger.mergedPlans() - subqueryPlansByLevel += mergedPlans.map { mergedPlan => - val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) { + subplansByLevel += mergedPlans.map { mergedPlan => + val planWithoutReferences = if (subplansByLevel.isEmpty) { // Level 0 plans can't contain references mergedPlan.plan } else { - removeReferences(mergedPlan.plan, subqueryPlansByLevel) + removeReferences(mergedPlan.plan, subplansByLevel) } - if (mergedPlan.merged && mergedPlan.plan.output.size > 1) { + if (mergedPlan.merged) { CTERelationDef( Project( Seq(Alias( @@ -151,38 +149,42 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { } } - // Replace references back to `CTERelationRef`s or to original `ScalarSubquery`s in the main - // plan. - val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel) + // Replace references back to `CTERelationRef`s or to original subplans. + val newPlan = removeReferences(planWithReferences, subplansByLevel) // Add `CTERelationDef`s to the plan. - val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte }) - if (subqueryCTEs.nonEmpty) { - WithCTE(newPlan, subqueryCTEs.toSeq) + val subplanCTEs = subplansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte }) + if (subplanCTEs.nonEmpty) { + WithCTE(newPlan, subplanCTEs.toSeq) } else { newPlan } } - // First traversal inserts `ScalarSubqueryReference`s to the plan and tries to merge subquery - // plans by each level. + // First traversal inserts `ScalarSubqueryReference`s and `NoGroupingAggregateReference`s to the + // plan and tries to merge subplans by each level. Levels are separated eiter by scalar subqueries + // or by non-grouping aggregate nodes. Nodes with the same level make sense to try merging. private def insertReferences( plan: LogicalPlan, + root: Boolean, planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = { - // The level of a subquery plan is maximum level of its inner subqueries + 1 or 0 if it has no - // inner subqueries. - var maxLevel = 0 - val planWithReferences = - plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) { + if (!plan.containsAnyPattern(AGGREGATE, SCALAR_SUBQUERY)) { + return (plan, 0) + } + + // Calculate the level propagated from subquery plans, which is the maximum level of the + // subqueries of the node + 1 or 0 if the node has no subqueries. + var levelFromSubqueries = 0 + val nodeSubqueriesWithReferences = + plan.transformExpressionsWithPruning(_.containsPattern(SCALAR_SUBQUERY)) { case s: ScalarSubquery if !s.isCorrelated && s.deterministic => - val (planWithReferences, level) = insertReferences(s.plan, planMergers) + val (planWithReferences, level) = insertReferences(s.plan, true, planMergers) - while (level >= planMergers.size) planMergers += new PlanMerger() // The subquery could contain a hint that is not propagated once we merge it, but as a // non-correlated scalar subquery won't be turned into a Join the loss of hints is fine. - val mergeResult = planMergers(level).merge(planWithReferences) + val mergeResult = getPlanMerger(planMergers, level).merge(planWithReferences, true) - maxLevel = maxLevel.max(level + 1) + levelFromSubqueries = levelFromSubqueries.max(level + 1) val mergedOutput = mergeResult.outputMap(planWithReferences.output.head) val outputIndex = @@ -195,26 +197,96 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { s.exprId) case o => o } - (planWithReferences, maxLevel) + + // Calculate the level of the node, which is the maximum of the above calculated level + // propagated from subqueries and the level propagated from child nodes. + val (planWithReferences, level) = nodeSubqueriesWithReferences match { + case a: Aggregate if !root && a.groupingExpressions.isEmpty => + val (childWithReferences, levelFromChild) = insertReferences(a.child, false, planMergers) + val aggregateWithReferences = a.withNewChildren(Seq(childWithReferences)) + + // Level is the maximum of the level from subqueries and the level from child. + val level = levelFromChild.max(levelFromSubqueries) + + val mergeResult = getPlanMerger(planMergers, level).merge(aggregateWithReferences, false) + + val mergedOutput = aggregateWithReferences.output.map(mergeResult.outputMap) + val outputIndices = + mergedOutput.map(a => mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == a.exprId)) + val aggregateReference = NonGroupingAggregateReference( + level, + mergeResult.mergedPlanIndex, + outputIndices, + a.output + ) + + // This is a non-grouping aggregate node so propagate the level of the node + 1 to its + // parent + (aggregateReference, level + 1) + case o => + val (newChildren, levels) = o.children.map(insertReferences(_, false, planMergers)).unzip + // Level is the maximum of the level from subqueries and the level from the children. + (o.withNewChildren(newChildren), (levelFromSubqueries +: levels).max) + } + + (planWithReferences, level) + } + + private def getPlanMerger(planMergers: ArrayBuffer[PlanMerger], level: Int) = { + while (level >= planMergers.size) planMergers += new PlanMerger() + planMergers(level) } - // Second traversal replaces `ScalarSubqueryReference`s to either - // `GetStructField(ScalarSubquery(CTERelationRef to the merged plan)` if the plan is merged from - // multiple subqueries or `ScalarSubquery(original plan)` if it isn't. + // Second traversal replaces: + // - a `ScalarSubqueryReference` either to + // `GetStructField(ScalarSubquery(CTERelationRef to the merged plan), merged output index)` if + // the plan is merged from multiple subqueries or to `ScalarSubquery(original plan)` if it + // isn't. + // - a `NoGroupingAggregateReference` either to + // ``` + // Project( + // Seq( + // GetStructField( + // ScalarSubquery(CTERelationRef to the merged plan), + // merged output index 1), + // GetStructField( + // ScalarSubquery(CTERelationRef to the merged plan), + // merged output index 2), + // ...), + // OneRowRelation) + // ``` + // if the plan is merged from multiple subqueries or to `original plan` if it isn't. private def removeReferences( plan: LogicalPlan, - subqueryPlansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = { - plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { - case ssr: ScalarSubqueryReference => - subqueryPlansByLevel(ssr.level)(ssr.mergedPlanIndex) match { + subplansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = { + plan.transformUpWithPruning( + _.containsAnyPattern(NO_GROUPING_AGGREGATE_REFERENCE, SCALAR_SUBQUERY_REFERENCE)) { + case ngar: NonGroupingAggregateReference => + subplansByLevel(ngar.level)(ngar.mergedPlanIndex) match { case cte: CTERelationDef => - GetStructField( - ScalarSubquery( - CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming), - exprId = ssr.exprId), - ssr.outputIndex) - case o => ScalarSubquery(o, exprId = ssr.exprId) + val projectList = ngar.outputIndices.zip(ngar.output).map { case (i, a) => + Alias( + GetStructField( + ScalarSubquery( + CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming)), + i), + a.name)(a.exprId) + } + Project(projectList, OneRowRelation()) + case o => o } + case o => o.transformExpressionsUpWithPruning(_.containsPattern(SCALAR_SUBQUERY_REFERENCE)) { + case ssr: ScalarSubqueryReference => + subplansByLevel(ssr.level)(ssr.mergedPlanIndex) match { + case cte: CTERelationDef => + GetStructField( + ScalarSubquery( + CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming), + exprId = ssr.exprId), + ssr.outputIndex) + case o => ScalarSubquery(o, exprId = ssr.exprId) + } + } } } } @@ -233,9 +305,26 @@ case class ScalarSubqueryReference( level: Int, mergedPlanIndex: Int, outputIndex: Int, - dataType: DataType, + override val dataType: DataType, exprId: ExprId) extends LeafExpression with Unevaluable { override def nullable: Boolean = true final override val nodePatterns: Seq[TreePattern] = Seq(SCALAR_SUBQUERY_REFERENCE) } + +/** + * Temporal reference to a non-grouping aggregate which is added to a `PlanMerger`. + * + * @param level The level of the replaced aggregate. It defines the `PlanMerger` instance into which + * the aggregate is merged. + * @param mergedPlanIndex The index of the merged plan in the `PlanMerger`. + * @param outputIndices The indices of the output attributes of the merged plan. + * @param output The output of original aggregate. + */ +case class NonGroupingAggregateReference( + level: Int, + mergedPlanIndex: Int, + outputIndices: Seq[Int], + override val output: Seq[Attribute]) extends LeafNode { + final override val nodePatterns: Seq[TreePattern] = Seq(NO_GROUPING_AGGREGATE_REFERENCE) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala index 37982d163927..1623166e0a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -58,8 +58,8 @@ case class MergedPlan(plan: LogicalPlan, merged: Boolean) * 2. Merge a new plan with a cached plan by combining their outputs * * The merging process preserves semantic equivalence while combining outputs from multiple - * plans into a single plan. This is primarily used by [[MergeScalarSubqueries]] to deduplicate - * scalar subquery execution. + * plans into a single plan. This is primarily used by [[MergeSubplans]] to deduplicate subplan + * execution. * * Supported plan types for merging: * - [[Project]]: Merges project lists @@ -88,16 +88,21 @@ class PlanMerger { * 3. If no merge is possible, add as a new cache entry * * @param plan The logical plan to merge or cache. + * @param subqueryPlan If the logical plan is a subquery plan. * @return A [[MergeResult]] containing: * - The merged/cached plan to use * - Its index in the cache * - An attribute mapping for rewriting expressions */ - def merge(plan: LogicalPlan): MergeResult = { + def merge(plan: LogicalPlan, subqueryPlan: Boolean): MergeResult = { cache.zipWithIndex.collectFirst(Function.unlift { case (mp, i) => checkIdenticalPlans(plan, mp.plan).map { outputMap => - val newMergePlan = MergedPlan(mp.plan, true) + // Identical subquery expression plans are not marked as `merged` as the + // `ReusedSubqueryExec` rule can handle them without extracting the plans to CTEs. + // But, when a non-subquery subplan is identical to a cached plan we need to mark the plan + // `merged` and so extract it to a CTE later. + val newMergePlan = MergedPlan(mp.plan, cache(i).merged || !subqueryPlan) cache(i) = newMergePlan MergeResult(newMergePlan, i, outputMap) }.orElse { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index ba4e801ed0a6..5ea93e74c5d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -150,6 +150,7 @@ object TreePattern extends Enumeration { val LOCAL_RELATION: Value = Value val LOGICAL_QUERY_STAGE: Value = Value val NATURAL_LIKE_JOIN: Value = Value + val NO_GROUPING_AGGREGATE_REFERENCE: Value = Value val OFFSET: Value = Value val OUTER_JOIN: Value = Value val PARAMETERIZED_QUERY: Value = Value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala similarity index 82% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala index 008b4a89ce60..b368035e278e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala @@ -25,14 +25,14 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class MergeScalarSubqueriesSuite extends PlanTest { +class MergeSubplansSuite extends PlanTest { override def beforeEach(): Unit = { CTERelationDef.curId.set(0) } private object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("MergeScalarSubqueries", Once, MergeScalarSubqueries) :: Nil + val batches = Batch("MergeSubplans", Once, MergeSubplans) :: Nil } val testRelation = LocalRelation($"a".int, $"b".int, $"c".string) @@ -590,4 +590,135 @@ class MergeScalarSubqueriesSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) } + + test("Merge aggregates") { + val agg1 = testRelation.groupBy()(min($"a").as("min_a")) + val agg2 = testRelation.groupBy()(max($"a").as("max_a")) + val originalQuery = agg1.join(agg2) + + val mergedSubquery = testRelation + .groupBy()( + min($"a").as("min_a"), + max($"a").as("max_a") + ) + .select( + CreateNamedStruct(Seq( + Literal("min_a"), $"min_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + OneRowRelation().select(extractorExpression(0, analyzedMergedSubquery.output, 0, "min_a")) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 1, "max_a"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Merge non-siblig aggregates") { + val agg1 = testRelation.groupBy()(min($"a").as("min_a")) + val agg2 = testRelation.groupBy()(max($"a").as("max_a")) + val originalQuery = agg1.join(testRelation).join(agg2) + + val mergedSubquery = testRelation + .groupBy()( + min($"a").as("min_a"), + max($"a").as("max_a") + ) + .select( + CreateNamedStruct(Seq( + Literal("min_a"), $"min_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + OneRowRelation().select(extractorExpression(0, analyzedMergedSubquery.output, 0, "min_a")) + .join(testRelation) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 1, "max_a"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Merge subqueries and aggregates") { + val subquery1 = ScalarSubquery(testRelation.groupBy()(min($"a").as("min_a"))) + val subquery2 = ScalarSubquery(testRelation.groupBy()(max($"a").as("max_a"))) + val agg1 = testRelation.groupBy()(sum($"a").as("sum_a")) + val agg2 = testRelation.groupBy()(avg($"a").as("avg_a")) + val originalQuery = + testRelation + .select( + subquery1, + subquery2) + .join(agg1) + .join(agg2) + + val mergedSubquery = testRelation + .groupBy()( + min($"a").as("min_a"), + max($"a").as("max_a"), + sum($"a").as("sum_a"), + avg($"a").as("avg_a") + ) + .select( + CreateNamedStruct(Seq( + Literal("min_a"), $"min_a", + Literal("max_a"), $"max_a", + Literal("sum_a"), $"sum_a", + Literal("avg_a"), $"avg_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation + .select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 2, "sum_a"))) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 3, "avg_a"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Merge identical subqueries and aggregates") { + val subquery1 = ScalarSubquery(testRelation.groupBy()(min($"a").as("min_a"))) + val subquery2 = ScalarSubquery(testRelation.groupBy()(min($"a").as("min_a_2"))) + val agg1 = testRelation.groupBy()(min($"a").as("min_a_3")) + val agg2 = testRelation.groupBy()(min($"a").as("min_a_4")) + val originalQuery = + testRelation + .select( + subquery1, + subquery2) + .join(agg1) + .join(agg2) + + val mergedSubquery = testRelation + .groupBy()(min($"a").as("min_a")) + .select( + CreateNamedStruct(Seq(Literal("min_a"), $"min_a")).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation + .select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 0)) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 0, "min_a_3"))) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 0, "min_a_4"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 8edb59f49282..7f3b8383f0f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -61,8 +61,8 @@ class SparkOptimizer( new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)), Batch("InjectRuntimeFilter", FixedPoint(1), InjectRuntimeFilter), - Batch("MergeScalarSubqueries", Once, - MergeScalarSubqueries, + Batch("MergeSubplans", Once, + MergeSubplans, RewriteDistinctAggregates), Batch("Pushdown Filters from PartitionPruning", fixedPoint, PushDownPredicates), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index 7d7185ae6c13..603ec183bfb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} -import org.apache.spark.sql.catalyst.optimizer.MergeScalarSubqueries +import org.apache.spark.sql.catalyst.optimizer.MergeSubplans import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} import org.apache.spark.sql.execution.{ReusedSubqueryExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} @@ -207,7 +207,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp // `MergeScalarSubqueries` can duplicate subqueries in the optimized plan and would make testing // complicated. - conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, MergeScalarSubqueries.ruleName) + conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, MergeSubplans.ruleName) } protected override def afterAll(): Unit = try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala new file mode 100644 index 000000000000..b7557b42702e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class PlanMergeSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + import testImplicits._ + + setupTestData() + + test("Merge non-correlated scalar subqueries") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | (SELECT avg(key) FROM testData), + | (SELECT sum(key) FROM testData), + | (SELECT count(distinct key) FROM testData) + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-correlated scalar subqueries in a subquery") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT ( + | SELECT + | SUM( + | (SELECT avg(key) FROM testData) + + | (SELECT sum(key) FROM testData) + + | (SELECT count(distinct key) FROM testData)) + | FROM testData + |) + """.stripMargin) + + checkAnswer(df, Row(520050.0) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 5, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-correlated scalar subqueries from different levels") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | (SELECT avg(key) FROM testData), + | ( + | SELECT + | SUM( + | (SELECT sum(key) FROM testData) + | ) + | FROM testData + | ) + """.stripMargin) + + checkAnswer(df, Row(50.5, 505000) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-correlated scalar subqueries from different parent plans") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | ( + | SELECT + | SUM( + | (SELECT avg(key) FROM testData) + | ) + | FROM testData + | ), + | ( + | SELECT + | SUM( + | (SELECT sum(key) FROM testData) + | ) + | FROM testData + | ) + """.stripMargin) + + checkAnswer(df, Row(5050.0, 505000) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 4, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-correlated scalar subqueries with conflicting names") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | (SELECT avg(key) AS key FROM testData), + | (SELECT sum(key) AS key FROM testData), + | (SELECT count(distinct key) AS key FROM testData) + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-grouping aggregates") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT * + |FROM (SELECT avg(key) FROM testData) + |JOIN (SELECT sum(key) FROM testData) + |JOIN (SELECT count(distinct key) FROM testData) + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-grouping aggregates from different levels") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | first(avg_key), + | ( + | -- Using `testData2` makes the whole subquery plan non-mergeable to the + | -- non-grouping aggregate subplan in the main plan, which uses `testData`, but its + | -- aggregate subplan with `sum(key)` is mergeable + | SELECT first(sum_key) + | FROM (SELECT sum(key) AS sum_key FROM testData) + | JOIN testData2 + | ), + | first(count_key) + |FROM (SELECT avg(key) AS avg_key, count(distinct key) as count_key FROM testData) + |JOIN testData3 + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-grouping aggregate and subquery") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | first(avg_key), + | ( + | -- In this case the whole scalar subquery plan is mergeable to the non-grouping + | -- aggregate subplan in the main plan. + | SELECT sum(key) AS sum_key FROM testData + | ), + | first(count_key) + |FROM (SELECT avg(key) AS avg_key, count(distinct key) as count_key FROM testData) + |JOIN testData3 + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("SPARK-40618: Regression test for merging subquery bug with nested subqueries") { + // This test contains a subquery expression with another subquery expression nested inside. + // It acts as a regression test to ensure that the MergeScalarSubqueries rule does not attempt + // to merge them together. + withTable("t1", "t2") { + sql("create table t1(col int) using csv") + checkAnswer(sql("select(select sum((select sum(col) from t1)) from t1)"), Row(null)) + + checkAnswer(sql( + """ + |select + | (select sum( + | (select sum( + | (select sum(col) from t1)) + | from t1)) + | from t1) + |""".stripMargin), + Row(null)) + + sql("create table t2(col int) using csv") + checkAnswer(sql( + """ + |select + | (select sum( + | (select sum( + | (select sum(col) from t1)) + | from t2)) + | from t1) + |""".stripMargin), + Row(null)) + } + } + + test("SPARK-42346: Rewrite distinct aggregates after merging subqueries") { + withTempView("t1") { + Seq((1, 2), (3, 4)).toDF("c1", "c2").createOrReplaceTempView("t1") + + checkAnswer(sql( + """ + |SELECT + | (SELECT count(distinct c1) FROM t1), + | (SELECT count(distinct c2) FROM t1) + |""".stripMargin), + Row(2, 2)) + + // In this case we don't merge the subqueries as `RewriteDistinctAggregates` kicks off for the + // 2 subqueries first but `MergeScalarSubqueries` is not prepared for the `Expand` nodes that + // are inserted by the rewrite. + checkAnswer(sql( + """ + |SELECT + | (SELECT count(distinct c1) + sum(distinct c2) FROM t1), + | (SELECT count(distinct c2) + sum(distinct c1) FROM t1) + |""".stripMargin), + Row(8, 6)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 3ba48da0e327..b53610761d04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2234,161 +2234,6 @@ class SubquerySuite extends QueryTest } } - test("Merge non-correlated scalar subqueries") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | (SELECT avg(key) FROM testData), - | (SELECT sum(key) FROM testData), - | (SELECT count(distinct key) FROM testData) - """.stripMargin) - - checkAnswer(df, Row(50.5, 5050, 100) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-correlated scalar subqueries in a subquery") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT ( - | SELECT - | SUM( - | (SELECT avg(key) FROM testData) + - | (SELECT sum(key) FROM testData) + - | (SELECT count(distinct key) FROM testData)) - | FROM testData - |) - """.stripMargin) - - checkAnswer(df, Row(520050.0) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 5, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-correlated scalar subqueries from different levels") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | (SELECT avg(key) FROM testData), - | ( - | SELECT - | SUM( - | (SELECT sum(key) FROM testData) - | ) - | FROM testData - | ) - """.stripMargin) - - checkAnswer(df, Row(50.5, 505000) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-correlated scalar subqueries from different parent plans") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | ( - | SELECT - | SUM( - | (SELECT avg(key) FROM testData) - | ) - | FROM testData - | ), - | ( - | SELECT - | SUM( - | (SELECT sum(key) FROM testData) - | ) - | FROM testData - | ) - """.stripMargin) - - checkAnswer(df, Row(5050.0, 505000) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 4, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-correlated scalar subqueries with conflicting names") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | (SELECT avg(key) AS key FROM testData), - | (SELECT sum(key) AS key FROM testData), - | (SELECT count(distinct key) AS key FROM testData) - """.stripMargin) - - checkAnswer(df, Row(50.5, 5050, 100) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - test("SPARK-39355: Single column uses quoted to construct UnresolvedAttribute") { checkAnswer( sql(""" @@ -2489,39 +2334,6 @@ class SubquerySuite extends QueryTest } } - test("SPARK-40618: Regression test for merging subquery bug with nested subqueries") { - // This test contains a subquery expression with another subquery expression nested inside. - // It acts as a regression test to ensure that the MergeScalarSubqueries rule does not attempt - // to merge them together. - withTable("t1", "t2") { - sql("create table t1(col int) using csv") - checkAnswer(sql("select(select sum((select sum(col) from t1)) from t1)"), Row(null)) - - checkAnswer(sql( - """ - |select - | (select sum( - | (select sum( - | (select sum(col) from t1)) - | from t1)) - | from t1) - |""".stripMargin), - Row(null)) - - sql("create table t2(col int) using csv") - checkAnswer(sql( - """ - |select - | (select sum( - | (select sum( - | (select sum(col) from t1)) - | from t2)) - | from t1) - |""".stripMargin), - Row(null)) - } - } - test("SPARK-40615: Check unsupported data type when decorrelating subqueries") { withTempView("v1", "v2") { sql( @@ -2616,31 +2428,6 @@ class SubquerySuite extends QueryTest } } - test("SPARK-42346: Rewrite distinct aggregates after merging subqueries") { - withTempView("t1") { - Seq((1, 2), (3, 4)).toDF("c1", "c2").createOrReplaceTempView("t1") - - checkAnswer(sql( - """ - |SELECT - | (SELECT count(distinct c1) FROM t1), - | (SELECT count(distinct c2) FROM t1) - |""".stripMargin), - Row(2, 2)) - - // In this case we don't merge the subqueries as `RewriteDistinctAggregates` kicks off for the - // 2 subqueries first but `MergeScalarSubqueries` is not prepared for the `Expand` nodes that - // are inserted by the rewrite. - checkAnswer(sql( - """ - |SELECT - | (SELECT count(distinct c1) + sum(distinct c2) FROM t1), - | (SELECT count(distinct c2) + sum(distinct c1) FROM t1) - |""".stripMargin), - Row(8, 6)) - } - } - test("SPARK-42745: Improved AliasAwareOutputExpression works with DSv2") { withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> "") {