Skip to content

Commit

Permalink
[GLUTEN-8497][CORE] A unified CallInfo API to replace AdaptiveContext (
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored and baibaichen committed Feb 1, 2025
1 parent 6f4d2a1 commit 9783a88
Show file tree
Hide file tree
Showing 15 changed files with 355 additions and 376 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,10 @@ object CHRuleApi {
injector.injectPostTransform(c => AddPreProjectionForHashJoin.apply(c.session))

// Gluten columnar: Fallback policies.
injector.injectFallbackPolicy(
c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan()))
injector.injectFallbackPolicy(c => p => ExpandFallbackPolicy(c.caller.isAqe(), p))

// Gluten columnar: Post rules.
injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()))
injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.caller.isAqe()))
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPost(c => intercept(each(c.session))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ object VeloxRuleApi {
injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, VeloxBatch))

// Gluten columnar: Fallback policies.
injector.injectFallbackPolicy(
c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan()))
injector.injectFallbackPolicy(c => p => ExpandFallbackPolicy(c.caller.isAqe(), p))

// Gluten columnar: Post rules.
injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()))
injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.caller.isAqe()))
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPost(c => each(c.session)))
Expand Down Expand Up @@ -180,8 +179,7 @@ object VeloxRuleApi {
injector.injectPostTransform(_ => CollapseProjectExecTransformer)
injector.injectPostTransform(c => FlushableHashAggregateRule.apply(c.session))
injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, VeloxBatch))
injector.injectPostTransform(
c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()))
injector.injectPostTransform(c => RemoveTopmostColumnarToRow(c.session, c.caller.isAqe()))
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPostTransform(c => each(c.session)))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.caller

import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.columnar.InMemoryRelation

/**
* Helper API that stores information about the call site of the columnar rule. Specific columnar
* rules could call the API to check whether this time of rule call was initiated for certain
* purpose. For example, a rule call could be for AQE optimization, or for cached plan optimization,
* or for regular executed plan optimization.
*/
trait CallerInfo {
def isAqe(): Boolean
def isCache(): Boolean
}

