diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala index 45e614376..560571a26 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala @@ -198,34 +198,16 @@ class QualificationAppInfo( } } - private def singleExecMatch(execInfo: ExecInfo): (Seq[(Int, ExecInfo)], Option[ExecInfo]) = { + private def checkStageIdInExec(prev: Option[ExecInfo], + execInfo: ExecInfo, next: Option[ExecInfo]): (Seq[(Int, ExecInfo)], Option[ExecInfo]) = { val associatedStages = { if (execInfo.stages.size > 1) { execInfo.stages.toSeq } else if (execInfo.stages.size < 1) { - // we don't know what stage its in or its duration - logDebug(s"No stage associated with ${execInfo.exec} " + - s"so speedup factor isn't applied anywhere.") - Seq.empty - } else { - Seq(execInfo.stages.head) - } - } - if (associatedStages.nonEmpty) { - (associatedStages.map((_, execInfo)), None) - } else { - (Seq.empty, Some(execInfo)) - } - } - - private def doubleExecMatch(neighbor: ExecInfo, execInfo: ExecInfo): ( - Seq[(Int, ExecInfo)], Option[ExecInfo]) = { - val associatedStages = { - if (execInfo.stages.size > 1) { - execInfo.stages.toSeq - } else if (execInfo.stages.size < 1) { - if (neighbor.stages.size >= 1) { - neighbor.stages.headOption.toSeq + if (prev.exists(_.stages.size >= 1)) { + prev.flatMap(_.stages.headOption).toSeq + } else if (next.exists(_.stages.size >= 1)) { + next.flatMap(_.stages.headOption).toSeq } else { // we don't know what stage its in or its duration logDebug(s"No stage associated with ${execInfo.exec} " + @@ -243,32 +225,6 @@ class QualificationAppInfo( } } - private def tripleExecMatch(prev: ExecInfo, execInfo: ExecInfo, next: ExecInfo): - (Seq[(Int, ExecInfo)], Option[ExecInfo]) = { - val associatedStages = { - if (execInfo.stages.size > 1) { - execInfo.stages.toSeq - } else if (execInfo.stages.size < 1) { - if (prev.stages.size >= 1) { - prev.stages.headOption.toSeq - } else if (next.stages.size >= 1) { - next.stages.headOption.toSeq - } else { - // we don't know what stage its in or its duration - logDebug(s"No stage associated with ${execInfo.exec} " + - s"so speedup factor isn't applied anywhere.") - Seq.empty - } - } else { - Seq(execInfo.stages.head) - } - } - if (associatedStages.nonEmpty) { - (associatedStages.map((_, execInfo)), None) - } else { - (Seq.empty, Some(execInfo)) - } - } private def getStageToExec(execInfos: Seq[ExecInfo]): (Map[Int, Seq[ExecInfo]], Seq[ExecInfo]) = { val execsWithoutStages = new ArrayBuffer[ExecInfo]() @@ -285,29 +241,25 @@ class QualificationAppInfo( // corner case to handle first element case 0 => if (execInfosInOrder.size > 1) { // If there are more than one Execs, then check if the next Exec has a stageId. - doubleExecMatch(execInfosInOrder(1), execInfosInOrder(0)) + checkStageIdInExec(None, execInfosInOrder(0), Some(execInfosInOrder(1))) } else { - singleExecMatch(execInfosInOrder(0)) + checkStageIdInExec(None, execInfosInOrder(0), None) } // corner case to handle last element - case i if i == execInfosInOrder.size - 1 => - if (execInfosInOrder.size > 1) { - // If there are more than one Execs, then check if the previous Exec has a stageId. - doubleExecMatch(execInfosInOrder(i - 1), execInfosInOrder(i)) - } else { - singleExecMatch(execInfosInOrder(i)) - } + case i if i == execInfosInOrder.size - 1 && execInfosInOrder.size > 1 => + // If there are more than one Execs, then check if the previous Exec has a stageId. + checkStageIdInExec(Some(execInfosInOrder(i - 1)), execInfosInOrder(i), None) case i => - tripleExecMatch(execInfosInOrder(i - 1), execInfosInOrder(i), execInfosInOrder(i + 1)) - case _ => - (Seq.empty, None) + checkStageIdInExec(Some(execInfosInOrder(i - 1)), + execInfosInOrder(i), Some(execInfosInOrder(i + 1))) } val perStageSum = execsToStageMap.map(_._1).toList.flatten .groupBy(_._1).map { case (stage, execInfo) => - (stage, execInfo.map(_._2))} + (stage, execInfo.map(_._2)) + } // Add all the execs that don't have a stageId to execsWithoutStages. - execsWithoutStages ++= execsToStageMap.map(_._2).toList.flatten + execsWithoutStages ++= execsToStageMap.map(_._2).toList.flatten (perStageSum, execsWithoutStages) }