Skip to content

Commit

Permalink
Remove local sort for TopNRowNumber
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Jul 11, 2024
1 parent be57db8 commit 5bc17d0
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,11 @@ object VeloxBackendSettings extends BackendSettingsApi {

override def alwaysFailOnMapExpression(): Boolean = true

override def requiredChildOrderingForWindow(): Boolean = true
override def requiredChildOrderingForWindow(): Boolean = {
GlutenConfig.getConf.veloxColumnarWindowType.equals("streaming")
}

override def requiredChildOrderingForWindowGroupLimit(): Boolean = false

override def staticPartitionWriteOnly(): Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ trait BackendSettingsApi {

def alwaysFailOnMapExpression(): Boolean = false

def requiredChildOrderingForWindow(): Boolean = false
def requiredChildOrderingForWindow(): Boolean = true

def requiredChildOrderingForWindowGroupLimit(): Boolean = false
def requiredChildOrderingForWindowGroupLimit(): Boolean = true

def staticPartitionWriteOnly(): Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ object HashAggregateExecBaseTransformer {
case a: SortAggregateExec => a.initialInputBufferOffset
}

def from(agg: BaseAggregateExec)(
childConverter: SparkPlan => SparkPlan = p => p): HashAggregateExecBaseTransformer = {
def from(agg: BaseAggregateExec): HashAggregateExecBaseTransformer = {
BackendsApiManager.getSparkPlanExecApiInstance
.genHashAggregateExecTransformer(
agg.requiredChildDistributionExpressions,
Expand All @@ -195,7 +194,7 @@ object HashAggregateExecBaseTransformer {
agg.aggregateAttributes,
getInitialInputBufferOffset(agg),
agg.resultExpressions,
childConverter(agg.child)
agg.child
)
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ case class WindowExecTransformer(
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
if (
BackendsApiManager.getSettings.requiredChildOrderingForWindow()
&& GlutenConfig.getConf.veloxColumnarWindowType.equals("streaming")
) {
// Velox StreamingWindow need to require child order.
if (BackendsApiManager.getSettings.requiredChildOrderingForWindow()) {
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
} else {
Seq(Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,24 @@ case class WindowGroupLimitExecTransformer(

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
// Velox StreamingTopNRowNumber need to require child order.
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
} else {
Seq(Nil)
}
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputOrdering: Seq[SortOrder] = {
if (requiredChildOrdering.forall(_.isEmpty)) {
// The Velox backend `TopNRowNumber` does not require child ordering, because it
// uses hash table to store partition and use priority queue to track of top limit rows.
// Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys.
// To be safe, here we do not propagate the ordering.
// TODO: Make the framework aware of grouped data distribution
Nil
} else {
child.outputOrdering
}
}

override def outputPartitioning: Partitioning = child.outputPartitioning

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar

import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, ProjectExecTransformer, ShuffledHashJoinExecTransformerBase, SortExecTransformer, WindowExecTransformer, WindowGroupLimitExecTransformer}

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, UnaryExecNode}

/**
* This rule is used to eliminate unnecessary local sort.
*
* This could happen if:
* - Convert sort merge join to shuffled hash join
* - Offload SortAggregate to native hash aggregate
* - Offload WindowGroupLimit to native TopNRowNumber
* - The columnar window type is `sort`
*/
object EliminateLocalSort extends Rule[SparkPlan] {
private def canEliminateLocalSort(p: SparkPlan): Boolean = p match {
case _: HashAggregateExecBaseTransformer => true
case _: ShuffledHashJoinExecTransformerBase => true
case _: WindowGroupLimitExecTransformer => true
case _: WindowExecTransformer => true
case _ => false
}

private def canThrough(p: SparkPlan): Boolean = p match {
case _: ProjectExec => true
case _: ProjectExecTransformer => true
case _ => false
}

private def orderingSatisfies(gChild: SparkPlan, requiredOrdering: Seq[SortOrder]): Boolean = {
SortOrder.orderingSatisfies(gChild.outputOrdering, requiredOrdering)
}

override def apply(plan: SparkPlan): SparkPlan = {
plan.transformDown {
case p if canEliminateLocalSort(p) =>
val requiredChildOrdering = p.requiredChildOrdering
assert(requiredChildOrdering.size == p.children.size)
val newChildren = p.children.zipWithIndex.map {
case (SortWithChild(gChild), i) if orderingSatisfies(gChild, requiredChildOrdering(i)) =>
gChild
case (p: UnaryExecNode, i) if canThrough(p) =>
// There may be more than one project between target operator and sort,
// e.g., both hash aggregate and sort pull out project
p.child match {
case SortWithChild(gChild) if orderingSatisfies(gChild, requiredChildOrdering(i)) =>
p.withNewChildren(gChild :: Nil)
case _ => p
}
case p => p._1
}
p.withNewChildren(newChildren)
}
}
}

object SortWithChild {
def unapply(plan: SparkPlan): Option[SparkPlan] = {
plan match {
case p1 @ ProjectExec(_, SortExecTransformer(_, false, p2: ProjectExec, _))
if p1.outputSet == p2.child.outputSet =>
Some(p2.child)
case p1 @ ProjectExecTransformer(
_,
SortExecTransformer(_, false, p2: ProjectExecTransformer, _))
if p1.outputSet == p2.child.outputSet =>
Some(p2.child)
case SortExec(_, false, child, _) =>
Some(child)
case SortExecTransformer(_, false, child, _) =>
Some(child)
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package org.apache.gluten.extension.columnar

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.SortExecTransformer
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -32,6 +33,8 @@ import org.apache.spark.sql.execution.{SortExec, SparkPlan}
* SortAggregate with the same key. So, this rule adds local sort back if necessary.
*/
object EnsureLocalSortRequirements extends Rule[SparkPlan] {
private lazy val offload = TransformPreOverrides.apply()

private def addLocalSort(
originalChild: SparkPlan,
requiredOrdering: Seq[SortOrder]): SparkPlan = {
Expand All @@ -40,18 +43,12 @@ object EnsureLocalSortRequirements extends Rule[SparkPlan] {
FallbackTags.add(newChild, "columnar Sort is not enabled in SortExec")
newChild
} else {
val newChildWithTransformer =
SortExecTransformer(
newChild.sortOrder,
newChild.global,
newChild.child,
newChild.testSpillFrequency)
val validationResult = newChildWithTransformer.doValidate()
if (validationResult.isValid) {
newChildWithTransformer
val rewrittenPlan = RewriteSparkPlanRulesManager.apply().apply(newChild)
if (rewrittenPlan.eq(newChild) && FallbackTags.nonEmpty(rewrittenPlan)) {
// The sort can not be offloaded
rewrittenPlan
} else {
FallbackTags.add(newChild, validationResult)
newChild
offload.apply(rewrittenPlan)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,13 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] {
.genFilterExecTransformer(plan.condition, plan.child)
transformer.doValidate().tagOnFallback(plan)
case plan: HashAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)()
val transformer = HashAggregateExecBaseTransformer.from(plan)
transformer.doValidate().tagOnFallback(plan)
case plan: SortAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)()
val transformer = HashAggregateExecBaseTransformer.from(plan)
transformer.doValidate().tagOnFallback(plan)
case plan: ObjectHashAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)()
val transformer = HashAggregateExecBaseTransformer.from(plan)
transformer.doValidate().tagOnFallback(plan)
case plan: UnionExec =>
val transformer = ColumnarUnionExec(plan.children)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ case class OffloadAggregate() extends OffloadSingleNode with LogLevelUtil {
case _: TransformSupport =>
// If the child is transformable, transform aggregation as well.
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)()
HashAggregateExecBaseTransformer.from(plan)
case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
HashAggregateExecBaseTransformer.from(plan)()
HashAggregateExecBaseTransformer.from(plan)
case _ =>
// If the child is not transformable, do not transform the agg.
FallbackTags.add(plan, "child output schema is empty")
plan
}
} else {
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)()
HashAggregateExecBaseTransformer.from(plan)
}
}
}
Expand Down Expand Up @@ -425,10 +425,10 @@ object OffloadOthers {
ColumnarCoalesceExec(plan.numPartitions, plan.child)
case plan: SortAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)(SortUtils.dropPartialSort)
HashAggregateExecBaseTransformer.from(plan)
case plan: ObjectHashAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)()
HashAggregateExecBaseTransformer.from(plan)
case plan: UnionExec =>
val children = plan.children
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class EnumeratedApplier(session: SparkSession)
List(
(_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
(spark: SparkSession) => RewriteTransformer(spark),
(_: SparkSession) => EliminateLocalSort,
(_: SparkSession) => EnsureLocalSortRequirements,
(_: SparkSession) => CollapseProjectExecTransformer
) :::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec
object RasOffloadHashAggregate extends RasOffload {
override def offload(node: SparkPlan): SparkPlan = node match {
case agg: HashAggregateExec =>
val out = HashAggregateExecBaseTransformer.from(agg)()
val out = HashAggregateExecBaseTransformer.from(agg)
out
case other => other
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class HeuristicApplier(session: SparkSession)
List(
(_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
(spark: SparkSession) => RewriteTransformer(spark),
(_: SparkSession) => EliminateLocalSort,
(_: SparkSession) => EnsureLocalSortRequirements,
(_: SparkSession) => CollapseProjectExecTransformer
) :::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.gluten.extension.columnar.rewrite

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.SortUtils

import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.JoinType
Expand Down Expand Up @@ -52,8 +51,8 @@ object RewriteJoin extends RewriteSingleNode with JoinSelectionHelper {
smj.joinType,
buildSide,
smj.condition,
SortUtils.dropPartialSort(smj.left),
SortUtils.dropPartialSort(smj.right),
smj.left,
smj.right,
smj.isSkewJoin
)
case _ => plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.execution.{WindowExecTransformer, WindowGroupLimitExecTransformer}
import org.apache.gluten.execution.{SortExecTransformer, WindowExecTransformer, WindowGroupLimitExecTransformer}

import org.apache.spark.sql.GlutenSQLTestsTrait
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -134,6 +134,9 @@ class GlutenSQLWindowFunctionSuite extends SQLWindowFunctionSuite with GlutenSQL
case _ => false
}
)
assert(
getExecutedPlan(df).collect { case s: SortExecTransformer if !s.global => s }.size == 1
)
}
}

Expand Down

0 comments on commit 5bc17d0

Please sign in to comment.