Skip to content

Commit

Permalink
[CORE][VL] Cost model code refactors (#8541)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Jan 16, 2025
1 parent 67ebbc8 commit 1520458
Show file tree
Hide file tree
Showing 26 changed files with 222 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster}
import org.apache.gluten.extension.columnar.transition.{Convention, ConventionFunc}
import org.apache.gluten.substrait.rel.LocalFilesNode
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
Expand Down Expand Up @@ -54,7 +55,6 @@ class CHBackend extends SubstraitBackend {
override def name(): String = CHConf.BACKEND_NAME
override def buildInfo(): BuildInfo =
BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN")
override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def iteratorApi(): IteratorApi = new CHIteratorApi
override def sparkPlanExecApi(): SparkPlanExecApi = new CHSparkPlanExecApi
override def transformerApi(): TransformerApi = new CHTransformerApi
Expand All @@ -63,6 +63,8 @@ class CHBackend extends SubstraitBackend {
override def listenerApi(): ListenerApi = new CHListenerApi
override def ruleApi(): RuleApi = new CHRuleApi
override def settings(): BackendSettingsApi = CHBackendSettings
override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def costers(): Seq[LongCoster] = Seq(LegacyCoster)
}

object CHBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
import org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster
import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform}
import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers}
import org.apache.gluten.extension.columnar.rewrite._
Expand Down Expand Up @@ -143,9 +142,6 @@ object CHRuleApi {
}

private def injectRas(injector: RasInjector): Unit = {
// Register legacy coster for transition planner.
injector.injectCoster(_ => LegacyCoster)

// CH backend doesn't work with RAS at the moment. Inject a rule that aborts any
// execution calls.
injector.injectPreTransform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster, RoughCoster}
import org.apache.gluten.extension.columnar.transition.{Convention, ConventionFunc}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.rel.LocalFilesNode
Expand Down Expand Up @@ -61,7 +62,6 @@ class VeloxBackend extends SubstraitBackend {
override def name(): String = VeloxBackend.BACKEND_NAME
override def buildInfo(): BuildInfo =
BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION, VELOX_REVISION_TIME)
override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def iteratorApi(): IteratorApi = new VeloxIteratorApi
override def sparkPlanExecApi(): SparkPlanExecApi = new VeloxSparkPlanExecApi
override def transformerApi(): TransformerApi = new VeloxTransformerApi
Expand All @@ -70,6 +70,8 @@ class VeloxBackend extends SubstraitBackend {
override def listenerApi(): ListenerApi = new VeloxListenerApi
override def ruleApi(): RuleApi = new VeloxRuleApi
override def settings(): BackendSettingsApi = VeloxBackendSettings
override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def costers(): Seq[LongCoster] = Seq(LegacyCoster, RoughCoster)
}

object VeloxBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
import org.apache.gluten.extension.columnar.enumerated.{RasOffload, RemoveSort}
import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, RoughCoster}
import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform}
import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers}
import org.apache.gluten.extension.columnar.rewrite._
Expand Down Expand Up @@ -120,10 +119,6 @@ object VeloxRuleApi {
}

