diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/CustomShuffleReaderExecParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/CustomShuffleReaderExecParser.scala deleted file mode 100644 index b70258b5a..000000000 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/CustomShuffleReaderExecParser.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.tool.planparser - -import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker - -import org.apache.spark.sql.execution.ui.SparkPlanGraphNode - -case class CustomShuffleReaderExecParser( - override val node: SparkPlanGraphNode, - override val checker: PluginTypeChecker, - override val sqlID: Long) extends GenericExecParser(node, checker,sqlID) { - - // note this is called either AQEShuffleRead and CustomShuffleReader depending - // on the Spark version, our supported ops list it as CustomShuffleReader - override val fullExecName = "CustomShuffleReaderExec" -} diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/GenericExecParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/GenericExecParser.scala index 5fb59630e..59c440fad 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/GenericExecParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/GenericExecParser.scala @@ -25,11 +25,12 @@ class GenericExecParser( val node: SparkPlanGraphNode, val checker: PluginTypeChecker, val sqlID: Long, + val execName: Option[String] = None, val expressionFunction: Option[String => Array[String]] = None, val app: Option[AppBase] = None ) extends ExecParser { - val fullExecName: String = node.name + "Exec" + val fullExecName: String = execName.getOrElse(node.name + "Exec") override def parse: ExecInfo = { val duration = computeDuration @@ -99,9 +100,11 @@ object GenericExecParser { node: SparkPlanGraphNode, checker: PluginTypeChecker, sqlID: Long, + execName: Option[String] = None, expressionFunction: Option[String => Array[String]] = None, app: Option[AppBase] = None ): GenericExecParser = { - new GenericExecParser(node, checker, sqlID, expressionFunction, app) + val fullExecName = execName.getOrElse(node.name + "Exec") + new GenericExecParser(node, checker, sqlID, Some(fullExecName), expressionFunction, app) } } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala index eae0564e3..daf92921c 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala @@ -27,7 +27,7 @@ case class ObjectHashAggregateExecParser( override val checker: PluginTypeChecker, override val sqlID: Long, appParam: AppBase) extends - GenericExecParser(node, checker, sqlID) with Logging { + GenericExecParser(node, checker, sqlID, app = Some(appParam)) with Logging { override def computeDuration: Option[Long] = { val accumId = node.metrics.find(_.name == "time in aggregation build").map(_.accumulatorId) SQLPlanParser.getTotalDuration(accumId, appParam) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala index fbd7c182e..44c72c86a 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala @@ -26,6 +26,7 @@ import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlanInfo +//import org.apache.spark.sql.execution.joins.CartesianProductExec import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode} import org.apache.spark.sql.rapids.tool.{AppBase, BuildSide, ExecHelper, JoinType, RDDCheckHelper, ToolUtils, UnsupportedExpr} import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph @@ -463,10 +464,12 @@ object SQLPlanParser extends Logging { app: AppBase): ExecInfo = { val normalizedNodeName = node.name.stripSuffix("$") normalizedNodeName match { - case "AggregateInPandas" => - GenericExecParser(node, checker, sqlID).parse - case "ArrowEvalPython" => - GenericExecParser(node, checker, sqlID).parse + // Generalize all the execs that call GenericExecParser in one case + case "AggregateInPandas" | "ArrowEvalPython" | "CartesianProduct" | "Coalesce" + | "CollectLimit" | "FlatMapGroupsInPandas" | "GlobalLimit" | "LocalLimit" + | "InMemoryTableScan" | "MapInPandas" | "PythonMapInArrow" | "MapInArrow" | "Range" + | "Sample" | "Union" | "WindowInPandas" => + GenericExecParser(node, checker, sqlID, app = Some(app)).parse case "BatchScan" => BatchScanExecParser(node, checker, sqlID, app).parse case "BroadcastExchange" => @@ -475,54 +478,41 @@ object SQLPlanParser extends Logging { BroadcastHashJoinExecParser(node, checker, sqlID).parse case "BroadcastNestedLoopJoin" => BroadcastNestedLoopJoinExecParser(node, checker, sqlID).parse - case "CartesianProduct" => - GenericExecParser(node, checker, sqlID).parse - case "Coalesce" => - GenericExecParser(node, checker, sqlID).parse - case "CollectLimit" => - GenericExecParser(node, checker, sqlID).parse + // This is called either AQEShuffleRead and CustomShuffleReader depending + // on the Spark version, our supported ops list it as CustomShuffleReader case "CustomShuffleReader" | "AQEShuffleRead" => - CustomShuffleReaderExecParser(node, checker, sqlID).parse + GenericExecParser( + node, checker, sqlID, execName = Some("CustomShuffleReaderExec")).parse case "Exchange" => ShuffleExchangeExecParser(node, checker, sqlID, app).parse case "Expand" => - GenericExecParser(node, checker, sqlID, Some(parseExpandExpressions)).parse + GenericExecParser( + node, checker, sqlID, expressionFunction = Some(parseExpandExpressions)).parse case "Filter" => - GenericExecParser(node, checker, sqlID, Some(parseFilterExpressions)).parse - case "FlatMapGroupsInPandas" => - GenericExecParser(node, checker, sqlID).parse + GenericExecParser( + node, checker, sqlID, expressionFunction = Some(parseFilterExpressions)).parse case "Generate" => - GenericExecParser(node, checker, sqlID, Some(parseGenerateExpressions)).parse - case "GlobalLimit" => - GenericExecParser(node, checker, sqlID).parse + GenericExecParser( + node, checker, sqlID, expressionFunction = Some(parseGenerateExpressions)).parse case "HashAggregate" => HashAggregateExecParser(node, checker, sqlID, app).parse - case "LocalLimit" => - GenericExecParser(node, checker, sqlID).parse - case "InMemoryTableScan" => - GenericExecParser(node, checker, sqlID).parse case i if DataWritingCommandExecParser.isWritingCmdExec(i) => DataWritingCommandExecParser.parseNode(node, checker, sqlID) - case "MapInPandas" => - GenericExecParser(node, checker, sqlID).parse case "ObjectHashAggregate" => - ObjectHashAggregateExecParser(node, checker, sqlID, app).parse + ObjectHashAggregateExecParser(node, checker, sqlID, appParam = app).parse case "Project" => - GenericExecParser(node, checker, sqlID, Some(parseProjectExpressions)).parse - case "PythonMapInArrow" | "MapInArrow" => - GenericExecParser(node, checker, sqlID).parse - case "Range" => - GenericExecParser(node, checker, sqlID).parse - case "Sample" => - GenericExecParser(node, checker, sqlID).parse + GenericExecParser( + node, checker, sqlID, expressionFunction = Some(parseProjectExpressions)).parse case "ShuffledHashJoin" => ShuffledHashJoinExecParser(node, checker, sqlID, app).parse case "Sort" => - GenericExecParser(node, checker, sqlID, Some(parseSortExpressions)).parse + GenericExecParser( + node, checker, sqlID, expressionFunction = Some(parseSortExpressions)).parse case s if ReadParser.isScanNode(s) => FileSourceScanExecParser(node, checker, sqlID, app).parse case "SortAggregate" => - GenericExecParser(node, checker, sqlID, Some(parseAggregateExpressions)).parse + GenericExecParser( + node, checker, sqlID, expressionFunction = Some(parseAggregateExpressions)).parse case smj if SortMergeJoinExecParser.accepts(smj) => SortMergeJoinExecParser(node, checker, sqlID).parse case "SubqueryBroadcast" => @@ -530,13 +520,11 @@ object SQLPlanParser extends Logging { case sqe if SubqueryExecParser.accepts(sqe) => SubqueryExecParser.parseNode(node, checker, sqlID, app) case "TakeOrderedAndProject" => - GenericExecParser(node, checker, sqlID, Some(parseTakeOrderedExpressions)).parse - case "Union" => - GenericExecParser(node, checker, sqlID).parse + GenericExecParser( + node, checker, sqlID, expressionFunction = Some(parseTakeOrderedExpressions)).parse case "Window" => - GenericExecParser(node, checker, sqlID, Some(parseWindowExpressions)).parse - case "WindowInPandas" => - GenericExecParser(node, checker, sqlID).parse + GenericExecParser( + node, checker, sqlID, expressionFunction = Some(parseWindowExpressions)).parse case "WindowGroupLimit" => WindowGroupLimitParser(node, checker, sqlID).parse case wfe if WriteFilesExecParser.accepts(wfe) =>