Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Rename TransformHint to FallbackTag #6254

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.extension.{CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.columnar.AddTransformHintRule
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.sql.shims.SparkShimLoader
Expand Down Expand Up @@ -146,7 +146,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {

child match {
case scan: FileSourceScanExec if (checkMergeTreeFileFormat(scan.relation)) =>
// For the validation phase of the AddTransformHintRule
// For the validation phase of the AddFallbackTagRule
CHFilterExecTransformer(condition, child)
case scan: FileSourceScanExecTransformerBase if (checkMergeTreeFileFormat(scan.relation)) =>
// For the transform phase, the FileSourceScanExec is already transformed
Expand Down Expand Up @@ -226,7 +226,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
// FIXME: The operation happens inside ReplaceSingleNode().
// Caller may not know it adds project on top of the shuffle.
val project = TransformPreOverrides().apply(
AddTransformHintRule().apply(
AddFallbackTagRule().apply(
ProjectExec(plan.child.output ++ projectExpressions, plan.child)))
var newExprs = Seq[Expression]()
for (i <- exprs.indices) {
Expand All @@ -251,7 +251,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
// FIXME: The operation happens inside ReplaceSingleNode().
// Caller may not know it adds project on top of the shuffle.
val project = TransformPreOverrides().apply(
AddTransformHintRule().apply(
AddFallbackTagRule().apply(
ProjectExec(plan.child.output ++ projectExpressions, plan.child)))
var newOrderings = Seq[SortOrder]()
for (i <- orderings.indices) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.gluten.extension
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.TransformHints.EncodeTransformableTagImplicits
import org.apache.gluten.extension.columnar.FallbackTags.EncodeFallbackTagImplicits
import org.apache.gluten.utils.PhysicalPlanSelector

import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -61,7 +61,7 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
"columnar broadcast exchange is disabled or " +
"columnar broadcast join is disabled")
} else {
if (TransformHints.isNotTransformable(bhj)) {
if (FallbackTags.nonEmpty(bhj)) {
ValidationResult.notOk("broadcast join is already tagged as not transformable")
} else {
val bhjTransformer = BackendsApiManager.getSparkPlanExecApiInstance
Expand All @@ -83,8 +83,8 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
}
}
}
TransformHints.tagNotTransformable(bhj, isTransformable)
TransformHints.tagNotTransformable(exchange, isTransformable)
FallbackTags.add(bhj, isTransformable)
FallbackTags.add(exchange, isTransformable)
case _ =>
// Skip. This might be the case that the exchange was already
// executed in earlier stage
Expand Down Expand Up @@ -116,7 +116,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
// Currently their doBroadcast() methods just propagate child's broadcast
// payloads which is not right in speaking of columnar.
if (!enableColumnarBroadcastJoin) {
TransformHints.tagNotTransformable(
FallbackTags.add(
bhj,
"columnar BroadcastJoin is not enabled in BroadcastHashJoinExec")
} else {
Expand Down Expand Up @@ -149,7 +149,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
case Some(exchange @ BroadcastExchangeExec(mode, child)) =>
isBhjTransformable.tagOnFallback(bhj)
if (!isBhjTransformable.isValid) {
TransformHints.tagNotTransformable(exchange, isBhjTransformable)
FallbackTags.add(exchange, isBhjTransformable)
}
case None =>
// we are in AQE, find the hidden exchange
Expand Down Expand Up @@ -182,7 +182,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
// to conform to the underlying exchange's type, columnar or vanilla
exchange match {
case BroadcastExchangeExec(mode, child) =>
TransformHints.tagNotTransformable(
FallbackTags.add(
bhj,
"it's a materialized broadcast exchange or reused broadcast exchange")
case ColumnarBroadcastExchangeExec(mode, child) =>
Expand All @@ -199,7 +199,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
}
} catch {
case e: UnsupportedOperationException =>
TransformHints.tagNotTransformable(
FallbackTags.add(
p,
s"${e.getMessage}, original Spark plan is " +
s"${p.getClass}(${p.children.toList.map(_.getClass)})")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.gluten.expression._
import org.apache.gluten.expression.ExpressionNames.{TRANSFORM_KEYS, TRANSFORM_VALUES}
import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar.TransformHints
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride
import org.apache.gluten.sql.shims.SparkShimLoader
Expand Down Expand Up @@ -371,7 +371,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
val newChild = maybeAddAppendBatchesExec(projectTransformer)
ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output.drop(1))
} else {
TransformHints.tagNotTransformable(shuffle, validationResult)
FallbackTags.add(shuffle, validationResult)
shuffle.withNewChildren(child :: Nil)
}
case RoundRobinPartitioning(num) if SQLConf.get.sortBeforeRepartition && num > 1 =>
Expand All @@ -397,7 +397,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
projectTransformer
} else {
val project = ProjectExec(projectList, child)
TransformHints.tagNotTransformable(project, projectBeforeSortValidationResult)
FallbackTags.add(project, projectBeforeSortValidationResult)
project
}
val sortOrder = SortOrder(projectBeforeSort.output.head, Ascending)
Expand All @@ -410,7 +410,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
val newChild = maybeAddAppendBatchesExec(dropSortColumnTransformer)
ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output)
} else {
TransformHints.tagNotTransformable(shuffle, validationResult)
FallbackTags.add(shuffle, validationResult)
shuffle.withNewChildren(child :: Nil)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.execution

import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.extension.columnar.TransformHints
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.catalyst.expressions.Expression
Expand Down Expand Up @@ -99,7 +99,7 @@ object ScanTransformerFactory {
transformer
} else {
val newSource = batchScan.copy(runtimeFilters = transformer.runtimeFilters)
TransformHints.tagNotTransformable(newSource, validationResult.reason.get)
FallbackTags.add(newSource, validationResult.reason.get)
newSource
}
} else {
Expand All @@ -109,7 +109,7 @@ object ScanTransformerFactory {
if (validation) {
throw new GlutenNotSupportException(s"Unsupported scan ${batchScan.scan}")
}
TransformHints.tagNotTransformable(batchScan, "The scan in BatchScanExec is not supported.")
FallbackTags.add(batchScan, "The scan in BatchScanExec is not supported.")
batchScan
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ object EnsureLocalSortRequirements extends Rule[SparkPlan] {
requiredOrdering: Seq[SortOrder]): SparkPlan = {
val newChild = SortExec(requiredOrdering, global = false, child = originalChild)
if (!GlutenConfig.getConf.enableColumnarSort) {
TransformHints.tagNotTransformable(newChild, "columnar Sort is not enabled in SortExec")
FallbackTags.add(newChild, "columnar Sort is not enabled in SortExec")
newChild
} else {
val newChildWithTransformer =
Expand All @@ -50,7 +50,7 @@ object EnsureLocalSortRequirements extends Rule[SparkPlan] {
if (validationResult.isValid) {
newChildWithTransformer
} else {
TransformHints.tagNotTransformable(newChild, validationResult)
FallbackTags.add(newChild, validationResult)
newChild
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,11 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
// Propagate fallback reason to vanilla SparkPlan
glutenPlan.foreach {
case _: GlutenPlan =>
case p: SparkPlan if TransformHints.isNotTransformable(p) && p.logicalLink.isDefined =>
case p: SparkPlan if FallbackTags.nonEmpty(p) && p.logicalLink.isDefined =>
originalPlan
.find(_.logicalLink.exists(_.fastEquals(p.logicalLink.get)))
.filterNot(TransformHints.isNotTransformable)
.foreach(origin => TransformHints.tag(origin, TransformHints.getHint(p)))
.filterNot(FallbackTags.nonEmpty)
.foreach(origin => FallbackTags.tag(origin, FallbackTags.getTag(p)))
case _ =>
}

Expand Down Expand Up @@ -278,7 +278,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
) {
plan
} else {
TransformHints.tagAllNotTransformable(
FallbackTags.addRecursively(
vanillaSparkPlan,
TRANSFORM_UNSUPPORTED(fallbackInfo.reason, appendReasonIfExists = false))
FallbackNode(vanillaSparkPlan)
Expand Down
Loading
Loading