Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LogicalPlan, Project, Subquery, WithCTE}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, LeafNode, LogicalPlan, OneRowRelation, Project, Subquery, WithCTE}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, NO_GROUPING_AGGREGATE_REFERENCE, SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -100,7 +100,7 @@ import org.apache.spark.sql.types.DataType
* : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125]
* +- *(1) Scan OneRowRelation[]
*/
object MergeScalarSubqueries extends Rule[LogicalPlan] {
object MergeSubplans extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan match {
// Subquery reuse needs to be enabled for this optimization.
Expand All @@ -117,26 +117,24 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
}

private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
// Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert references in place of
// `ScalarSubquery`s.
// Collect subplans by level into `PlanMerger`s and insert references in place of them.
val planMergers = ArrayBuffer.empty[PlanMerger]
val planWithReferences = insertReferences(plan, planMergers)._1
val planWithReferences = insertReferences(plan, true, planMergers)._1

// Traverse level by level and convert merged plans to `CTERelationDef`s and keep non-merged
// ones. While traversing replace references in plans back to `CTERelationRef`s or to original
// `ScalarSubquery`s. This is safe as a subquery plan at a level can reference only lower level
// other subqueries.
val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
// plans. This is safe as a subplan at a level can reference only lower level ot other subplans.
val subplansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
planMergers.foreach { planMerger =>
val mergedPlans = planMerger.mergedPlans()
subqueryPlansByLevel += mergedPlans.map { mergedPlan =>
val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) {
subplansByLevel += mergedPlans.map { mergedPlan =>
val planWithoutReferences = if (subplansByLevel.isEmpty) {
// Level 0 plans can't contain references
mergedPlan.plan
} else {
removeReferences(mergedPlan.plan, subqueryPlansByLevel)
removeReferences(mergedPlan.plan, subplansByLevel)
}
if (mergedPlan.merged && mergedPlan.plan.output.size > 1) {
if (mergedPlan.merged) {
CTERelationDef(
Project(
Seq(Alias(
Expand All @@ -151,38 +149,42 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
}
}

// Replace references back to `CTERelationRef`s or to original `ScalarSubquery`s in the main
// plan.
val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel)
// Replace references back to `CTERelationRef`s or to original subplans.
val newPlan = removeReferences(planWithReferences, subplansByLevel)

// Add `CTERelationDef`s to the plan.
val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte })
if (subqueryCTEs.nonEmpty) {
WithCTE(newPlan, subqueryCTEs.toSeq)
val subplanCTEs = subplansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte })
if (subplanCTEs.nonEmpty) {
WithCTE(newPlan, subplanCTEs.toSeq)
} else {
newPlan
}
}

// First traversal inserts `ScalarSubqueryReference`s to the plan and tries to merge subquery
// plans by each level.
// First traversal inserts `ScalarSubqueryReference`s and `NoGroupingAggregateReference`s to the
// plan and tries to merge subplans by each level. Levels are separated eiter by scalar subqueries
// or by non-grouping aggregate nodes. Nodes with the same level make sense to try merging.
private def insertReferences(
plan: LogicalPlan,
root: Boolean,
planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = {
// The level of a subquery plan is maximum level of its inner subqueries + 1 or 0 if it has no
// inner subqueries.
var maxLevel = 0
val planWithReferences =
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
if (!plan.containsAnyPattern(AGGREGATE, SCALAR_SUBQUERY)) {
return (plan, 0)
}

// Calculate the level propagated from subquery plans, which is the maximum level of the
// subqueries of the node + 1 or 0 if the node has no subqueries.
var levelFromSubqueries = 0
val nodeSubqueriesWithReferences =
plan.transformExpressionsWithPruning(_.containsPattern(SCALAR_SUBQUERY)) {
case s: ScalarSubquery if !s.isCorrelated && s.deterministic =>
val (planWithReferences, level) = insertReferences(s.plan, planMergers)
val (planWithReferences, level) = insertReferences(s.plan, true, planMergers)

while (level >= planMergers.size) planMergers += new PlanMerger()
// The subquery could contain a hint that is not propagated once we merge it, but as a
// non-correlated scalar subquery won't be turned into a Join the loss of hints is fine.
val mergeResult = planMergers(level).merge(planWithReferences)
val mergeResult = getPlanMerger(planMergers, level).merge(planWithReferences, true)

maxLevel = maxLevel.max(level + 1)
levelFromSubqueries = levelFromSubqueries.max(level + 1)

val mergedOutput = mergeResult.outputMap(planWithReferences.output.head)
val outputIndex =
Expand All @@ -195,26 +197,96 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
s.exprId)
case o => o
}
(planWithReferences, maxLevel)

// Calculate the level of the node, which is the maximum of the above calculated level
// propagated from subqueries and the level propagated from child nodes.
val (planWithReferences, level) = nodeSubqueriesWithReferences match {
case a: Aggregate if !root && a.groupingExpressions.isEmpty =>
val (childWithReferences, levelFromChild) = insertReferences(a.child, false, planMergers)
val aggregateWithReferences = a.withNewChildren(Seq(childWithReferences))

// Level is the maximum of the level from subqueries and the level from child.
val level = levelFromChild.max(levelFromSubqueries)

val mergeResult = getPlanMerger(planMergers, level).merge(aggregateWithReferences, false)

val mergedOutput = aggregateWithReferences.output.map(mergeResult.outputMap)
val outputIndices =
mergedOutput.map(a => mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == a.exprId))
val aggregateReference = NonGroupingAggregateReference(
level,
mergeResult.mergedPlanIndex,
outputIndices,
a.output
)

// This is a non-grouping aggregate node so propagate the level of the node + 1 to its
// parent
(aggregateReference, level + 1)
case o =>
val (newChildren, levels) = o.children.map(insertReferences(_, false, planMergers)).unzip
// Level is the maximum of the level from subqueries and the level from the children.
(o.withNewChildren(newChildren), (levelFromSubqueries +: levels).max)
}

(planWithReferences, level)
}

private def getPlanMerger(planMergers: ArrayBuffer[PlanMerger], level: Int) = {
while (level >= planMergers.size) planMergers += new PlanMerger()
planMergers(level)
}

// Second traversal replaces `ScalarSubqueryReference`s to either
// `GetStructField(ScalarSubquery(CTERelationRef to the merged plan)` if the plan is merged from
// multiple subqueries or `ScalarSubquery(original plan)` if it isn't.
// Second traversal replaces:
// - a `ScalarSubqueryReference` either to
// `GetStructField(ScalarSubquery(CTERelationRef to the merged plan), merged output index)` if
// the plan is merged from multiple subqueries or to `ScalarSubquery(original plan)` if it
// isn't.
// - a `NoGroupingAggregateReference` either to
// ```
// Project(
// Seq(
// GetStructField(
// ScalarSubquery(CTERelationRef to the merged plan),
// merged output index 1),
// GetStructField(
// ScalarSubquery(CTERelationRef to the merged plan),
// merged output index 2),
// ...),
// OneRowRelation)
// ```
// if the plan is merged from multiple subqueries or to `original plan` if it isn't.
private def removeReferences(
plan: LogicalPlan,
subqueryPlansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = {
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) {
case ssr: ScalarSubqueryReference =>
subqueryPlansByLevel(ssr.level)(ssr.mergedPlanIndex) match {
subplansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = {
plan.transformUpWithPruning(
_.containsAnyPattern(NO_GROUPING_AGGREGATE_REFERENCE, SCALAR_SUBQUERY_REFERENCE)) {
case ngar: NonGroupingAggregateReference =>
subplansByLevel(ngar.level)(ngar.mergedPlanIndex) match {
case cte: CTERelationDef =>
GetStructField(
ScalarSubquery(
CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming),
exprId = ssr.exprId),
ssr.outputIndex)
case o => ScalarSubquery(o, exprId = ssr.exprId)
val projectList = ngar.outputIndices.zip(ngar.output).map { case (i, a) =>
Alias(
GetStructField(
ScalarSubquery(
CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming)),
i),
a.name)(a.exprId)
}
Project(projectList, OneRowRelation())
case o => o
}
case o => o.transformExpressionsUpWithPruning(_.containsPattern(SCALAR_SUBQUERY_REFERENCE)) {
case ssr: ScalarSubqueryReference =>
subplansByLevel(ssr.level)(ssr.mergedPlanIndex) match {
case cte: CTERelationDef =>
GetStructField(
ScalarSubquery(
CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming),
exprId = ssr.exprId),
ssr.outputIndex)
case o => ScalarSubquery(o, exprId = ssr.exprId)
}
}
}
}
}
Expand All @@ -233,9 +305,26 @@ case class ScalarSubqueryReference(
level: Int,
mergedPlanIndex: Int,
outputIndex: Int,
dataType: DataType,
override val dataType: DataType,
exprId: ExprId) extends LeafExpression with Unevaluable {
override def nullable: Boolean = true

final override val nodePatterns: Seq[TreePattern] = Seq(SCALAR_SUBQUERY_REFERENCE)
}

/**
* Temporal reference to a non-grouping aggregate which is added to a `PlanMerger`.
*
* @param level The level of the replaced aggregate. It defines the `PlanMerger` instance into which
* the aggregate is merged.
* @param mergedPlanIndex The index of the merged plan in the `PlanMerger`.
* @param outputIndices The indices of the output attributes of the merged plan.
* @param output The output of original aggregate.
*/
case class NonGroupingAggregateReference(
level: Int,
mergedPlanIndex: Int,
outputIndices: Seq[Int],
override val output: Seq[Attribute]) extends LeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(NO_GROUPING_AGGREGATE_REFERENCE)
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ case class MergedPlan(plan: LogicalPlan, merged: Boolean)
* 2. Merge a new plan with a cached plan by combining their outputs
*
* The merging process preserves semantic equivalence while combining outputs from multiple
* plans into a single plan. This is primarily used by [[MergeScalarSubqueries]] to deduplicate
* scalar subquery execution.
* plans into a single plan. This is primarily used by [[MergeSubplans]] to deduplicate subplan
* execution.
*
* Supported plan types for merging:
* - [[Project]]: Merges project lists
Expand Down Expand Up @@ -88,16 +88,21 @@ class PlanMerger {
* 3. If no merge is possible, add as a new cache entry
*
* @param plan The logical plan to merge or cache.
* @param subqueryPlan If the logical plan is a subquery plan.
* @return A [[MergeResult]] containing:
* - The merged/cached plan to use
* - Its index in the cache
* - An attribute mapping for rewriting expressions
*/
def merge(plan: LogicalPlan): MergeResult = {
def merge(plan: LogicalPlan, subqueryPlan: Boolean): MergeResult = {
cache.zipWithIndex.collectFirst(Function.unlift {
case (mp, i) =>
checkIdenticalPlans(plan, mp.plan).map { outputMap =>
val newMergePlan = MergedPlan(mp.plan, true)
// Identical subquery expression plans are not marked as `merged` as the
// `ReusedSubqueryExec` rule can handle them without extracting the plans to CTEs.
// But, when a non-subquery subplan is identical to a cached plan we need to mark the plan
// `merged` and so extract it to a CTE later.
val newMergePlan = MergedPlan(mp.plan, cache(i).merged || !subqueryPlan)
cache(i) = newMergePlan
MergeResult(newMergePlan, i, outputMap)
}.orElse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ object TreePattern extends Enumeration {
val LOCAL_RELATION: Value = Value
val LOGICAL_QUERY_STAGE: Value = Value
val NATURAL_LIKE_JOIN: Value = Value
val NO_GROUPING_AGGREGATE_REFERENCE: Value = Value
val OFFSET: Value = Value
val OUTER_JOIN: Value = Value
val PARAMETERIZED_QUERY: Value = Value
Expand Down
Loading