diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala index 7e17767b9..9580aa470 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala @@ -18,7 +18,6 @@ package com.nvidia.spark.rapids.tool.analysis import scala.collection.mutable.{AbstractSet, ArrayBuffer, HashMap, LinkedHashSet} -import com.nvidia.spark.rapids.tool.planparser.SQLPlanParser import com.nvidia.spark.rapids.tool.profiling.{AccumProfileResults, SQLAccumProfileResults, SQLMetricInfoCase, SQLStageInfoProfileResult, UnsupportedSQLPlan, WholeStageCodeGenResults} import com.nvidia.spark.rapids.tool.qualification.QualSQLPlanAnalyzer @@ -88,7 +87,8 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap // Maps stages to operators by checking for non-zero intersection // between nodeMetrics and stageAccumulateIDs val nodeIdToStage = planGraph.allNodes.map { node => - val mappedStages = SQLPlanParser.getStagesInSQLNode(node, app) + val nodeAccums = node.metrics.map(_.accumulatorId) + val mappedStages = app.getStageIDsFromAccumIds(nodeAccums) ((sqlId, node.id), mappedStages) }.toMap sqlPlanNodeIdToStageIds ++= nodeIdToStage diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala index 1b8d63a5f..ad2634c41 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala @@ -26,7 +26,6 @@ import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlanInfo -//import org.apache.spark.sql.execution.joins.CartesianProductExec import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode} import org.apache.spark.sql.rapids.tool.{AppBase, BuildSide, ExecHelper, JoinType, RDDCheckHelper, ToolUtils, UnsupportedExpr} import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph @@ -433,22 +432,19 @@ object SQLPlanParser extends Logging { sqlDesc: String, checker: PluginTypeChecker, app: AppBase): PlanInfo = { - val planGraph = ToolsPlanGraph(planInfo) + val toolsGraph = ToolsPlanGraph.createGraphWithStageClusters(planInfo, app) + // Find all the node graphs that should be excluded and send it to the parsePlanNode - val excludedNodes = buildSkippedReusedNodesForPlan(planGraph) + val excludedNodes = buildSkippedReusedNodesForPlan(toolsGraph.sparkGraph) // we want the sub-graph nodes to be inside of the wholeStageCodeGen so use nodes // vs allNodes - val execInfos = planGraph.nodes.flatMap { node => - parsePlanNode(node, sqlID, checker, app, reusedNodeIds = excludedNodes) + val execInfos = toolsGraph.nodes.flatMap { node => + parsePlanNode(node, sqlID, checker, app, reusedNodeIds = excludedNodes, + nodeIdToStagesFunc = toolsGraph.getNodeStageClusters) } PlanInfo(appID, sqlID, sqlDesc, execInfos) } - def getStagesInSQLNode(node: SparkPlanGraphNode, app: AppBase): Set[Int] = { - val nodeAccums = node.metrics.map(_.accumulatorId) - nodeAccums.flatMap(app.accumManager.getAccStageIds).toSet - } - // Set containing execs that refers to other expressions. We need this to be a list to allow // appending more execs in teh future as necessary. // Note that Spark graph may create duplicate nodes when any of the following execs exists. @@ -541,7 +537,8 @@ object SQLPlanParser extends Logging { sqlID: Long, checker: PluginTypeChecker, app: AppBase, - reusedNodeIds: Set[Long] + reusedNodeIds: Set[Long], + nodeIdToStagesFunc: Long => Set[Int] ): Seq[ExecInfo] = { // Avoid counting duplicate nodes. We mark them as shouldRemove to neutralize their impact on // speedups. @@ -560,9 +557,11 @@ object SQLPlanParser extends Logging { // For WholeStageCodegen clusters, use PhotonStageExecParser if the cluster is of Photon type. // Else, fall back to WholeStageExecParser to parse the cluster. case photonCluster: PhotonSparkPlanGraphCluster => - PhotonStageExecParser(photonCluster, checker, sqlID, app, reusedNodeIds).parse + PhotonStageExecParser(photonCluster, checker, sqlID, app, reusedNodeIds, + nodeIdToStagesFunc = nodeIdToStagesFunc).parse case cluster: SparkPlanGraphCluster => - WholeStageExecParser(cluster, checker, sqlID, app, reusedNodeIds).parse + WholeStageExecParser(cluster, checker, sqlID, app, reusedNodeIds, + nodeIdToStagesFunc = nodeIdToStagesFunc).parse case _ => // For individual nodes, use PhotonPlanParser if the node is of Photon type. // Else, fall back to the Spark node parsing logic to parse the node. @@ -587,7 +586,7 @@ object SQLPlanParser extends Logging { ExecInfo(node, sqlID, normalizedNodeName, expr = "", 1, duration = None, node.id, isSupported = false, None) } - val stagesInNode = getStagesInSQLNode(node, app) + val stagesInNode = nodeIdToStagesFunc(node.id) execInfo.setStages(stagesInNode) // shouldRemove is set to true if the exec is a member of "execsToBeRemoved" or if the node // is a duplicate diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala index 19eb46a52..cdef3b225 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala @@ -27,7 +27,8 @@ abstract class WholeStageExecParserBase( checker: PluginTypeChecker, sqlID: Long, app: AppBase, - reusedNodeIds: Set[Long]) extends Logging { + reusedNodeIds: Set[Long], + nodeIdToStagesFunc: Long => Set[Int]) extends Logging { val fullExecName = "WholeStageCodegenExec" @@ -38,16 +39,15 @@ abstract class WholeStageExecParserBase( // Perhaps take the max of those in Stage? val accumId = node.metrics.find(_.name == "duration").map(_.accumulatorId) val maxDuration = SQLPlanParser.getTotalDuration(accumId, app) - val stagesInNode = SQLPlanParser.getStagesInSQLNode(node, app) + val stagesInNode = nodeIdToStagesFunc.apply(node.id) // We could skip the entire wholeStage if it is duplicate; but we will lose the information of // the children nodes. val isDupNode = reusedNodeIds.contains(node.id) val childNodes = node.nodes.flatMap { c => - SQLPlanParser.parsePlanNode(c, sqlID, checker, app, reusedNodeIds) + // Pass the nodeToStagesFunc to the child nodes so they can get the stages. + SQLPlanParser.parsePlanNode(c, sqlID, checker, app, reusedNodeIds, + nodeIdToStagesFunc = nodeIdToStagesFunc) } - // For the childNodes, we need to append the stages. Otherwise, nodes without metrics won't be - // assigned to stage - childNodes.foreach(_.appendToStages(stagesInNode)) // if any of the execs in WholeStageCodegen supported mark this entire thing as supported val anySupported = childNodes.exists(_.isSupported == true) val unSupportedExprsArray = @@ -55,14 +55,11 @@ abstract class WholeStageExecParserBase( // average speedup across the execs in the WholeStageCodegen for now val supportedChildren = childNodes.filterNot(_.shouldRemove) val avSpeedupFactor = SQLPlanParser.averageSpeedup(supportedChildren.map(_.speedupFactor)) - // can't rely on the wholeStagecodeGen having a stage if children do so aggregate them together - // for now - val allStagesIncludingChildren = childNodes.flatMap(_.stages).toSet ++ stagesInNode.toSet - // Finally, the node should be marked as shouldRemove when all the children of the + // The node should be marked as shouldRemove when all the children of the // wholeStageCodeGen are marked as shouldRemove. val removeNode = isDupNode || childNodes.forall(_.shouldRemove) val execInfo = ExecInfo(node, sqlID, node.name, node.name, avSpeedupFactor, maxDuration, - node.id, anySupported, Some(childNodes), allStagesIncludingChildren, + node.id, anySupported, Some(childNodes), stagesInNode, shouldRemove = removeNode, unsupportedExprs = unSupportedExprsArray) Seq(execInfo) } @@ -73,5 +70,6 @@ case class WholeStageExecParser( checker: PluginTypeChecker, sqlID: Long, app: AppBase, - reusedNodeIds: Set[Long]) - extends WholeStageExecParserBase(node, checker, sqlID, app, reusedNodeIds) + reusedNodeIds: Set[Long], + nodeIdToStagesFunc: Long => Set[Int]) + extends WholeStageExecParserBase(node, checker, sqlID, app, reusedNodeIds, nodeIdToStagesFunc) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/photon/PhotonStageExecParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/photon/PhotonStageExecParser.scala index d49192d27..ceb59e712 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/photon/PhotonStageExecParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/photon/PhotonStageExecParser.scala @@ -35,5 +35,6 @@ case class PhotonStageExecParser( checker: PluginTypeChecker, sqlID: Long, app: AppBase, - reusedNodeIds: Set[Long]) - extends WholeStageExecParserBase(node, checker, sqlID, app, reusedNodeIds) + reusedNodeIds: Set[Long], + nodeIdToStagesFunc: Long => Set[Int]) + extends WholeStageExecParserBase(node, checker, sqlID, app, reusedNodeIds, nodeIdToStagesFunc) diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AccumToStageRetriever.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AccumToStageRetriever.scala new file mode 100644 index 000000000..7da118a0e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AccumToStageRetriever.scala @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.tool + +/** + * Trait that defines the interface for retrieving stage IDs from accumulables. + * This is used to map accumulables to stages. We use it as interface in order to allow to separate + * the logic and use dummy different implementations and mocks for testing when needed. + */ +trait AccumToStageRetriever { + /** + * Given a sequence of accumIds, return a set of stage IDs that are associated with the + * accumIds. Note that this method can only be called after the accumulables have been fully + * processed. + */ + def getStageIDsFromAccumIds(accumIds: Seq[Long]): Set[Int] +} diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala index e3313b832..f8d4d7703 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala @@ -42,7 +42,9 @@ import org.apache.spark.util.Utils abstract class AppBase( val eventLogInfo: Option[EventLogInfo], - val hadoopConf: Option[Configuration]) extends Logging with ClusterTagPropHandler { + val hadoopConf: Option[Configuration]) extends Logging + with ClusterTagPropHandler + with AccumToStageRetriever { var appMetaData: Option[AppMetaData] = None @@ -105,6 +107,10 @@ abstract class AppBase( def sqlPlans: immutable.Map[Long, SparkPlanInfo] = sqlManager.getPlanInfos + def getStageIDsFromAccumIds(accumIds: Seq[Long]): Set[Int] = { + accumIds.flatMap(accumManager.getAccStageIds).toSet + } + // Returns the String value of the eventlog or empty if it is not defined. Note that the eventlog // won't be defined for running applications def getEventLogPath: String = { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/ToolsPlanGraph.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/ToolsPlanGraph.scala index 6d6187d80..c9a1f1d3f 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/ToolsPlanGraph.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/ToolsPlanGraph.scala @@ -16,16 +16,375 @@ package org.apache.spark.sql.rapids.tool.util -import com.nvidia.spark.rapids.tool.planparser.DatabricksParseHelper import java.util.concurrent.atomic.AtomicLong + import scala.collection.mutable +import com.nvidia.spark.rapids.tool.planparser.DatabricksParseHelper + import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphEdge, SparkPlanGraphNode, SQLPlanMetric} +import org.apache.spark.sql.execution.ui._ +import org.apache.spark.sql.rapids.tool.AccumToStageRetriever import org.apache.spark.sql.rapids.tool.store.AccumNameRef import org.apache.spark.sql.rapids.tool.util.plangraph.PlanGraphTransformer import org.apache.spark.sql.rapids.tool.util.stubs.{GraphReflectionAPI, GraphReflectionAPIHelper} +/** + * A wrapper of the original SparkPlanGraph with additional information about the + * node-to-stage mapping. + * 1- The graph is constructed by visiting PlanInfos and creating GraphNodes and Edges. + * Although it is more efficient to assign stages during the construction of the nodes, + * the design is intentionally keeping those two phases separate to make the code more modular + * and easier to maintain. + * 2- Traverse the nodes and assign them to stages based on the metrics. + * 3- Nodes that belong to a graph cluster (childs of WholeStageCodeGen) while missing + * metrics, are assigned same as their WholeStageCodeGen node. + * 4- Iterate on all the orphanNodes and assign them to stages based on their adjacents nodes. + * 5- The iterative process is repeated until no assignment can be made. + * + * @param sparkGraph the original SparkPlanGraph to wrap + * @param accumToStageRetriever The object that eventually can retrieve StageIDs from AccumIds. + */ +class ToolsPlanGraph(val sparkGraph: SparkPlanGraph, + accumToStageRetriever: AccumToStageRetriever) { + // A map between SQLNode Id and the clusterIds that the node belongs to. + // Here, a clusterId means a stageId. + // Note: It is possible to represent the clusters as map [clusterId, Set[SQLNodeIds]]. + // While this is more memory efficient, it is more time-consuming to find the clusters a + // node belongs to since we have to iterate through all the keys. + private val nodeToStageCluster: mutable.Map[Long, Set[Int]] = mutable.HashMap[Long, Set[Int]]() + // shortcut to the nodes + def nodes: collection.Seq[SparkPlanGraphNode] = sparkGraph.nodes + // shortcut to the edges + def edges: collection.Seq[SparkPlanGraphEdge] = sparkGraph.edges + // delegate the call to the original graph + def allNodes: collection.Seq[SparkPlanGraphNode] = sparkGraph.allNodes + + /** + * Get stages that are associated with the accumulators of the node. + * Use this method if the purpose is to get raw information about the node-stage relationship + * based on the AccumIds without applying any logic. + * @param node the node to get the stages for + * @return a set of stageIds or empty if None + */ + private def getNodeStagesByAccum(node: SparkPlanGraphNode): Set[Int] = { + val nodeAccums = node.metrics.map(_.accumulatorId) + accumToStageRetriever.getStageIDsFromAccumIds(nodeAccums) + } + + /** + * Get the stages that the node belongs to. This function is used to get all the stages that can + * be assigned to a node. For example, if we want to get the "Exchange" node stages, then we call + * that method. + * @param node the node to get the stages for + * @return a set of stageIds or empty if None + */ + def getAllNodeStages(node: SparkPlanGraphNode): Set[Int] = { + val stageIdsByAccum = getNodeStagesByAccum(node) + nodeToStageCluster.get(node.id) match { + case Some(stageId) => stageIdsByAccum ++ stageId + case _ => stageIdsByAccum + } + } + + /** + * Check if a node exec is an epilogue. A.k.a, the exec has to be the tail of a stage. + * @param nodeName normalized node name (i.e., no GPU prefix) + * @return true if the node is an epilogue exec + */ + private def isEpilogueExec(nodeName: String): Boolean = { + nodeName match { + case "Exchange" | "BroadcastQueryStage" | "ShuffleQueryStage" | "TableCacheQueryStage" + | "ResultQueryStage" | "BroadcastExchange" => + true + case _ => false + } + } + + /** + * Check if a node exec is a prologue. A.k.a, the exec has to be the head of a stage. + * @param nodeName normalized node name (i.e., no GPU prefix) + * @return true if the node is a prologue exec + */ + private def isPrologueExec(nodeName: String): Boolean = { + nodeName match { + case nName if nName.contains("ShuffleRead") => + true + case _ => false + } + } + + /** + * Given a nodeName, this method returns a code that represents the node type. + * For example, an exchange node has to be at the end of a stage. ShuffleRead has to be at the + * beginning of a stage and so. + * + * @param nodeName the normalized name of the sparkNode (i.e., no GPU prefix). + * @return a code representing thenode type: + * (1) if the node can be assigned based on incoming edges. i.e., all nodes except the + * head of stage like shuffleRead. + * (2) if the node can be assigned based on outgoing edges. i.e., + * all nodes except the tail of stage like shuffleWrite/exchange. + * (3) if the node can be assigned based on both incoming and outgoing edges. + */ + private def multiplexCases(nodeName: String): Int = { + // nodes like shuffleRead should not be assigned to incoming edges + var result = 0 + if (!isPrologueExec(nodeName)) { + // Those are the nodes that can be assigned based on incoming edges. + result |= 1 + } + if (!isEpilogueExec(nodeName)) { + // Those are the nodes that can be assigned based on outgoing edges. + result |= 2 + } + result + } + + /** + * This method is used to assign a node to clusterID during the first walk of the graph. + * A cluster is used to wrap nodes together this could be a stageId. + * @param node the sparkNode to assign + * @return the clusterId that the node belongs to + */ + private def populateNodeClusters(node: SparkPlanGraphNode): Set[Int] = { + // First normalize the name. + val normalizedName = ToolsPlanGraph.processPlanInfo(node.name) + val stageIds = getNodeStagesByAccum(node) + normalizedName match { + case nName if isEpilogueExec(nName) => + // Cases that are tail of the stage cluster. + if (stageIds.size <= 1) { + stageIds + } else { + // Only use the smallest StageId because this would represent the stage that writes + // the data. + Set[Int](stageIds.min) + } + case nName if isPrologueExec(nName) => + // Cases that are head of a new stage. + if (stageIds.size <= 1) { + ToolsPlanGraph.EMPTY_CLUSTERS + } else { + // We should pick the stages associated with the reading metrics. This is likely to be + // the stage with the highest ID value. + Set[Int](stageIds.max) + } + case _ => + // Everything else goes here. + // It is possible to have multiple stages for a given node. + stageIds + } + } + + /** + * Updates the data structure that keeps track of the nodes cluster assignment. + * It adds the node to the map and remove the node from the orphans list if it exists. + * @param node the node to be assigned. + * @param orphanNodes the list of nodes that are not assigned to any cluster. + * @param clusters the clusterIds to assign the node to + */ + private def removeNodeFromOrphans(node: SparkPlanGraphNode, + orphanNodes: mutable.ArrayBuffer[SparkPlanGraphNode], + clusters: Set[Int]): Unit = { + nodeToStageCluster.put(node.id, clusters) + orphanNodes -= node + } + + /** + * Commits a wholeStageNode to a cluster. + * A WholeStageNode is visited after its children are. If any of the children is not assigned to + * a cluster, the wNode will transfer its assignment to the child. + * @param wNode the wholeStageCodeGen node to be visited + * @param orphanNodes the list of nodes that are not assigned to any cluster. + * @param clusters the clusterId to assign the node to. + * @return true if a change is made. + */ + private def commitNodeToStageCluster( + wNode: SparkPlanGraphCluster, + orphanNodes: mutable.ArrayBuffer[SparkPlanGraphNode], + clusters: Set[Int]): Boolean = { + if (nodeToStageCluster.contains(wNode.id) && clusters.subsetOf(nodeToStageCluster(wNode.id))) { + // Nothing to do since the node is assigned to the same cluster before. + false + } else { + val newClusterIds = + clusters ++ nodeToStageCluster.getOrElse(wNode.id, ToolsPlanGraph.EMPTY_CLUSTERS) + // Remove the wNode from orphanNodes if it exists + removeNodeFromOrphans(wNode, orphanNodes, newClusterIds) + // Assign the children to the same clusters if any of them is not assigned already. + wNode.nodes.foreach { childNode => + if (!nodeToStageCluster.contains(childNode.id)) { + // Assign the child node to the same stage of wNode and remove it from orphans + removeNodeFromOrphans(childNode, orphanNodes, newClusterIds) + } + } + true + } + } + + /** + * Assign a node to a clusterId. This method is used to assign a node to a clusterId during the + * first visit. + * @param node sparkNode to be assigned + * @param orphanNodes the list of nodes that are not assigned to any cluster + * @param clusters the clusterIds to assign the node to + * @return true if a change is made. + */ + private def commitNodeToStageCluster( + node: SparkPlanGraphNode, + orphanNodes: mutable.ArrayBuffer[SparkPlanGraphNode], + clusters: Set[Int]): Boolean = { + node match { + case cluster: SparkPlanGraphCluster => + // WholeCodeGen represents a special case because it propagates its assignment to + // children nodes. + commitNodeToStageCluster(cluster, orphanNodes, clusters) + case _ => + removeNodeFromOrphans(node, orphanNodes, clusters) + true + } + } + + /** + * Walk through the graph nodes and assign them to the correct stage cluster. + */ + protected def assignNodesToStageClusters(): Unit = { + // Keep track of nodes that have no assignment to any cluster. + val orphanNodes = mutable.ArrayBuffer[SparkPlanGraphNode]() + // Step(1): Visit all the nodes and assign them to the correct cluster based on AccumIDs. + // In the process, WholeStageCodeGens propagate their assignment to the child nodes if + // they are orphans. + allNodes.foreach { node => + if (!nodeToStageCluster.contains(node.id)) { + // Get clusterIDs based on AccumIds + val clusterIds = populateNodeClusters(node) + if (clusterIds.nonEmpty) { + // Found assignment + commitNodeToStageCluster(node, orphanNodes, clusterIds) + } else { + // This node has no assignment. Add it to the orphanNodes + orphanNodes += node + } + } + } + // Step(2): At this point, we made a quick visit handling all the straightforward cases. + // Iterate on the orphanNodes and try to assign them based on the adjacent nodes. + var changeFlag = orphanNodes.nonEmpty + while (changeFlag) { + // Iterate on the orphanNodes and try to assign them based on the adjacent nodes until no + // changes can be done in a single iteration. + changeFlag = false + // P.S: Copy the orphanNodes because we cannot remove objects inside the loop. + val orphanNodesCopy = orphanNodes.clone() + orphanNodesCopy.foreach { currNode => + if (orphanNodes.contains(currNode)) { // Avoid dup processing caused by wholeStageCodeGen + val currNodeName = ToolsPlanGraph.processPlanInfo(currNode.name) + val updatedFlag = currNode match { + case wNode: SparkPlanGraphCluster => + // WholeStageCodeGen is a corner case because it is not connected by edges. + // The only way to set the clusterID is to get it from the children if any. + wNode.nodes.find { childNode => nodeToStageCluster.contains(childNode.id) } match { + case Some(childNode) => + val clusterIDs = nodeToStageCluster(childNode.id) + commitNodeToStageCluster(wNode, orphanNodes, clusterIDs) + case _ => // do nothing if we could not find a child node with a clusterId + false + } + case _ => + // Handle all other nodes. + // Set the node type to determine the restrictions (i.e., exchange is + // positioned at the tail of a stage and shuffleRead should be the head of a stage). + val nodeCase = multiplexCases(currNodeName) + var clusterIDs = ToolsPlanGraph.EMPTY_CLUSTERS + if ((nodeCase & 1) > 0) { + // Assign cluster based on incoming edges. + val inEdgesWithIds = + edges.filter(e => e.toId == currNode.id && nodeToStageCluster.contains(e.fromId)) + if (inEdgesWithIds.nonEmpty) { + // For simplicity, assign the node based on the first incoming adjacent node. + clusterIDs = nodeToStageCluster(inEdgesWithIds.head.fromId) + } + } + if (clusterIDs.isEmpty && (nodeCase & 2) > 0) { + // Assign cluster based on outgoing edges (i.e., ShuffleRead). + // Corner case: TPC-DS Like Bench q2 (sqlID 24). + // A shuffleReader is reading on driver followed by an exchange without + // metrics. + // The metrics will not have a valid accumID. + // In that case, it is not feasible to match it to a cluster without + // considering the incoming node (exchange in that case). This corner + // case is handled later as a last-ditch effort. + val outEdgesWithIds = + edges.filter(e => e.fromId == currNode.id && nodeToStageCluster.contains(e.toId)) + if (outEdgesWithIds.nonEmpty) { + // For simplicity, assign the node based on the first outgoing adjacent node. + clusterIDs = nodeToStageCluster(outEdgesWithIds.head.toId) + } + } + if (clusterIDs.nonEmpty) { + // There is a possible assignment. Commit it. + commitNodeToStageCluster(currNode, orphanNodes, clusterIDs) + } else { + // nothing has changed + false + } + } // End of setting the UpdatedFlag variable. + changeFlag |= updatedFlag + } // End of if orphanNodes.contains(currNode). + } // End of iteration on orphanNodes. + // Corner case for shuffleRead when it is reading from the driver followed by an exchange that + // has no metrics. + if (!changeFlag && orphanNodes.nonEmpty) { + // This is to handle the special case of a shuffleRead that is reading from the driver. + // We could not assign any node to a cluster. This means that we have a cycle in the graph, + // and we need to break it. + // This is done by breaking the rule, allowing the shuffleRead to pick the highest stage + // order of the ancestor node. + changeFlag |= orphanNodes.filter( + n => isPrologueExec(ToolsPlanGraph.processPlanInfo(n.name))).exists { // Picks shuffleRead + orphanNode => + // Get adjacent nodes to the shuffleRead that have cluster assignment. + val inEdgesWithIds = + edges.filter(e => e.toId == orphanNode.id && nodeToStageCluster.contains(e.fromId)) + if (inEdgesWithIds.nonEmpty) { + // At this point, we need to get all the possible stageIDs that can be assigned to the + // adjacent nodes because and not only the logical ones. + val possibleIds = inEdgesWithIds.map { e => + val adjacentNode = allNodes.find(eN => eN.id == e.fromId).get + getAllNodeStages(adjacentNode) + }.reduce(_ ++ _) + // Assign the maximum value clusterId to the node. + val newIDs = Set[Int](possibleIds.max) + commitNodeToStageCluster(orphanNode, orphanNodes, newIDs) + } else { + false + } + } + } // end of corner case handling + } // end of changeFlag loop + } // end of assignNodesToStageClusters + // Start the construction of the graph + assignNodesToStageClusters() + + // Define public interface methods + /** + * Get the stage clusters that the node belongs to. + * Use this method if this logical representation of the node-to-stage relationship. + * For example, an "Exchange" node returns only a single stageID which is the stage that writes + * the data. + * @param node the node to get the stages for + * @return a set of stageIds or empty if None + */ + def getNodeStageClusters(node: SparkPlanGraphNode): Set[Int] = { + nodeToStageCluster.getOrElse(node.id, ToolsPlanGraph.EMPTY_CLUSTERS) + } + + def getNodeStageClusters(nodeId: Long): Set[Int] = { + nodeToStageCluster.getOrElse(nodeId, ToolsPlanGraph.EMPTY_CLUSTERS) + } +} // end of class ToolsPlanGraph + /** * This code is mostly copied from org.apache.spark.sql.execution.ui.SparkPlanGraph * with changes to handle GPU nodes. Without this special handle, the default SparkPlanGraph @@ -34,6 +393,8 @@ import org.apache.spark.sql.rapids.tool.util.stubs.{GraphReflectionAPI, GraphRef * Build a SparkPlanGraph from the root of a SparkPlan tree. */ object ToolsPlanGraph { + // Empty cluster set used to represent a node that is not assigned to any cluster. + private val EMPTY_CLUSTERS: Set[Int] = Set.empty // Captures the API loaded at runtime if any. var api: GraphReflectionAPI = _ @@ -89,6 +450,12 @@ object ToolsPlanGraph { } } + def createGraphWithStageClusters(planInfo: SparkPlanInfo, + accumStageMapper: AccumToStageRetriever): ToolsPlanGraph = { + val sGraph = ToolsPlanGraph(planInfo) + new ToolsPlanGraph(sGraph, accumStageMapper) + } + private def processPlanInfo(nodeName: String): String = { if (nodeName.startsWith("Gpu")) { nodeName.replaceFirst("Gpu", "") diff --git a/core/src/test/resources/QualificationExpectations/complex_dec_expectation.csv b/core/src/test/resources/QualificationExpectations/complex_dec_expectation.csv index f76c554be..8a9e1f666 100644 --- a/core/src/test/resources/QualificationExpectations/complex_dec_expectation.csv +++ b/core/src/test/resources/QualificationExpectations/complex_dec_expectation.csv @@ -1,2 +1,2 @@ App Name,App ID,SQL DF Duration,SQL Dataframe Task Duration,App Duration,GPU Opportunity,Executor CPU Time Percent,SQL Ids with Failures,Unsupported Read File Formats and Types,Unsupported Write Data Format,Complex Types,Nested Complex Types,Potential Problems,Longest SQL Duration,SQL Stage Durations Sum,NONSQL Task Duration Plus Overhead,Unsupported Task Duration,Supported SQL DF Task Duration,App Duration Estimated,Unsupported Execs,Unsupported Expressions,Estimated Job Frequency (monthly), Total Core Seconds -"Spark shell","local-1626104300434",1500,1469,131104,1315,88.35,"","","","struct,lastname:string>;struct,previous:struct,city:string>>;array>;map;map>;map>;array>;array","struct,lastname:string>;struct,previous:struct,city:string>>;array>;map>;map>;array>","NESTED COMPLEX TYPE",1260,1388,129598,181,1288,false,"CollectLimit","",30,1564 +"Spark shell","local-1626104300434",1500,1469,131104,996,88.35,"","","","struct,lastname:string>;struct,previous:struct,city:string>>;array>;map;map>;map>;array>;array","struct,lastname:string>;struct,previous:struct,city:string>>;array>;map>;map>;array>","NESTED COMPLEX TYPE",1260,1388,129598,493,976,false,"CollectLimit","",30,1564 diff --git a/core/src/test/resources/QualificationExpectations/read_dsv1_expectation.csv b/core/src/test/resources/QualificationExpectations/read_dsv1_expectation.csv index 4cb0d9940..b725defc0 100644 --- a/core/src/test/resources/QualificationExpectations/read_dsv1_expectation.csv +++ b/core/src/test/resources/QualificationExpectations/read_dsv1_expectation.csv @@ -1,2 +1,2 @@ App Name,App ID,SQL DF Duration,SQL Dataframe Task Duration,App Duration,GPU Opportunity,Executor CPU Time Percent,SQL Ids with Failures,Unsupported Read File Formats and Types,Unsupported Write Data Format,Complex Types,Nested Complex Types,Potential Problems,Longest SQL Duration,SQL Stage Durations Sum,NONSQL Task Duration Plus Overhead,Unsupported Task Duration,Supported SQL DF Task Duration,App Duration Estimated,Unsupported Execs,Unsupported Expressions,Estimated Job Frequency (monthly),Total Core Seconds -"Spark shell","local-1624371544219",4575,20421,175293,1557,72.15,"","JSON[string:double:date:int:bigint];Text[*]","JSON","","","",1859,5372,176916,13469,6952,false,"CollectLimit;Scan text;Execute InsertIntoHadoopFsRelationCommand json;Scan json","",30,2096 +"Spark shell","local-1624371544219",4575,20421,175293,1523,72.15,"","JSON[string:double:date:int:bigint];Text[*]","JSON","","","",1859,5372,176916,13622,6799,false,"CollectLimit;Scan text;Execute InsertIntoHadoopFsRelationCommand json;Scan json","",30,2096 diff --git a/core/src/test/resources/QualificationExpectations/read_dsv2_expectation.csv b/core/src/test/resources/QualificationExpectations/read_dsv2_expectation.csv index 404c02755..4c7726207 100644 --- a/core/src/test/resources/QualificationExpectations/read_dsv2_expectation.csv +++ b/core/src/test/resources/QualificationExpectations/read_dsv2_expectation.csv @@ -1,2 +1,2 @@ App Name,App ID,SQL DF Duration,SQL Dataframe Task Duration,App Duration,GPU Opportunity,Executor CPU Time Percent,SQL Ids with Failures,Unsupported Read File Formats and Types,Unsupported Write Data Format,Complex Types,Nested Complex Types,Potential Problems,Longest SQL Duration,SQL Stage Durations Sum,NONSQL Task Duration Plus Overhead,Unsupported Task Duration,Supported SQL DF Task Duration,App Duration Estimated,Unsupported Execs,Unsupported Expressions,Estimated Job Frequency (monthly),Total Core Seconds -"Spark shell","local-1624371906627",4917,21802,83738,1745,71.3,"","Text[*];json[double]","JSON","","","",1984,5438,83336,14064,7738,false,"CollectLimit;Scan text;Execute InsertIntoHadoopFsRelationCommand json;BatchScan json","",30,997 +"Spark shell","local-1624371906627",4917,21802,83738,2687,71.3,"","Text[*];json[double]","JSON","","","",1984,5438,83336,9889,11913,false,"CollectLimit;Scan text;Execute InsertIntoHadoopFsRelationCommand json;BatchScan json","",30,997 diff --git a/core/src/test/resources/QualificationExpectations/truncated_1_end_expectation.csv b/core/src/test/resources/QualificationExpectations/truncated_1_end_expectation.csv index 6ef5acd56..f7737d508 100644 --- a/core/src/test/resources/QualificationExpectations/truncated_1_end_expectation.csv +++ b/core/src/test/resources/QualificationExpectations/truncated_1_end_expectation.csv @@ -1,2 +1,2 @@ App Name,App ID,SQL DF Duration,SQL Dataframe Task Duration,App Duration,GPU Opportunity,Executor CPU Time Percent,SQL Ids with Failures,Unsupported Read File Formats and Types,Unsupported Write Data Format,Complex Types,Nested Complex Types,Potential Problems,Longest SQL Duration,SQL Stage Durations Sum,NONSQL Task Duration Plus Overhead,Unsupported Task Duration,Supported SQL DF Task Duration,App Duration Estimated,Unsupported Execs,Unsupported Expressions,Estimated Job Frequency (monthly),Total Core Seconds -"Rapids Spark Profiling Tool Unit Tests","local-1622043423018",395,14353,4872,142,62.67,"","","JSON","","","",1306,794,4477,9164,5189,true,"SerializeFromObject;Execute InsertIntoHadoopFsRelationCommand json;DeserializeToObject;Filter;MapElements;Scan","",30,49 +"Rapids Spark Profiling Tool Unit Tests","local-1622043423018",395,14353,4872,164,62.67,"","","JSON","","","",1306,794,4477,8376,5977,true,"SerializeFromObject;Execute InsertIntoHadoopFsRelationCommand json;DeserializeToObject;Filter;MapElements;Scan","",30,49 diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/BasePlanParserSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/BasePlanParserSuite.scala index 54cad1b1c..b7966d4d2 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/BasePlanParserSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/BasePlanParserSuite.scala @@ -80,4 +80,27 @@ class BasePlanParserSuite extends BaseTestSuite { e.children.getOrElse(Seq.empty) :+ e } } + def verifyPlanExecToStageMap(toolsPlanInfo: PlanInfo): Unit = { + val allExecInfos = toolsPlanInfo.execInfo.flatMap { e => + e.children.getOrElse(Seq.empty) :+ e + } + // Test that all execs are assigned to stages + assert (allExecInfos.forall(_.stages.nonEmpty)) + // assert that exchange is assigned to a single stage + val exchangeExecs = allExecInfos.filter(_.exec == "Exchange") + if (exchangeExecs.nonEmpty) { + assert (exchangeExecs.forall(_.stages.size == 1)) + } + } + + def verifyExecToStageMapping(plans: Seq[PlanInfo], + qualApp: QualificationAppInfo, funcCB: Option[PlanInfo => Unit] = None): Unit = { + // Only iterate on plans with that are associated to jobs + val associatedSqls = qualApp.jobIdToSqlID.values.toSeq + val filteredPlans = plans.filter(p => associatedSqls.contains(p.sqlID)) + val func = funcCB.getOrElse(verifyPlanExecToStageMap(_)) + filteredPlans.foreach { plan => + func(plan) + } + } } diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala index a3d6211bf..6fc9357cf 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala @@ -109,7 +109,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { app.sqlPlans.foreach { case (sqlID, plan) => val planInfo = SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) - + verifyPlanExecToStageMap(planInfo) val wholeStages = planInfo.execInfo.filter(_.exec.contains("WholeStageCodegen")) val allChildren = wholeStages.flatMap(_.children).flatten val sorts = allChildren.filter(_.exec == "Sort") @@ -205,6 +205,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { app.sqlPlans.foreach { case (sqlID, plan) => val planInfo = SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) + verifyPlanExecToStageMap(planInfo) val allExecInfo = planInfo.execInfo // Note that: // Spark320+ will generate the following execs: @@ -247,6 +248,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { app.sqlPlans.foreach { case (sqlID, plan) => val planInfo = SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) + verifyPlanExecToStageMap(planInfo) val wholeStages = planInfo.execInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 2) val numSupported = wholeStages.filter(_.isSupported).size @@ -267,6 +269,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val json = allExecInfo.filter(_.exec.contains("Scan json")) val orc = allExecInfo.filter(_.exec.contains("Scan orc")) @@ -289,6 +292,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) // Note that the text scan from this file is v1 so ignore it val json = allExecInfo.filter(_.exec.contains("BatchScan json")) @@ -319,6 +323,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val subqueryExecs = allExecInfo.filter(_.exec.contains(s"Subquery")) val summaryRecs = subqueryExecs.flatten { sqExec => @@ -354,6 +359,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val writeExecs = allExecInfo.filter(_.exec.contains(s"$dataWriteCMD")) val text = writeExecs.filter(_.expr.contains("text")) @@ -393,6 +399,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val writeExecs = allExecInfo.filter(_.exec.contains(s"$dataWriteCMD")) val text = writeExecs.filter(_.expr.contains("text")) @@ -425,6 +432,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val parquet = { allExecInfo.filter(_.exec.contains("CreateDataSourceTableAsSelectCommand")) @@ -454,6 +462,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val tableScan = allExecInfo.filter(_.exec == ("InMemoryTableScan")) assertSizeAndSupported(1, tableScan.toSeq) @@ -468,6 +477,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val broadcasts = allExecInfo.filter(_.exec == "BroadcastExchange") assertSizeAndSupported(3, broadcasts.toSeq, @@ -488,6 +498,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) // reusedExchange is added as a supportedExec val reusedExchangeExecs = allExecInfo.filter(_.exec == "ReusedExchange") @@ -509,6 +520,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val reader = allExecInfo.filter(_.exec == "CustomShuffleReader") assertSizeAndSupported(2, reader.toSeq) @@ -523,6 +535,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val reader = allExecInfo.filter(_.exec == "AQEShuffleRead") assertSizeAndSupported(2, reader.toSeq) @@ -552,15 +565,17 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) for (execName <- supportedExecs) { val execs = allExecInfo.filter(_.exec == execName) - assertSizeAndSupported(1, execs.toSeq, expectedDur = Seq.empty, extraText = execName) + assertSizeAndSupported(1, execs, expectedDur = Seq.empty, extraText = execName) } for (execName <- unsupportedExecs) { val execs = allExecInfo.filter(_.exec == execName) - assertSizeAndNotSupported(1, execs.toSeq) + assertSizeAndNotSupported(1, execs) } + verifyExecToStageMapping(parsedPlans.toSeq, app) } } @@ -581,6 +596,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) for (execName <- supportedExecs) { val supportedExec = allExecInfo.filter(_.exec == execName) @@ -617,6 +633,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val bhj = allExecInfo.filter(_.exec == "BroadcastHashJoin") assertSizeAndSupported(1, bhj) @@ -659,6 +676,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val bhj = allExecInfo.filter(_.exec == "BroadcastHashJoin") assertSizeAndNotSupported(1, bhj) @@ -692,6 +710,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val smj = allExecInfo.filter(_.exec == "SortMergeJoin") assertSizeAndNotSupported(1, smj) @@ -714,6 +733,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) val sortAggregate = execInfo.filter(_.exec == "SortAggregate") assertSizeAndSupported(2, sortAggregate) @@ -734,6 +754,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) val objectHashAggregate = execInfo.filter(_.exec == "ObjectHashAggregate") // OHA will get sql time metrics "time in aggregation build" @@ -760,6 +781,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) val windowExecs = execInfo.filter(_.exec == "Window") assertSizeAndSupported(1, windowExecs) @@ -775,6 +797,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val flatMapGroups = allExecInfo.filter(_.exec == "FlatMapGroupsInPandas") assertSizeAndSupported(1, flatMapGroups) @@ -800,6 +823,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val supportedExecs = Array("GlobalLimit", "LocalLimit") app.sqlPlans.foreach { case (sqlID, plan) => val planInfo = SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) + verifyPlanExecToStageMap(planInfo) // GlobalLimit and LocalLimit are inside WholeStageCodegen. So getting the children of // WholeStageCodegenExec val wholeStages = planInfo.execInfo.filter(_.exec.contains("WholeStageCodegen")) @@ -825,6 +849,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { app.sqlPlans.foreach { case (sqlID, plan) => val planInfo = SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) + verifyPlanExecToStageMap(planInfo) val wholeStages = planInfo.execInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 1) val allChildren = wholeStages.flatMap(_.children).flatten @@ -849,6 +874,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { app.sqlPlans.foreach { case (sqlID, plan) => val planInfo = SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) + verifyPlanExecToStageMap(planInfo) val wholeStages = planInfo.execInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 2) val allChildren = wholeStages.flatMap(_.children).flatten @@ -873,6 +899,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) val takeOrderedAndProject = execInfo.filter(_.exec == "TakeOrderedAndProject") assertSizeAndNotSupported(1, takeOrderedAndProject) @@ -897,6 +924,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) val generateExprs = execInfo.filter(_.exec == "Generate") assertSizeAndSupported(1, generateExprs) @@ -924,6 +952,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) val projectExprs = execInfo.filter(_.exec == "Project") assertSizeAndSupported(1, projectExprs) @@ -948,6 +977,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) val sortAggregate = execInfo.filter(_.exec == "SortAggregate") assertSizeAndSupported(2, sortAggregate) @@ -974,6 +1004,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val sortExec = allExecInfo.filter(_.exec.contains("Sort")) assert(sortExec.size == 3) @@ -1001,6 +1032,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val wholeStages = allExecInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 1) @@ -1029,6 +1061,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "test desc", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) parsedPlans.foreach { pInfo => assert(pInfo.sqlDesc == "test desc") } @@ -1061,6 +1094,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val wholeStages = allExecInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 1) @@ -1091,6 +1125,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val wholeStages = allExecInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 1) @@ -1121,6 +1156,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val wholeStages = allExecInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 1) @@ -1183,6 +1219,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val projects = allExecInfo.filter(_.exec.contains("Project")) assertSizeAndNotSupported(1, projects) @@ -1211,6 +1248,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val wholeStages = allExecInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 1) @@ -1239,6 +1277,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) val hashAggregate = execInfo.filter(_.exec == "HashAggregate") assertSizeAndSupported(2, hashAggregate, checkDurations = false) @@ -1265,6 +1304,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val wholeStages = allExecInfo.filter(_.exec.contains("WholeStageCodegen")) assert(wholeStages.size == 1) @@ -1344,6 +1384,11 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + // TODO: Spark shows a weird behavior as the entire plan with SQL 65 loses the metrics + // and the associated jobs. This causes some execs to be isolated without stage. + // The UI won't be able to visualize that job anymore. So, we need to investigate + // what happens in that query before we test the execs-to-stage mapping. + val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val deltaLakeWrites = allExecInfo.filter(_.exec.contains(s"$dataWriteCMD")) assertSizeAndSupported(1, deltaLakeWrites) @@ -1406,6 +1451,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) // The promote_precision should be part of the project exec. val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val projExecs = allExecInfo.filter(_.exec.contains("Project")) @@ -1449,6 +1495,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) // The current_database should be part of the project-exec and the parser should ignore it. val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val projExecs = allExecInfo.filter(_.exec.contains("Project")) @@ -1697,6 +1744,18 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + // Note that the generated plan, there are skipped stages that causes some execs to appear + // without their relevant stages. so we skip the stage verification here. + verifyExecToStageMapping(parsedPlans.toSeq, app, Some( planInfo => + if (planInfo.sqlID == 73) { + // Nodes should not have any stages + val allExecInfos = planInfo.execInfo.flatMap { e => + e.children.getOrElse(Seq.empty) :+ e + } + // exclude all stages higher than 8 because those ones belong to a skipped stage + allExecInfos.filter(_.nodeId <= 8).forall(_.stages.nonEmpty) + }) + ) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val windowGroupLimitExecs = allExecInfo.filter(_.exec.contains(windowGroupLimitExecCmd)) // We should have two WindowGroupLimitExec operators (Partial and Final). @@ -1728,6 +1787,18 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + // Note that the generated plan, there are skipped stages that causes some execs to appear + // without their relevant stages. so we skip the stage verification here. + verifyExecToStageMapping(parsedPlans.toSeq, app, Some( planInfo => + if (planInfo.sqlID == 76) { + // Nodes should not have any stages + val allExecInfos = planInfo.execInfo.flatMap { e => + e.children.getOrElse(Seq.empty) :+ e + } + // exclude all stages higher than 8 because those ones belong to a skipped stage + allExecInfos.filter(_.nodeId <= 8).forall(_.stages.nonEmpty) + }) + ) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val windowExecNotSupportedExprs = allExecInfo.filter( _.exec.contains(windowGroupLimitExecCmd)).flatMap(x => x.unsupportedExprs) @@ -1793,6 +1864,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) // we should have 2 hash aggregates with min_by and max_by expressions // if the min_by and max_by were not recognized, the test would fail val hashAggExecs = @@ -1822,6 +1894,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) } + verifyExecToStageMapping(parsedPlans.toSeq, app) val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) val projectExecs = allExecInfo.filter(_.exec == "Project") assertSizeAndSupported(1, projectExecs)