private def injectRas(injector: RasInjector): Unit = {
// Gluten RAS: Costers.
injector.injectCoster(_ => LegacyCoster)
injector.injectCoster(_ => RoughCoster)

// Gluten RAS: Pre rules.
injector.injectPreTransform(_ => RemoveTransitions)
injector.injectPreTransform(_ => PushDownInputFileExpression.PreOffload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters

class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper {

protected val rootPath: String = getClass.getResource("/").getPath
override protected val resourcePath: String = "/tpch-data-parquet"
override protected val fileFormat: String = "parquet"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package org.apache.gluten.extension.columnar.enumerated.planner

import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel, LegacyCoster, LongCostModel}
import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
import org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel, LegacyCoster, LongCostModel}
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq}
import org.apache.gluten.ras.{Cost, Ras}
import org.apache.gluten.ras.Ras
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.path.RasPath
import org.apache.gluten.ras.property.PropertySet
Expand Down Expand Up @@ -152,7 +152,7 @@ object VeloxRasSuite {
def newRas(rasRules: Seq[RasRule[SparkPlan]]): Ras[SparkPlan] = {
GlutenOptimization
.builder()
.costModel(sessionCostModel())
.costModel(EnumeratedTransform.asRasCostModel(sessionCostModel()))
.addRules(rasRules)
.create()
.asInstanceOf[Ras[SparkPlan]]
Expand Down Expand Up @@ -205,27 +205,27 @@ object VeloxRasSuite {

class UserCostModel1 extends GlutenCostModel {
private val base = legacyCostModel()
override def costOf(node: SparkPlan): Cost = node match {
override def costOf(node: SparkPlan): GlutenCost = node match {
case _: RowUnary => base.makeInfCost()
case other => base.costOf(other)
}
override def costComparator(): Ordering[Cost] = base.costComparator()
override def makeInfCost(): Cost = base.makeInfCost()
override def sum(one: Cost, other: Cost): Cost = base.sum(one, other)
override def diff(one: Cost, other: Cost): Cost = base.diff(one, other)
override def makeZeroCost(): Cost = base.makeZeroCost()
override def costComparator(): Ordering[GlutenCost] = base.costComparator()
override def makeInfCost(): GlutenCost = base.makeInfCost()
override def sum(one: GlutenCost, other: GlutenCost): GlutenCost = base.sum(one, other)
override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = base.diff(one, other)
override def makeZeroCost(): GlutenCost = base.makeZeroCost()
}

class UserCostModel2 extends GlutenCostModel {
private val base = legacyCostModel()
override def costOf(node: SparkPlan): Cost = node match {
override def costOf(node: SparkPlan): GlutenCost = node match {
case _: ColumnarUnary => base.makeInfCost()
case other => base.costOf(other)
}
override def costComparator(): Ordering[Cost] = base.costComparator()
override def makeInfCost(): Cost = base.makeInfCost()
override def sum(one: Cost, other: Cost): Cost = base.sum(one, other)
override def diff(one: Cost, other: Cost): Cost = base.diff(one, other)
override def makeZeroCost(): Cost = base.makeZeroCost()
override def costComparator(): Ordering[GlutenCost] = base.costComparator()
override def makeInfCost(): GlutenCost = base.makeInfCost()
override def sum(one: GlutenCost, other: GlutenCost): GlutenCost = base.sum(one, other)
override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = base.diff(one, other)
override def makeZeroCost(): GlutenCost = base.makeZeroCost()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.component

import org.apache.gluten.extension.columnar.cost.LongCoster
import org.apache.gluten.extension.columnar.transition.ConventionFunc
import org.apache.gluten.extension.injector.Injector

Expand Down Expand Up @@ -69,6 +70,12 @@ trait Component {
*/
def convFuncOverride(): ConventionFunc.Override = ConventionFunc.Override.Empty

/**
* A sequence of [[org.apache.gluten.extension.columnar.cost.LongCoster]] Gluten is using for cost
* evaluation.
*/
def costers(): Seq[LongCoster] = Nil

/** Query planner rules. */
def injectRules(injector: Injector): Unit
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar.enumerated.planner.cost
package org.apache.gluten.extension.columnar.cost

import org.apache.gluten.ras.{Cost, CostModel}

import org.apache.spark.sql.execution.SparkPlan

trait GlutenCostModel extends CostModel[SparkPlan] {
// Returns cost value of one + other.
def sum(one: Cost, other: Cost): Cost
// Returns cost value of one - other.
def diff(one: Cost, other: Cost): Cost

def makeZeroCost(): Cost
}
trait GlutenCost
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar.cost

import org.apache.gluten.component.Component

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.SparkReflectionUtil

/**
* The cost model API of Gluten. Used by:
* 1. RAS planner for cost-based optimization; 2. Transition graph for choosing transition paths.
*/
trait GlutenCostModel {
def costOf(node: SparkPlan): GlutenCost
def costComparator(): Ordering[GlutenCost]
def makeZeroCost(): GlutenCost
def makeInfCost(): GlutenCost
// Returns cost value of one + other.
def sum(one: GlutenCost, other: GlutenCost): GlutenCost
// Returns cost value of one - other.
def diff(one: GlutenCost, other: GlutenCost): GlutenCost
}

object GlutenCostModel extends Logging {
def find(aliasOrClass: String): GlutenCostModel = {
val costModelRegistry = LongCostModel.registry()
// Components should override Backend's costers. Hence, reversed registration order is applied.
Component
.sorted()
.reverse
.flatMap(_.costers())
.foreach(coster => costModelRegistry.register(coster))
val costModel = find(costModelRegistry, aliasOrClass)
costModel
}

private def find(registry: LongCostModel.Registry, aliasOrClass: String): GlutenCostModel = {
if (LongCostModel.Kind.values().contains(aliasOrClass)) {
val kind = LongCostModel.Kind.values()(aliasOrClass)
val model = registry.get(kind)
return model
}
val clazz = SparkReflectionUtil.classForName(aliasOrClass)
logInfo(s"Using user cost model: $aliasOrClass")
val ctor = clazz.getDeclaredConstructor()
ctor.setAccessible(true)
val model: GlutenCostModel = ctor.newInstance().asInstanceOf[GlutenCostModel]
model
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar.enumerated.planner.cost
package org.apache.gluten.extension.columnar.cost

import org.apache.gluten.ras.Cost

case class LongCost(value: Long) extends Cost
case class LongCost(value: Long) extends GlutenCost
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar.enumerated.planner.cost
package org.apache.gluten.extension.columnar.cost

import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.columnar.enumerated.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.Cost

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -39,15 +38,15 @@ abstract class LongCostModel extends GlutenCostModel {
assert(a >= 0)
assert(b >= 0)
val sum = a + b
if (sum < a || sum < b) Long.MaxValue else sum
if (sum < a || sum < b) infLongCost else sum
}

override def sum(one: Cost, other: Cost): LongCost = (one, other) match {
override def sum(one: GlutenCost, other: GlutenCost): LongCost = (one, other) match {
case (LongCost(value), LongCost(otherValue)) => LongCost(safeSum(value, otherValue))
}

// Returns cost value of one - other.
override def diff(one: Cost, other: Cost): Cost = (one, other) match {
override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = (one, other) match {
case (LongCost(value), LongCost(otherValue)) =>
val d = Math.subtractExact(value, otherValue)
require(d >= zeroLongCost, s"Difference between cost $one and $other should not be negative")
Expand All @@ -62,13 +61,13 @@ abstract class LongCostModel extends GlutenCostModel {

def selfLongCostOf(node: SparkPlan): Long

override def costComparator(): Ordering[Cost] = Ordering.Long.on {
override def costComparator(): Ordering[GlutenCost] = Ordering.Long.on {
case LongCost(value) => value
case _ => throw new IllegalStateException("Unexpected cost type")
}

override def makeInfCost(): Cost = LongCost(infLongCost)
override def makeZeroCost(): Cost = LongCost(zeroLongCost)
override def makeInfCost(): GlutenCost = LongCost(infLongCost)
override def makeZeroCost(): GlutenCost = LongCost(zeroLongCost)
}

object LongCostModel extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar.enumerated.planner.cost
package org.apache.gluten.extension.columnar.cost

import org.apache.spark.sql.execution.SparkPlan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar.enumerated.planner.cost
package org.apache.gluten.extension.columnar.cost

import org.apache.gluten.exception.GlutenException

import org.apache.spark.sql.execution.SparkPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.component.Component
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel}
import org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization
import org.apache.gluten.extension.columnar.enumerated.planner.cost.GlutenCostModel
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.injector.Injector
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.RasRule

Expand All @@ -47,11 +48,12 @@ import org.apache.spark.sql.execution._
case class EnumeratedTransform(costModel: GlutenCostModel, rules: Seq[RasRule[SparkPlan]])
extends Rule[SparkPlan]
with LogLevelUtil {
import EnumeratedTransform._

private val optimization = {
GlutenOptimization
.builder()
.costModel(costModel)
.costModel(asRasCostModel(costModel))
.addRules(rules)
.create()
}
Expand Down Expand Up @@ -82,4 +84,18 @@ object EnumeratedTransform {
val call = new ColumnarRuleCall(session, AdaptiveContext(session), false)
dummyInjector.gluten.ras.createEnumeratedTransform(call)
}

def asRasCostModel(gcm: GlutenCostModel): CostModel[SparkPlan] = {
new CostModelAdapter(gcm)
}

/** The adapter to make GlutenCostModel comply with RAS cost model. */
private class CostModelAdapter(gcm: GlutenCostModel) extends CostModel[SparkPlan] {
override def costOf(node: SparkPlan): Cost = CostAdapter(gcm.costOf(node))
override def costComparator(): Ordering[Cost] =
gcm.costComparator().on[Cost] { case CostAdapter(gc) => gc }
override def makeInfCost(): Cost = CostAdapter(gcm.makeInfCost())
}

private case class CostAdapter(gc: GlutenCost) extends Cost
}
Loading

0 comments on commit 1520458

Please sign in to comment.