object CallerInfo {
private val localStorage: ThreadLocal[Option[CallerInfo]] =
new ThreadLocal[Option[CallerInfo]]() {
override def initialValue(): Option[CallerInfo] = None
}

private class Impl(override val isAqe: Boolean, override val isCache: Boolean) extends CallerInfo

/*
* Find the information about the caller that initiated the rule call.
*/
def create(): CallerInfo = {
if (localStorage.get().nonEmpty) {
return localStorage.get().get
}
val stack = Thread.currentThread.getStackTrace
new Impl(isAqe = inAqeCall(stack), isCache = inCacheCall(stack))
}

private def inAqeCall(stack: Seq[StackTraceElement]): Boolean = {
stack.exists(_.getClassName.equals(AdaptiveSparkPlanExec.getClass.getName))
}

private def inCacheCall(stack: Seq[StackTraceElement]): Boolean = {
stack.exists(_.getClassName.equals(InMemoryRelation.getClass.getName))
}

/** For testing only. */
def withLocalValue[T](isAqe: Boolean, isCache: Boolean)(body: => T): T = {
val prevValue = localStorage.get()
val newValue = new Impl(isAqe, isCache)
localStorage.set(Some(newValue))
try {
body
} finally {
localStorage.set(prevValue)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar

import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.extension.caller.CallerInfo

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -29,7 +29,7 @@ trait ColumnarRuleApplier {
object ColumnarRuleApplier {
class ColumnarRuleCall(
val session: SparkSession,
val ac: AdaptiveContext,
val caller: CallerInfo,
val outputsColumnar: Boolean) {
val glutenConf: GlutenConfig = {
new GlutenConfig(session.sessionState.conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ import org.apache.gluten.component.Component
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.SparkReflectionUtil

// format: off
/**
* The cost model API of Gluten. Used by:
* 1. RAS planner for cost-based optimization; 2. Transition graph for choosing transition paths.
* <p>
* 1. RAS planner for cost-based optimization;
* <p>
* 2. Transition graph for choosing transition paths.
*/
// format: on
trait GlutenCostModel {
def costOf(node: SparkPlan): GlutenCost
def costComparator(): Ordering[GlutenCost]
Expand All @@ -38,14 +42,18 @@ trait GlutenCostModel {
}

object GlutenCostModel extends Logging {
def find(aliasOrClass: String): GlutenCostModel = {
val costModelRegistry = LongCostModel.registry()
private val costModelRegistry = {
val r = LongCostModel.registry()
// Components should override Backend's costers. Hence, reversed registration order is applied.
Component
.sorted()
.reverse
.flatMap(_.costers())
.foreach(coster => costModelRegistry.register(coster))
.foreach(coster => r.register(coster))
r
}

def find(aliasOrClass: String): GlutenCostModel = {
val costModel = find(costModelRegistry, aliasOrClass)
costModel
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/
package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.extension.caller.CallerInfo
import org.apache.gluten.extension.columnar.{ColumnarRuleApplier, ColumnarRuleExecutor}
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.logging.LogLevelUtil

import org.apache.spark.internal.Logging
Expand All @@ -39,27 +39,14 @@ class EnumeratedApplier(
extends ColumnarRuleApplier
with Logging
with LogLevelUtil {
private val adaptiveContext = AdaptiveContext(session)

override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = {
val call = new ColumnarRuleCall(session, adaptiveContext, outputsColumnar)
val finalPlan = maybeAqe {
apply0(ruleBuilders.map(b => b(call)), plan)
}
val call = new ColumnarRuleCall(session, CallerInfo.create(), outputsColumnar)
val finalPlan = apply0(ruleBuilders.map(b => b(call)), plan)
finalPlan
}

private def apply0(rules: Seq[Rule[SparkPlan]], plan: SparkPlan): SparkPlan =
new ColumnarRuleExecutor("ras", rules).execute(plan)

private def maybeAqe[T](f: => T): T = {
adaptiveContext.setAdaptiveContext()
try {
f
} finally {
adaptiveContext.resetAdaptiveContext()
}
}
}

object EnumeratedApplier {}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.component.Component
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.caller.CallerInfo
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel}
import org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.injector.Injector
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.ras.property.PropertySet
Expand Down Expand Up @@ -81,7 +81,7 @@ object EnumeratedTransform {
val session = SparkSession.getActiveSession.getOrElse(
throw new GlutenException(
"HeuristicTransform#static can only be called when an active Spark session exists"))
val call = new ColumnarRuleCall(session, AdaptiveContext(session), false)
val call = new ColumnarRuleCall(session, CallerInfo.create(), false)
dummyInjector.gluten.ras.createEnumeratedTransform(call)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/
package org.apache.gluten.extension.columnar.heuristic

import org.apache.gluten.extension.caller.CallerInfo
import org.apache.gluten.extension.columnar.{ColumnarRuleApplier, ColumnarRuleExecutor}
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.logging.LogLevelUtil

import org.apache.spark.internal.Logging
Expand All @@ -33,35 +33,33 @@ import org.apache.spark.sql.execution.SparkPlan
class HeuristicApplier(
session: SparkSession,
transformBuilders: Seq[ColumnarRuleCall => Rule[SparkPlan]],
fallbackPolicyBuilders: Seq[ColumnarRuleCall => Rule[SparkPlan]],
fallbackPolicyBuilders: Seq[ColumnarRuleCall => SparkPlan => Rule[SparkPlan]],
postBuilders: Seq[ColumnarRuleCall => Rule[SparkPlan]],
finalBuilders: Seq[ColumnarRuleCall => Rule[SparkPlan]])
extends ColumnarRuleApplier
with Logging
with LogLevelUtil {
private val adaptiveContext = AdaptiveContext(session)

override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = {
val call = new ColumnarRuleCall(session, adaptiveContext, outputsColumnar)
val call = new ColumnarRuleCall(session, CallerInfo.create(), outputsColumnar)
makeRule(call).apply(plan)
}

private def makeRule(call: ColumnarRuleCall): Rule[SparkPlan] = {
plan =>
prepareFallback(plan) {
p =>
val suggestedPlan = transformPlan("transform", transformRules(call), p)
val finalPlan = transformPlan("fallback", fallbackPolicies(call), suggestedPlan) match {
case FallbackNode(fallbackPlan) =>
// we should use vanilla c2r rather than native c2r,
// and there should be no `GlutenPlan` any more,
// so skip the `postRules()`.
fallbackPlan
case plan =>
transformPlan("post", postRules(call), plan)
}
transformPlan("final", finalRules(call), finalPlan)
originalPlan =>
val suggestedPlan = transformPlan("transform", transformRules(call), originalPlan)
val finalPlan = transformPlan(
"fallback",
fallbackPolicies(call).map(_(originalPlan)),
suggestedPlan) match {
case FallbackNode(fallbackPlan) =>
// we should use vanilla c2r rather than native c2r,
// and there should be no `GlutenPlan` anymore,
// so skip the `postRules()`.
fallbackPlan
case plan =>
transformPlan("post", postRules(call), plan)
}
transformPlan("final", finalRules(call), finalPlan)
}

private def transformPlan(
Expand All @@ -70,17 +68,6 @@ class HeuristicApplier(
plan: SparkPlan): SparkPlan =
new ColumnarRuleExecutor(phase, rules).execute(plan)

private def prepareFallback[T](p: SparkPlan)(f: SparkPlan => T): T = {
adaptiveContext.setAdaptiveContext()
adaptiveContext.setOriginalPlan(p)
try {
f(p)
} finally {
adaptiveContext.resetOriginalPlan()
adaptiveContext.resetAdaptiveContext()
}
}

/**
* Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which
* the plan will be breakdown and decided to be fallen back or not.
Expand All @@ -93,7 +80,7 @@ class HeuristicApplier(
* Rules to add wrapper `FallbackNode`s on top of the input plan, as hints to make planner fall
* back the whole input plan to the original vanilla Spark plan.
*/
private def fallbackPolicies(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = {
private def fallbackPolicies(call: ColumnarRuleCall): Seq[SparkPlan => Rule[SparkPlan]] = {
fallbackPolicyBuilders.map(b => b.apply(call))
}

Expand All @@ -112,12 +99,6 @@ class HeuristicApplier(
private def finalRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = {
finalBuilders.map(b => b.apply(call))
}

// Just for test use.
def enableAdaptiveContext(): HeuristicApplier = {
adaptiveContext.enableAdaptiveContext()
this
}
}

object HeuristicApplier {}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package org.apache.gluten.extension.columnar.heuristic

import org.apache.gluten.component.Component
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.caller.CallerInfo
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.columnar.offload.OffloadSingleNode
import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
import org.apache.gluten.extension.columnar.validator.Validator
import org.apache.gluten.extension.injector.Injector
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.logging.LogLevelUtil

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -126,7 +126,7 @@ object HeuristicTransform {
val session = SparkSession.getActiveSession.getOrElse(
throw new GlutenException(
"HeuristicTransform#static can only be called when an active Spark session exists"))
val call = new ColumnarRuleCall(session, AdaptiveContext(session), false)
val call = new ColumnarRuleCall(session, CallerInfo.create(), false)
dummyInjector.gluten.legacy.createHeuristicTransform(call)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ object GlutenInjector {
private val preTransformBuilders = mutable.Buffer.empty[ColumnarRuleCall => Rule[SparkPlan]]
private val transformBuilders = mutable.Buffer.empty[ColumnarRuleCall => Rule[SparkPlan]]
private val postTransformBuilders = mutable.Buffer.empty[ColumnarRuleCall => Rule[SparkPlan]]
private val fallbackPolicyBuilders = mutable.Buffer.empty[ColumnarRuleCall => Rule[SparkPlan]]
private val fallbackPolicyBuilders =
mutable.Buffer.empty[ColumnarRuleCall => SparkPlan => Rule[SparkPlan]]
private val postBuilders = mutable.Buffer.empty[ColumnarRuleCall => Rule[SparkPlan]]
private val finalBuilders = mutable.Buffer.empty[ColumnarRuleCall => Rule[SparkPlan]]

Expand All @@ -73,7 +74,7 @@ object GlutenInjector {
postTransformBuilders += builder
}

def injectFallbackPolicy(builder: ColumnarRuleCall => Rule[SparkPlan]): Unit = {
def injectFallbackPolicy(builder: ColumnarRuleCall => SparkPlan => Rule[SparkPlan]): Unit = {
fallbackPolicyBuilders += builder
}

Expand Down
Loading

0 comments on commit 9783a88

Please sign in to comment.