diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala index 17159842b..899ce1e2e 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, WeakHashMap} import scala.util.control.NonFatal import scala.util.matching.Regex -import com.nvidia.spark.rapids.tool.planparser.ops.{OperatorRefBase, OpRef, UnsupportedExprOpRef} +import com.nvidia.spark.rapids.tool.planparser.ops.{ExprOpRef, OperatorRefTrait, OpRef, UnsupportedExprOpRef} import com.nvidia.spark.rapids.tool.planparser.photon.{PhotonPlanParser, PhotonStageExecParser} import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker @@ -75,7 +75,7 @@ object UnsupportedReasons extends Enumeration { case class UnsupportedExecSummary( sqlId: Long, execId: Long, - execRef: OperatorRefBase, + execRef: OperatorRefTrait, opType: OpTypes.OpType, reason: UnsupportedReasons.UnsupportedReason, opAction: OpActions.OpAction) { @@ -89,8 +89,6 @@ case class UnsupportedExecSummary( val unsupportedOperatorCSVFormat: String = execRef.getOpNameCSV val details: String = UnsupportedReasons.reportUnsupportedReason(reason) - - def isExpression: Boolean = execRef.isInstanceOf[UnsupportedExprOpRef] } case class ExecInfo( @@ -110,7 +108,7 @@ case class ExecInfo( dataSet: Boolean, udf: Boolean, shouldIgnore: Boolean, - expressions: Seq[OpRef]) { + expressions: Seq[ExprOpRef]) { private def childrenToString = { val str = children.map { c => @@ -141,10 +139,6 @@ case class ExecInfo( stages = stageIDs } - def appendToStages(stageIDs: Set[Int]): Unit = { - stages ++= stageIDs - } - def setShouldRemove(value: Boolean): Unit = { shouldRemove ||= value } @@ -286,7 +280,7 @@ object ExecInfo { udf, shouldIgnore, // convert array of string expressions to OpRefs - expressions = expressions.map(OpRef.fromExpr) + expressions = ExprOpRef.fromRawExprSeq(expressions) ) } @@ -351,7 +345,7 @@ case class PlanInfo( sqlID: Long, sqlDesc: String, execInfo: Seq[ExecInfo]) { - def getUnsupportedExpressions: Seq[OperatorRefBase] = { + def getUnsupportedExpressions: Seq[OperatorRefTrait] = { execInfo.flatMap { e => if (e.isClusterNode) { // wholeStageCodeGen does not have expressions/unsupported-expressions @@ -727,21 +721,21 @@ object SQLPlanParser extends Logging { } private def getAllFunctionNames(regPattern: Regex, expr: String, - groupInd: Int = 1, isAggr: Boolean = true): Set[String] = { + groupInd: Int = 1, isAggr: Boolean = true): Array[String] = { // Returns all matches in an expression. This can be used when the SQL expression is not // tokenized. val newExpr = processSpecialFunctions(expr) // first get all the functionNames val exprss = - regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSet + regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSeq // For aggregate expressions we want to process the results to remove the prefix // DB: remove the "^partial_" and "^finalmerge_" prefixes // TODO: // for performance sake, we can turn off the aggregate processing by enabling it only // when needed. However, for now, we always do this processing until we are confident we know - // the correct place to turn on/off that flag.we can use the argument isAgg only when needed + // the correct place to turn on/off that flag. We can use the argument isAgg only when needed val results = if (isAggr) { exprss.collect { case func => @@ -750,7 +744,7 @@ object SQLPlanParser extends Logging { } else { exprss } - results.filterNot(ignoreExpression(_)) + results.filterNot(ignoreExpression(_)).toArray } def parseProjectExpressions(exprStr: String): Array[String] = { @@ -758,7 +752,7 @@ object SQLPlanParser extends Logging { // This is to split the string such that only function names are extracted. The pattern is // such that function name is succeeded by `(`. We use regex to extract all the function names // below: - getAllFunctionNames(functionPrefixPattern, exprStr).toArray + getAllFunctionNames(functionPrefixPattern, exprStr) } // This parser is used for SortAggregateExec, HashAggregateExec and ObjectHashAggregateExec @@ -792,7 +786,7 @@ object SQLPlanParser extends Logging { } } } - parsedExpressions.distinct.toArray + parsedExpressions.toArray } def parseWindowExpressions(exprStr:String): Array[String] = { @@ -828,7 +822,7 @@ object SQLPlanParser extends Logging { } } } - parsedExpressions.distinct.toArray + parsedExpressions.toArray } def parseWindowGroupLimitExpressions(exprStr: String): Array[String] = { @@ -858,10 +852,10 @@ object SQLPlanParser extends Logging { // - Some values can be NULLs. That's why we cannot limit the extract to the first row. // - Nested brackets/parenthesis makes it challenging to use regex that contains // brackets/parenthesis to extract expressions. - // The implementation Use regex to extract all function names and return distinct set of + // The implementation Use regex to extract all function names and return da list of // function names. // This implementation is 1 line implementation, but it can be a memory/time bottleneck. - getAllFunctionNames(functionPrefixPattern, exprStr).toArray + getAllFunctionNames(functionPrefixPattern, exprStr) } def parseTakeOrderedExpressions(exprStr: String): Array[String] = { @@ -889,7 +883,7 @@ object SQLPlanParser extends Logging { } } } - parsedExpressions.distinct.toArray + parsedExpressions.toArray } def parseGenerateExpressions(exprStr: String): Array[String] = { @@ -897,11 +891,11 @@ object SQLPlanParser extends Logging { // 1. Generate explode(arrays#1306), [id#1304], true, [col#1426] // 2. Generate json_tuple(values#1305, Zipcode, ZipCodeType, City), [id#1304], // false, [c0#1407, c1#1408, c2#1409] - getAllFunctionNames(functionPrefixPattern, exprStr).toArray + getAllFunctionNames(functionPrefixPattern, exprStr) } private def addFunctionNames(exprs: String, parsedExpressions: ArrayBuffer[String]): Unit = { - val functionNames = getAllFunctionNames(functionPrefixPattern, exprs).toArray + val functionNames = getAllFunctionNames(functionPrefixPattern, exprs) functionNames.foreach(parsedExpressions += _) } @@ -946,7 +940,7 @@ object SQLPlanParser extends Logging { val isSortMergeSupported = !(isSortMergeJoin && joinCondition.nonEmpty && isSMJConditionUnsupported(joinCondition)) - (parsedExpressions.distinct.toArray, equiJoinSupportedTypes(buildSide, joinType) + (parsedExpressions.toArray, equiJoinSupportedTypes(buildSide, joinType) && isSortMergeSupported) } @@ -1054,7 +1048,7 @@ object SQLPlanParser extends Logging { } } } - parsedExpressions.distinct.toArray + parsedExpressions.toArray } def parseFilterExpressions(exprStr: String): Array[String] = { @@ -1109,7 +1103,7 @@ object SQLPlanParser extends Logging { processedExpr = nonBinaryOperatorsRegEx.replaceAllIn(processedExpr, " ") // Step-4: remove remaining parentheses '(', ')' and commas if we had functionCalls - if (!functionMatches.isEmpty) { + if (functionMatches.nonEmpty) { // remove "," processedExpr = processedExpr.replaceAll(",", " ") } @@ -1148,6 +1142,6 @@ object SQLPlanParser extends Logging { } } - parsedExpressions.distinct.toArray + parsedExpressions.toArray } } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala index 5b8730938..8754b69c6 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala @@ -59,7 +59,7 @@ abstract class WholeStageExecParserBase( // The node should be marked as shouldRemove when all the children of the // wholeStageCodeGen are marked as shouldRemove. val removeNode = isDupNode || childNodes.forall(_.shouldRemove) - // Remove any suffix in order to get the node label without any trailing number. + // Remove any suffix to get the node label without any trailing number. val nodeLabel = nodeNameRegeX.findFirstMatchIn(node.name) match { case Some(m) => m.group(1) // in case not found, use the full exec name diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WindowGroupLimitParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WindowGroupLimitParser.scala index 853c1b767..1dd50680b 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WindowGroupLimitParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WindowGroupLimitParser.scala @@ -37,7 +37,7 @@ case class WindowGroupLimitParser( override def getUnsupportedExprReasonsForExec( expressions: Array[String]): Seq[UnsupportedExprOpRef] = { - expressions.flatMap { expr => + expressions.distinct.flatMap { expr => if (!supportedRankingExprs.contains(expr)) { Some(UnsupportedExprOpRef(expr, s"Ranking function $expr is not supported in $fullExecName")) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/ExprOpRef.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/ExprOpRef.scala new file mode 100644 index 000000000..48be98c50 --- /dev/null +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/ExprOpRef.scala @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.tool.planparser.ops + +/** + * Represents a reference to an expression operator that is stored in the ExecInfo expressions + * @param opRef the opRef to wrap + * @param count the count of that expression within the exec. + */ +case class ExprOpRef(opRef: OpRef, count: Int = 1) extends OpRefWrapperBase(opRef) + +object ExprOpRef extends OpRefWrapperBaseTrait[ExprOpRef] { + def fromRawExprSeq(exprArr: Seq[String]): Seq[ExprOpRef] = { + exprArr.groupBy(identity) + .mapValues(expr => ExprOpRef(OpRef.fromExpr(expr.head), expr.size)).values.toSeq + } +} diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OpRefWrapperBase.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OpRefWrapperBase.scala new file mode 100644 index 000000000..09cd6bb83 --- /dev/null +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OpRefWrapperBase.scala @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.tool.planparser.ops + +/** + * An instance that wraps OpRef and exposes the OpRef methods. + * This is used to provide common interface for all classes that wrap OpRef along with other + * metadata. + * @param opRef the opRef to wrap + */ +class OpRefWrapperBase(opRef: OpRef) extends OperatorRefTrait { + override def getOpName: String = opRef.getOpName + + override def getOpNameCSV: String = opRef.getOpNameCSV + + override def getOpType: String = opRef.getOpType + + override def getOpTypeCSV: String = opRef.getOpTypeCSV + + def getOpRef: OpRef = opRef +} + +/** + * A trait that provides a factory method to create instances of OpRefWrapperBase from a sequence of + * @tparam R the type of the OpRefWrapperBase + */ +trait OpRefWrapperBaseTrait[R <: OpRefWrapperBase] { + /** + * Create instances of OpRefWrapperBase from a sequence of expressions. + * The expressions are grouped by their value and the count of each expression is stored in the + * OpRefWrapperBase entry. + * @param exprArr the sequence of expressions + * @return a sequence of OpRefWrapperBase instances + */ + def fromRawExprSeq(exprArr: Seq[String]): Seq[R] +} diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorCounter.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorCounter.scala index fab9422aa..f91bedc49 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorCounter.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorCounter.scala @@ -38,21 +38,21 @@ case class OperatorCounter(planInfo: PlanInfo) { * @param stages The set of stages where the operator appears. */ case class OperatorData( - opRef: OperatorRefBase, + opRef: OpRef, var count: Int = 0, - var stages: Set[Int] = Set()) + var stages: Set[Int] = Set()) extends OpRefWrapperBase(opRef) // Summarizes the count information for an exec or expression, including whether it is supported. case class OperatorCountSummary( opData: OperatorData, isSupported: Boolean) - private val supportedMap: mutable.Map[OperatorRefBase, OperatorData] = mutable.Map() - private val unsupportedMap: mutable.Map[OperatorRefBase, OperatorData] = mutable.Map() + private val supportedMap: mutable.Map[OpRef, OperatorData] = mutable.Map() + private val unsupportedMap: mutable.Map[OpRef, OperatorData] = mutable.Map() // Returns a sequence of `OperatorCountSummary`, combining both supported and // unsupported operators. - def getOpsCountSummary(): Seq[OperatorCountSummary] = { + def getOpsCountSummary: Seq[OperatorCountSummary] = { supportedMap.values.map(OperatorCountSummary(_, isSupported = true)).toSeq ++ unsupportedMap.values.map(OperatorCountSummary(_, isSupported = false)).toSeq } @@ -60,44 +60,49 @@ case class OperatorCounter(planInfo: PlanInfo) { // Updates the operator data in the given map (supported or unsupported). // Increments the count and updates the stages where the operator appears. - private def updateOpRefEntry(opRef: OperatorRefBase, stages: Set[Int], - targetMap: mutable.Map[OperatorRefBase, OperatorData]): Unit = { + private def updateOpRefEntry(opRef: OpRef, stages: Set[Int], + targetMap: mutable.Map[OpRef, OperatorData], incrValue: Int = 1): Unit = { val operatorData = targetMap.getOrElseUpdate(opRef, OperatorData(opRef)) - operatorData.count += 1 + operatorData.count += incrValue operatorData.stages ++= stages } // Processes an `ExecInfo` node to update exec and expression counts. // Separates supported and unsupported execs and expressions into their respective maps. private def processExecInfo(execInfo: ExecInfo): Unit = { - val opMap = execInfo.isSupported match { - case true => supportedMap - case false => unsupportedMap + val opMap = if (execInfo.isSupported) { + supportedMap + } else { + unsupportedMap } updateOpRefEntry(execInfo.execRef, execInfo.stages, opMap) - // update the map for supported expressions. We should exclude the unsupported expressions. - execInfo.expressions.filterNot( - e => execInfo.unsupportedExprs.exists(exp => exp.opRef.equals(e))).foreach { expr => - updateOpRefEntry(expr, execInfo.stages, supportedMap) - } - // update the map for unsupported expressions - execInfo.unsupportedExprs.foreach { expr => - updateOpRefEntry(expr, execInfo.stages, unsupportedMap) + // Update the map for supported expressions. For unsupported expressions, + // we use the count stored in the supported expressions. + execInfo.expressions.foreach { expr => + val exprMap = + if (execInfo.unsupportedExprs.exists(unsupExec => + unsupExec.getOpRef.equals(expr.getOpRef))) { + // The expression skips because it exists in the unsupported expressions. + unsupportedMap + } else { + supportedMap + } + updateOpRefEntry(expr.getOpRef, execInfo.stages, exprMap, expr.count) } } - // Counts the execs and expressions in the execution plan. + // Counts the execs and expressions in the execution plan excluding clusterNodes + // (i.e., WholeStageCodeGen). private def countOperators(): Unit = { planInfo.execInfo.foreach { exec => - exec.isClusterNode match { - // we do not want to count the cluster nodes in that aggregation - case true => - if (exec.children.nonEmpty) { - exec.children.get.foreach { child => - processExecInfo(child) - } + if (exec.isClusterNode) { + if (exec.children.nonEmpty) { + exec.children.get.foreach { child => + processExecInfo(child) } - case false => processExecInfo(exec) + } + } else { + processExecInfo(exec) } } } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/UnsupportedExprOpRef.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/UnsupportedExprOpRef.scala index 0e17d8f0d..e81f10903 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/UnsupportedExprOpRef.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/UnsupportedExprOpRef.scala @@ -25,7 +25,17 @@ package com.nvidia.spark.rapids.tool.planparser.ops * @param unsupportedReason A string describing why the expression is unsupported. */ case class UnsupportedExprOpRef(opRef: OpRef, - unsupportedReason: String) extends OperatorRefBase(opRef.value, opRef.opType) + unsupportedReason: String) extends OpRefWrapperBase(opRef) { + + override def getOpName: String = opRef.getOpName + + override def getOpNameCSV: String = opRef.getOpNameCSV + + override def getOpType: String = opRef.getOpType + + override def getOpTypeCSV: String = opRef.getOpTypeCSV + +} // Provides a factory method to create an instance from an expression name and unsupported reason. object UnsupportedExprOpRef { diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala index 44bbf9fa3..9e58ba967 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala @@ -388,7 +388,7 @@ class PluginTypeChecker(platform: Platform = PlatformFactory.createInstance(), } def getNotSupportedExprs(exprs: Seq[String]): Seq[UnsupportedExprOpRef] = { - exprs.collect { + exprs.distinct.collect { case expr if !isExprSupported(expr) => val reason = unsupportedOpsReasons.getOrElse(expr, "") UnsupportedExprOpRef(expr, reason) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala index 9dbe7c920..f82124feb 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala @@ -166,7 +166,7 @@ class QualOutputWriter(outputDir: String, reportReadSchema: Boolean, sums.foreach { sum => QualOutputWriter.constructUnsupportedDetailedStagesDurationInfo(csvFileWriter, sum, headersAndSizes, - QualOutputWriter.CSV_DELIMITER, false) + QualOutputWriter.CSV_DELIMITER, prettyPrint = false) } } finally { csvFileWriter.close() @@ -888,15 +888,15 @@ object QualOutputWriter { planInfos.foreach { planInfo => val sqlIDCSVStr = planInfo.sqlID.toString val allOpsCount = OperatorCounter(planInfo) - .getOpsCountSummary().sortBy(oInfo => (-oInfo.opData.count, oInfo.opData.opRef.getOpName)) + .getOpsCountSummary.sortBy(oInfo => (-oInfo.opData.count, oInfo.opData.opRef.getOpName)) if (allOpsCount.nonEmpty) { val planBuffer = allOpsCount.map { opInfo => val supportFlag = if (opInfo.isSupported) supportedCSVStr else unsupportedCSVStr val stageStr = StringUtils.reformatCSVString(opInfo.opData.stages.mkString(":")) s"$appIDCSVStr$delimiter" + s"$sqlIDCSVStr$delimiter" + - s"${opInfo.opData.opRef.getOpTypeCSV}$delimiter" + - s"${opInfo.opData.opRef.getOpNameCSV}$delimiter${opInfo.opData.count}$delimiter" + + s"${opInfo.opData.getOpTypeCSV}$delimiter" + + s"${opInfo.opData.getOpNameCSV}$delimiter${opInfo.opData.count}$delimiter" + s"$supportFlag$delimiter" + s"$stageStr" } @@ -1102,32 +1102,36 @@ object QualOutputWriter { reformatCSV: Boolean = true): Unit = { val reformatCSVFunc = getReformatCSVFunc(reformatCSV) val appId = sumInfo.appId + val appIDStr = reformatCSVFunc(appId) val appDuration = sumInfo.estimatedInfo.appDur val dummyStageID = -1 val dummyStageDur = 0 val execIdGenerator = new AtomicLong(0) - def constructDetailedUnsupportedRow(unSupExecInfo: UnsupportedExecSummary, - stageId: Int, stageAppDuration: Long): String = { - val data = ListBuffer[(String, Int)]( - reformatCSVFunc(appId) -> headersAndSizes(APP_ID_STR), - unSupExecInfo.sqlId.toString -> headersAndSizes(SQL_ID_STR), - stageId.toString -> headersAndSizes(STAGE_ID_STR), - reformatCSVFunc(unSupExecInfo.execId.toString) -> headersAndSizes(EXEC_ID), - reformatCSVFunc(unSupExecInfo.finalOpType) -> headersAndSizes(UNSUPPORTED_TYPE), - unSupExecInfo.unsupportedOperatorCSVFormat -> headersAndSizes(UNSUPPORTED_OPERATOR), - reformatCSVFunc(unSupExecInfo.details) -> headersAndSizes(DETAILS), - stageAppDuration.toString -> headersAndSizes(STAGE_WALLCLOCK_DUR_STR), - appDuration.toString -> headersAndSizes(APP_DUR_STR), - reformatCSVFunc(unSupExecInfo.opAction.toString) -> headersAndSizes(EXEC_ACTION) - ) - constructOutputRow(data, delimiter, prettyPrint) + def constructDetailedUnsupportedRow( + appID: String, + unSupExecInfo: UnsupportedExecSummary, + stageId: String, + stageAppDuration: String): String = { + val reformatCSVFunc = getReformatCSVFunc(reformatCSV) + s"$appID" + delimiter + + s"${unSupExecInfo.sqlId}" + delimiter + + s"$stageId" + delimiter + + s"${unSupExecInfo.execId.toString}" + delimiter + + s"${reformatCSVFunc(unSupExecInfo.finalOpType)}" + delimiter + + s"${unSupExecInfo.unsupportedOperatorCSVFormat}" + delimiter + + s"${reformatCSVFunc(unSupExecInfo.details)}" + delimiter + + s"$stageAppDuration" + delimiter + + s"$appDuration" + delimiter + + s"${unSupExecInfo.opAction}\n" } def getUnsupportedRows(execI: ExecInfo, stageId: Int, stageDur: Long): String = { val results = execI.getUnsupportedExecSummaryRecord(execIdGenerator.getAndIncrement()) + val stageIDStr = stageId.toString + val stageDurStr = stageDur.toString results.map { unsupportedExecSummary => - constructDetailedUnsupportedRow(unsupportedExecSummary, stageId, stageDur) + constructDetailedUnsupportedRow(appIDStr, unsupportedExecSummary, stageIDStr, stageDurStr) }.mkString } diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala index e31064851..771484c14 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.nvidia.spark.rapids.tool.ToolTestUtils +import com.nvidia.spark.rapids.tool.planparser.ops.{ExprOpRef, OpRef} import com.nvidia.spark.rapids.tool.qualification._ import org.scalatest.Matchers.{be, contain, convertToAnyShouldWrapper} import org.scalatest.exceptions.TestFailedException @@ -1319,46 +1320,69 @@ class SQLPlanParserSuite extends BasePlanParserSuite { test("Parsing Conditional Expressions") { // scalastyle:off line.size.limit - val expressionsMap: mutable.HashMap[String, Array[String]] = mutable.HashMap( - "(((lower(partition_act#90) = moduleview) && (isnotnull(productarr#22) && NOT (productarr#22=[]))) || (lower(moduletype#13) = saveforlater))" -> - Array("lower", "isnotnull", "EqualTo", "And", "Not", "Or"), + val expressionsMap: mutable.HashMap[String, Map[String, Int]] = mutable.HashMap( + "(((lower(partition_act#90) = moduleview) && (isnotnull(productarr#22) && NOT (productarr#22 = []))) || (lower(moduletype#13) = saveforlater))" -> + Map("lower" -> 2, + "isnotnull" -> 1, + "EqualTo" -> 3, + "And" -> 2, + "Not" -> 1, + "Or" -> 1), "(IsNotNull(c_customer_id))" -> - Array("IsNotNull"), + Map("IsNotNull" -> 1), "(isnotnull(names#15) AND StartsWith(names#15, OR))" -> - Array("isnotnull", "And", "StartsWith"), + Map("isnotnull" -> 1, + "StartsWith" -> 1, + "And" -> 1), "((isnotnull(s_state#68) AND (s_state#68 = TN)) OR (hex(cast(value#0 as bigint)) = B))" -> - Array("isnotnull", "And", "Or", "hex", "EqualTo"), + Map("isnotnull" -> 1, + "And" -> 1, + "Or" -> 1, + "hex" -> 1, + "EqualTo" -> 2), // Test that AND followed by '(' without space can be parsed "((isnotnull(s_state#68) AND(s_state#68 = TN)) OR (hex(cast(value#0 as bigint)) = B))" -> - Array("isnotnull", "And", "Or", "hex", "EqualTo"), + Map("isnotnull" -> 1, + "And" -> 1, + "Or" -> 1, + "hex" -> 1, + "EqualTo" -> 2), "(((isnotnull(d_year#498) AND isnotnull(d_moy#500)) AND (d_year#498 = 1998)) AND (d_moy#500 = 12))" -> - Array("isnotnull", "And", "EqualTo"), + Map("isnotnull" -> 2, + "And" -> 3, + "EqualTo" -> 2), "IsNotNull(d_year) AND IsNotNull(d_moy) AND EqualTo(d_year,1998) AND EqualTo(d_moy,12)" -> - Array("IsNotNull", "And", "EqualTo"), + Map("IsNotNull" -> 2, + "EqualTo" -> 2, + "And" -> 3), // check that a predicate with a single variable name is fine "flagVariable" -> - Array(), + Map(), // check that a predicate with a single function call "isnotnull(c_customer_sk#412)" -> - Array("isnotnull"), + Map("isnotnull" -> 1), "((substr(ca_zip#457, 1, 5) IN (85669,86197,88274,83405,86475,85392,85460,80348,81792) OR ca_state#456 IN (CA,WA,GA)) OR (cs_sales_price#20 > 500.00))" -> - Array("substr", "In", "Or", "GreaterThan"), + Map("substr" -> 1, + "In" -> 2, + "Or" -> 2, + "GreaterThan" -> 1), // test the operator is at the beginning of expression and not followed by space "NOT(isnotnull(d_moy))" -> - Array("Not", "isnotnull"), + Map("isnotnull" -> 1, + "Not" -> 1), // test the shiftright operator(since spark-4.0) "((isnotnull(d_year#498) AND isnotnull(d_moy#500)) AND (d_year#498 >> 1) >= 100)" -> - Array("isnotnull", "And", "GreaterThanOrEqual", "ShiftRight") + Map("isnotnull" -> 2, + "And" -> 2, + "ShiftRight" -> 1, + "GreaterThanOrEqual" -> 1) ) // scalastyle:on line.size.limit - for ((condExpr, expectedExpression) <- expressionsMap) { - val parsedExpressionsMine = SQLPlanParser.parseConditionalExpressions(condExpr) - val currOutput = parsedExpressionsMine.sorted - val expectedOutput = expectedExpression.sorted - assert(currOutput sameElements expectedOutput, - s"The parsed expressions are not as expected. Expression: ${condExpr}, " + - s"Expected: ${expectedOutput.mkString}, " + - s"Output: ${currOutput.mkString}") + for ((condExpr, expectedExpressionCounts) <- expressionsMap) { + val rawExpressions = SQLPlanParser.parseConditionalExpressions(condExpr) + val expected = expectedExpressionCounts.map(e => ExprOpRef(OpRef.fromExpr(e._1), e._2)) + val actualExpressions = ExprOpRef.fromRawExprSeq(rawExpressions) + actualExpressions should ===(expected) } } @@ -1403,10 +1427,18 @@ class SQLPlanParserSuite extends BasePlanParserSuite { "AND (content_name_16#197L = 1)) AND NOT (split(split(split(replace(replace(replace" + "(replace(trim(replace(cast(unbase64(content#192) as string), , ), Some( )), *., ), *, ), " + "https://, ), http://, ), /, -1)[0], :, -1)[0], \\?, -1)[0] = ))" - val expected = Array("isnotnull", "split", "replace", "trim", "unbase64", "And", - "EqualTo", "Not") - val expressions = SQLPlanParser.parseFilterExpressions(exprString) - expressions should ===(expected) + val expected = Map( + "isnotnull" -> 3, + "split" -> 3, + "replace" -> 5, + "trim" -> 1, + "unbase64" -> 1, + "And" -> 6, + "EqualTo" -> 4, // EqualTo comes from the = operator + "Not" -> 2).map(e => ExprOpRef(OpRef.fromExpr(e._1), e._2)) + val rawExpressions = SQLPlanParser.parseFilterExpressions(exprString) + val actualExpressions = ExprOpRef.fromRawExprSeq(rawExpressions) + actualExpressions should ===(expected) } @@ -1418,9 +1450,16 @@ class SQLPlanParserSuite extends BasePlanParserSuite { "THEN concat(replace(cast(unbase64(content#192) as string), , ), %) " + "ELSE replace(replace(replace(cast(unbase64(content#192) as string), , ), %, " + "\\%), *, %) END#200])" - val expected = Array("replace", "concat", "instr", "split", "trim", "unbase64") - val expressions = SQLPlanParser.parseAggregateExpressions(exprString) - expressions should ===(expected) + val expected = Map( + "replace" -> 10, + "concat" -> 1, + "instr" -> 1, + "split" -> 3, + "trim" -> 1, + "unbase64" -> 4).map(e => ExprOpRef(OpRef.fromExpr(e._1), e._2)) + val rawExpressions = SQLPlanParser.parseAggregateExpressions(exprString) + val actualExpressions = ExprOpRef.fromRawExprSeq(rawExpressions) + actualExpressions should ===(expected) } runConditionalTest("promote_precision is supported for Spark LT 3.4.0: issue-517", @@ -1522,7 +1561,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite { "gid#1296," + "CAST((IF((supersql_t12.`ret_type` = 2), 1, 0)) AS BIGINT)#1300L]]" // Only "IF" should be picked up as a function name - val expected = Array("IF") + val expected = Array("IF", "IF", "IF") val expressions = SQLPlanParser.parseExpandExpressions(exprString) expressions should ===(expected) }