From c15aac98bf1219a10afaf557e4989ed7c6b07b6e Mon Sep 17 00:00:00 2001
From: "Ahmed Hussein (amahussein)"
Date: Thu, 5 Dec 2024 15:48:59 -0600
Subject: [PATCH] Count expressions per Exec in SQLPlanParser
Signed-off-by: Ahmed Hussein (amahussein)
Fixes #1447
Improves the operators-stats by counting the occurence of each
expression within the Exec node.
---
.../tool/planparser/SQLPlanParser.scala | 49 ++++-----
.../planparser/WholeStageExecParser.scala | 2 +-
.../planparser/WindowGroupLimitParser.scala | 2 +-
.../tool/planparser/ops/ExprOpRef.scala | 31 ++++++
.../planparser/ops/OpRefWrapperBase.scala | 50 ++++++++++
.../tool/planparser/ops/OperatorCounter.scala | 61 ++++++------
.../planparser/ops/UnsupportedExprOpRef.scala | 12 ++-
.../qualification/PluginTypeChecker.scala | 2 +-
.../tool/qualification/QualOutputWriter.scala | 44 +++++----
.../tool/planparser/SqlPlanParserSuite.scala | 99 +++++++++++++------
10 files changed, 242 insertions(+), 110 deletions(-)
create mode 100644 core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/ExprOpRef.scala
create mode 100644 core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OpRefWrapperBase.scala
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala
index 17159842b..4d9c59dd8 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, WeakHashMap}
import scala.util.control.NonFatal
import scala.util.matching.Regex
-import com.nvidia.spark.rapids.tool.planparser.ops.{OperatorRefBase, OpRef, UnsupportedExprOpRef}
+import com.nvidia.spark.rapids.tool.planparser.ops.{ExprOpRef, OperatorRefTrait, OpRef, UnsupportedExprOpRef}
import com.nvidia.spark.rapids.tool.planparser.photon.{PhotonPlanParser, PhotonStageExecParser}
import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker
@@ -75,7 +75,7 @@ object UnsupportedReasons extends Enumeration {
case class UnsupportedExecSummary(
sqlId: Long,
execId: Long,
- execRef: OperatorRefBase,
+ execRef: OperatorRefTrait,
opType: OpTypes.OpType,
reason: UnsupportedReasons.UnsupportedReason,
opAction: OpActions.OpAction) {
@@ -89,8 +89,6 @@ case class UnsupportedExecSummary(
val unsupportedOperatorCSVFormat: String = execRef.getOpNameCSV
val details: String = UnsupportedReasons.reportUnsupportedReason(reason)
-
- def isExpression: Boolean = execRef.isInstanceOf[UnsupportedExprOpRef]
}
case class ExecInfo(
@@ -110,7 +108,7 @@ case class ExecInfo(
dataSet: Boolean,
udf: Boolean,
shouldIgnore: Boolean,
- expressions: Seq[OpRef]) {
+ expressions: Seq[ExprOpRef]) {
private def childrenToString = {
val str = children.map { c =>
@@ -141,10 +139,6 @@ case class ExecInfo(
stages = stageIDs
}
- def appendToStages(stageIDs: Set[Int]): Unit = {
- stages ++= stageIDs
- }
-
def setShouldRemove(value: Boolean): Unit = {
shouldRemove ||= value
}
@@ -286,7 +280,7 @@ object ExecInfo {
udf,
shouldIgnore,
// convert array of string expressions to OpRefs
- expressions = expressions.map(OpRef.fromExpr)
+ expressions = ExprOpRef.fromRawExprSeq(expressions)
)
}
@@ -351,7 +345,7 @@ case class PlanInfo(
sqlID: Long,
sqlDesc: String,
execInfo: Seq[ExecInfo]) {
- def getUnsupportedExpressions: Seq[OperatorRefBase] = {
+ def getUnsupportedExpressions: Seq[OperatorRefTrait] = {
execInfo.flatMap { e =>
if (e.isClusterNode) {
// wholeStageCodeGen does not have expressions/unsupported-expressions
@@ -727,21 +721,21 @@ object SQLPlanParser extends Logging {
}
private def getAllFunctionNames(regPattern: Regex, expr: String,
- groupInd: Int = 1, isAggr: Boolean = true): Set[String] = {
+ groupInd: Int = 1, isAggr: Boolean = true): Array[String] = {
// Returns all matches in an expression. This can be used when the SQL expression is not
// tokenized.
val newExpr = processSpecialFunctions(expr)
// first get all the functionNames
val exprss =
- regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSet
+ regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSeq
// For aggregate expressions we want to process the results to remove the prefix
// DB: remove the "^partial_" and "^finalmerge_" prefixes
// TODO:
// for performance sake, we can turn off the aggregate processing by enabling it only
// when needed. However, for now, we always do this processing until we are confident we know
- // the correct place to turn on/off that flag.we can use the argument isAgg only when needed
+ // the correct place to turn on/off that flag. We can use the argument isAgg only when needed
val results = if (isAggr) {
exprss.collect {
case func =>
@@ -750,7 +744,7 @@ object SQLPlanParser extends Logging {
} else {
exprss
}
- results.filterNot(ignoreExpression(_))
+ results.filterNot(ignoreExpression(_)).toArray
}
def parseProjectExpressions(exprStr: String): Array[String] = {
@@ -758,7 +752,7 @@ object SQLPlanParser extends Logging {
// This is to split the string such that only function names are extracted. The pattern is
// such that function name is succeeded by `(`. We use regex to extract all the function names
// below:
- getAllFunctionNames(functionPrefixPattern, exprStr).toArray
+ getAllFunctionNames(functionPrefixPattern, exprStr)
}
// This parser is used for SortAggregateExec, HashAggregateExec and ObjectHashAggregateExec
@@ -792,7 +786,7 @@ object SQLPlanParser extends Logging {
}
}
}
- parsedExpressions.distinct.toArray
+ parsedExpressions.toArray
}
def parseWindowExpressions(exprStr:String): Array[String] = {
@@ -828,7 +822,7 @@ object SQLPlanParser extends Logging {
}
}
}
- parsedExpressions.distinct.toArray
+ parsedExpressions.toArray
}
def parseWindowGroupLimitExpressions(exprStr: String): Array[String] = {
@@ -858,10 +852,10 @@ object SQLPlanParser extends Logging {
// - Some values can be NULLs. That's why we cannot limit the extract to the first row.
// - Nested brackets/parenthesis makes it challenging to use regex that contains
// brackets/parenthesis to extract expressions.
- // The implementation Use regex to extract all function names and return distinct set of
+ // The implementation Use regex to extract all function names and return a list of
// function names.
// This implementation is 1 line implementation, but it can be a memory/time bottleneck.
- getAllFunctionNames(functionPrefixPattern, exprStr).toArray
+ getAllFunctionNames(functionPrefixPattern, exprStr)
}
def parseTakeOrderedExpressions(exprStr: String): Array[String] = {
@@ -889,7 +883,7 @@ object SQLPlanParser extends Logging {
}
}
}
- parsedExpressions.distinct.toArray
+ parsedExpressions.toArray
}
def parseGenerateExpressions(exprStr: String): Array[String] = {
@@ -897,11 +891,11 @@ object SQLPlanParser extends Logging {
// 1. Generate explode(arrays#1306), [id#1304], true, [col#1426]
// 2. Generate json_tuple(values#1305, Zipcode, ZipCodeType, City), [id#1304],
// false, [c0#1407, c1#1408, c2#1409]
- getAllFunctionNames(functionPrefixPattern, exprStr).toArray
+ getAllFunctionNames(functionPrefixPattern, exprStr)
}
private def addFunctionNames(exprs: String, parsedExpressions: ArrayBuffer[String]): Unit = {
- val functionNames = getAllFunctionNames(functionPrefixPattern, exprs).toArray
+ val functionNames = getAllFunctionNames(functionPrefixPattern, exprs)
functionNames.foreach(parsedExpressions += _)
}
@@ -946,7 +940,7 @@ object SQLPlanParser extends Logging {
val isSortMergeSupported = !(isSortMergeJoin &&
joinCondition.nonEmpty && isSMJConditionUnsupported(joinCondition))
- (parsedExpressions.distinct.toArray, equiJoinSupportedTypes(buildSide, joinType)
+ (parsedExpressions.toArray, equiJoinSupportedTypes(buildSide, joinType)
&& isSortMergeSupported)
}
@@ -1054,7 +1048,7 @@ object SQLPlanParser extends Logging {
}
}
}
- parsedExpressions.distinct.toArray
+ parsedExpressions.toArray
}
def parseFilterExpressions(exprStr: String): Array[String] = {
@@ -1109,7 +1103,7 @@ object SQLPlanParser extends Logging {
processedExpr = nonBinaryOperatorsRegEx.replaceAllIn(processedExpr, " ")
// Step-4: remove remaining parentheses '(', ')' and commas if we had functionCalls
- if (!functionMatches.isEmpty) {
+ if (functionMatches.nonEmpty) {
// remove ","
processedExpr = processedExpr.replaceAll(",", " ")
}
@@ -1147,7 +1141,6 @@ object SQLPlanParser extends Logging {
logDebug(s"Unrecognized Token - $token")
}
}
-
- parsedExpressions.distinct.toArray
+ parsedExpressions.toArray
}
}
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala
index 5b8730938..8754b69c6 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WholeStageExecParser.scala
@@ -59,7 +59,7 @@ abstract class WholeStageExecParserBase(
// The node should be marked as shouldRemove when all the children of the
// wholeStageCodeGen are marked as shouldRemove.
val removeNode = isDupNode || childNodes.forall(_.shouldRemove)
- // Remove any suffix in order to get the node label without any trailing number.
+ // Remove any suffix to get the node label without any trailing number.
val nodeLabel = nodeNameRegeX.findFirstMatchIn(node.name) match {
case Some(m) => m.group(1)
// in case not found, use the full exec name
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WindowGroupLimitParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WindowGroupLimitParser.scala
index 853c1b767..1dd50680b 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WindowGroupLimitParser.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/WindowGroupLimitParser.scala
@@ -37,7 +37,7 @@ case class WindowGroupLimitParser(
override def getUnsupportedExprReasonsForExec(
expressions: Array[String]): Seq[UnsupportedExprOpRef] = {
- expressions.flatMap { expr =>
+ expressions.distinct.flatMap { expr =>
if (!supportedRankingExprs.contains(expr)) {
Some(UnsupportedExprOpRef(expr,
s"Ranking function $expr is not supported in $fullExecName"))
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/ExprOpRef.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/ExprOpRef.scala
new file mode 100644
index 000000000..48be98c50
--- /dev/null
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/ExprOpRef.scala
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.tool.planparser.ops
+
+/**
+ * Represents a reference to an expression operator that is stored in the ExecInfo expressions
+ * @param opRef the opRef to wrap
+ * @param count the count of that expression within the exec.
+ */
+case class ExprOpRef(opRef: OpRef, count: Int = 1) extends OpRefWrapperBase(opRef)
+
+object ExprOpRef extends OpRefWrapperBaseTrait[ExprOpRef] {
+ def fromRawExprSeq(exprArr: Seq[String]): Seq[ExprOpRef] = {
+ exprArr.groupBy(identity)
+ .mapValues(expr => ExprOpRef(OpRef.fromExpr(expr.head), expr.size)).values.toSeq
+ }
+}
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OpRefWrapperBase.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OpRefWrapperBase.scala
new file mode 100644
index 000000000..09cd6bb83
--- /dev/null
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OpRefWrapperBase.scala
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.tool.planparser.ops
+
+/**
+ * An instance that wraps OpRef and exposes the OpRef methods.
+ * This is used to provide common interface for all classes that wrap OpRef along with other
+ * metadata.
+ * @param opRef the opRef to wrap
+ */
+class OpRefWrapperBase(opRef: OpRef) extends OperatorRefTrait {
+ override def getOpName: String = opRef.getOpName
+
+ override def getOpNameCSV: String = opRef.getOpNameCSV
+
+ override def getOpType: String = opRef.getOpType
+
+ override def getOpTypeCSV: String = opRef.getOpTypeCSV
+
+ def getOpRef: OpRef = opRef
+}
+
+/**
+ * A trait that provides a factory method to create instances of OpRefWrapperBase from a sequence of
+ * @tparam R the type of the OpRefWrapperBase
+ */
+trait OpRefWrapperBaseTrait[R <: OpRefWrapperBase] {
+ /**
+ * Create instances of OpRefWrapperBase from a sequence of expressions.
+ * The expressions are grouped by their value and the count of each expression is stored in the
+ * OpRefWrapperBase entry.
+ * @param exprArr the sequence of expressions
+ * @return a sequence of OpRefWrapperBase instances
+ */
+ def fromRawExprSeq(exprArr: Seq[String]): Seq[R]
+}
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorCounter.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorCounter.scala
index fab9422aa..f91bedc49 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorCounter.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorCounter.scala
@@ -38,21 +38,21 @@ case class OperatorCounter(planInfo: PlanInfo) {
* @param stages The set of stages where the operator appears.
*/
case class OperatorData(
- opRef: OperatorRefBase,
+ opRef: OpRef,
var count: Int = 0,
- var stages: Set[Int] = Set())
+ var stages: Set[Int] = Set()) extends OpRefWrapperBase(opRef)
// Summarizes the count information for an exec or expression, including whether it is supported.
case class OperatorCountSummary(
opData: OperatorData,
isSupported: Boolean)
- private val supportedMap: mutable.Map[OperatorRefBase, OperatorData] = mutable.Map()
- private val unsupportedMap: mutable.Map[OperatorRefBase, OperatorData] = mutable.Map()
+ private val supportedMap: mutable.Map[OpRef, OperatorData] = mutable.Map()
+ private val unsupportedMap: mutable.Map[OpRef, OperatorData] = mutable.Map()
// Returns a sequence of `OperatorCountSummary`, combining both supported and
// unsupported operators.
- def getOpsCountSummary(): Seq[OperatorCountSummary] = {
+ def getOpsCountSummary: Seq[OperatorCountSummary] = {
supportedMap.values.map(OperatorCountSummary(_, isSupported = true)).toSeq ++
unsupportedMap.values.map(OperatorCountSummary(_, isSupported = false)).toSeq
}
@@ -60,44 +60,49 @@ case class OperatorCounter(planInfo: PlanInfo) {
// Updates the operator data in the given map (supported or unsupported).
// Increments the count and updates the stages where the operator appears.
- private def updateOpRefEntry(opRef: OperatorRefBase, stages: Set[Int],
- targetMap: mutable.Map[OperatorRefBase, OperatorData]): Unit = {
+ private def updateOpRefEntry(opRef: OpRef, stages: Set[Int],
+ targetMap: mutable.Map[OpRef, OperatorData], incrValue: Int = 1): Unit = {
val operatorData = targetMap.getOrElseUpdate(opRef, OperatorData(opRef))
- operatorData.count += 1
+ operatorData.count += incrValue
operatorData.stages ++= stages
}
// Processes an `ExecInfo` node to update exec and expression counts.
// Separates supported and unsupported execs and expressions into their respective maps.
private def processExecInfo(execInfo: ExecInfo): Unit = {
- val opMap = execInfo.isSupported match {
- case true => supportedMap
- case false => unsupportedMap
+ val opMap = if (execInfo.isSupported) {
+ supportedMap
+ } else {
+ unsupportedMap
}
updateOpRefEntry(execInfo.execRef, execInfo.stages, opMap)
- // update the map for supported expressions. We should exclude the unsupported expressions.
- execInfo.expressions.filterNot(
- e => execInfo.unsupportedExprs.exists(exp => exp.opRef.equals(e))).foreach { expr =>
- updateOpRefEntry(expr, execInfo.stages, supportedMap)
- }
- // update the map for unsupported expressions
- execInfo.unsupportedExprs.foreach { expr =>
- updateOpRefEntry(expr, execInfo.stages, unsupportedMap)
+ // Update the map for supported expressions. For unsupported expressions,
+ // we use the count stored in the supported expressions.
+ execInfo.expressions.foreach { expr =>
+ val exprMap =
+ if (execInfo.unsupportedExprs.exists(unsupExec =>
+ unsupExec.getOpRef.equals(expr.getOpRef))) {
+ // The expression skips because it exists in the unsupported expressions.
+ unsupportedMap
+ } else {
+ supportedMap
+ }
+ updateOpRefEntry(expr.getOpRef, execInfo.stages, exprMap, expr.count)
}
}
- // Counts the execs and expressions in the execution plan.
+ // Counts the execs and expressions in the execution plan excluding clusterNodes
+ // (i.e., WholeStageCodeGen).
private def countOperators(): Unit = {
planInfo.execInfo.foreach { exec =>
- exec.isClusterNode match {
- // we do not want to count the cluster nodes in that aggregation
- case true =>
- if (exec.children.nonEmpty) {
- exec.children.get.foreach { child =>
- processExecInfo(child)
- }
+ if (exec.isClusterNode) {
+ if (exec.children.nonEmpty) {
+ exec.children.get.foreach { child =>
+ processExecInfo(child)
}
- case false => processExecInfo(exec)
+ }
+ } else {
+ processExecInfo(exec)
}
}
}
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/UnsupportedExprOpRef.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/UnsupportedExprOpRef.scala
index 0e17d8f0d..e81f10903 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/UnsupportedExprOpRef.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/UnsupportedExprOpRef.scala
@@ -25,7 +25,17 @@ package com.nvidia.spark.rapids.tool.planparser.ops
* @param unsupportedReason A string describing why the expression is unsupported.
*/
case class UnsupportedExprOpRef(opRef: OpRef,
- unsupportedReason: String) extends OperatorRefBase(opRef.value, opRef.opType)
+ unsupportedReason: String) extends OpRefWrapperBase(opRef) {
+
+ override def getOpName: String = opRef.getOpName
+
+ override def getOpNameCSV: String = opRef.getOpNameCSV
+
+ override def getOpType: String = opRef.getOpType
+
+ override def getOpTypeCSV: String = opRef.getOpTypeCSV
+
+}
// Provides a factory method to create an instance from an expression name and unsupported reason.
object UnsupportedExprOpRef {
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala
index 44bbf9fa3..9e58ba967 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala
@@ -388,7 +388,7 @@ class PluginTypeChecker(platform: Platform = PlatformFactory.createInstance(),
}
def getNotSupportedExprs(exprs: Seq[String]): Seq[UnsupportedExprOpRef] = {
- exprs.collect {
+ exprs.distinct.collect {
case expr if !isExprSupported(expr) =>
val reason = unsupportedOpsReasons.getOrElse(expr, "")
UnsupportedExprOpRef(expr, reason)
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala
index 9dbe7c920..f82124feb 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala
@@ -166,7 +166,7 @@ class QualOutputWriter(outputDir: String, reportReadSchema: Boolean,
sums.foreach { sum =>
QualOutputWriter.constructUnsupportedDetailedStagesDurationInfo(csvFileWriter,
sum, headersAndSizes,
- QualOutputWriter.CSV_DELIMITER, false)
+ QualOutputWriter.CSV_DELIMITER, prettyPrint = false)
}
} finally {
csvFileWriter.close()
@@ -888,15 +888,15 @@ object QualOutputWriter {
planInfos.foreach { planInfo =>
val sqlIDCSVStr = planInfo.sqlID.toString
val allOpsCount = OperatorCounter(planInfo)
- .getOpsCountSummary().sortBy(oInfo => (-oInfo.opData.count, oInfo.opData.opRef.getOpName))
+ .getOpsCountSummary.sortBy(oInfo => (-oInfo.opData.count, oInfo.opData.opRef.getOpName))
if (allOpsCount.nonEmpty) {
val planBuffer = allOpsCount.map { opInfo =>
val supportFlag = if (opInfo.isSupported) supportedCSVStr else unsupportedCSVStr
val stageStr = StringUtils.reformatCSVString(opInfo.opData.stages.mkString(":"))
s"$appIDCSVStr$delimiter" +
s"$sqlIDCSVStr$delimiter" +
- s"${opInfo.opData.opRef.getOpTypeCSV}$delimiter" +
- s"${opInfo.opData.opRef.getOpNameCSV}$delimiter${opInfo.opData.count}$delimiter" +
+ s"${opInfo.opData.getOpTypeCSV}$delimiter" +
+ s"${opInfo.opData.getOpNameCSV}$delimiter${opInfo.opData.count}$delimiter" +
s"$supportFlag$delimiter" +
s"$stageStr"
}
@@ -1102,32 +1102,36 @@ object QualOutputWriter {
reformatCSV: Boolean = true): Unit = {
val reformatCSVFunc = getReformatCSVFunc(reformatCSV)
val appId = sumInfo.appId
+ val appIDStr = reformatCSVFunc(appId)
val appDuration = sumInfo.estimatedInfo.appDur
val dummyStageID = -1
val dummyStageDur = 0
val execIdGenerator = new AtomicLong(0)
- def constructDetailedUnsupportedRow(unSupExecInfo: UnsupportedExecSummary,
- stageId: Int, stageAppDuration: Long): String = {
- val data = ListBuffer[(String, Int)](
- reformatCSVFunc(appId) -> headersAndSizes(APP_ID_STR),
- unSupExecInfo.sqlId.toString -> headersAndSizes(SQL_ID_STR),
- stageId.toString -> headersAndSizes(STAGE_ID_STR),
- reformatCSVFunc(unSupExecInfo.execId.toString) -> headersAndSizes(EXEC_ID),
- reformatCSVFunc(unSupExecInfo.finalOpType) -> headersAndSizes(UNSUPPORTED_TYPE),
- unSupExecInfo.unsupportedOperatorCSVFormat -> headersAndSizes(UNSUPPORTED_OPERATOR),
- reformatCSVFunc(unSupExecInfo.details) -> headersAndSizes(DETAILS),
- stageAppDuration.toString -> headersAndSizes(STAGE_WALLCLOCK_DUR_STR),
- appDuration.toString -> headersAndSizes(APP_DUR_STR),
- reformatCSVFunc(unSupExecInfo.opAction.toString) -> headersAndSizes(EXEC_ACTION)
- )
- constructOutputRow(data, delimiter, prettyPrint)
+ def constructDetailedUnsupportedRow(
+ appID: String,
+ unSupExecInfo: UnsupportedExecSummary,
+ stageId: String,
+ stageAppDuration: String): String = {
+ val reformatCSVFunc = getReformatCSVFunc(reformatCSV)
+ s"$appID" + delimiter +
+ s"${unSupExecInfo.sqlId}" + delimiter +
+ s"$stageId" + delimiter +
+ s"${unSupExecInfo.execId.toString}" + delimiter +
+ s"${reformatCSVFunc(unSupExecInfo.finalOpType)}" + delimiter +
+ s"${unSupExecInfo.unsupportedOperatorCSVFormat}" + delimiter +
+ s"${reformatCSVFunc(unSupExecInfo.details)}" + delimiter +
+ s"$stageAppDuration" + delimiter +
+ s"$appDuration" + delimiter +
+ s"${unSupExecInfo.opAction}\n"
}
def getUnsupportedRows(execI: ExecInfo, stageId: Int, stageDur: Long): String = {
val results = execI.getUnsupportedExecSummaryRecord(execIdGenerator.getAndIncrement())
+ val stageIDStr = stageId.toString
+ val stageDurStr = stageDur.toString
results.map { unsupportedExecSummary =>
- constructDetailedUnsupportedRow(unsupportedExecSummary, stageId, stageDur)
+ constructDetailedUnsupportedRow(appIDStr, unsupportedExecSummary, stageIDStr, stageDurStr)
}.mkString
}
diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala
index e31064851..771484c14 100644
--- a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala
+++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import scala.util.control.NonFatal
import com.nvidia.spark.rapids.tool.ToolTestUtils
+import com.nvidia.spark.rapids.tool.planparser.ops.{ExprOpRef, OpRef}
import com.nvidia.spark.rapids.tool.qualification._
import org.scalatest.Matchers.{be, contain, convertToAnyShouldWrapper}
import org.scalatest.exceptions.TestFailedException
@@ -1319,46 +1320,69 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
test("Parsing Conditional Expressions") {
// scalastyle:off line.size.limit
- val expressionsMap: mutable.HashMap[String, Array[String]] = mutable.HashMap(
- "(((lower(partition_act#90) = moduleview) && (isnotnull(productarr#22) && NOT (productarr#22=[]))) || (lower(moduletype#13) = saveforlater))" ->
- Array("lower", "isnotnull", "EqualTo", "And", "Not", "Or"),
+ val expressionsMap: mutable.HashMap[String, Map[String, Int]] = mutable.HashMap(
+ "(((lower(partition_act#90) = moduleview) && (isnotnull(productarr#22) && NOT (productarr#22 = []))) || (lower(moduletype#13) = saveforlater))" ->
+ Map("lower" -> 2,
+ "isnotnull" -> 1,
+ "EqualTo" -> 3,
+ "And" -> 2,
+ "Not" -> 1,
+ "Or" -> 1),
"(IsNotNull(c_customer_id))" ->
- Array("IsNotNull"),
+ Map("IsNotNull" -> 1),
"(isnotnull(names#15) AND StartsWith(names#15, OR))" ->
- Array("isnotnull", "And", "StartsWith"),
+ Map("isnotnull" -> 1,
+ "StartsWith" -> 1,
+ "And" -> 1),
"((isnotnull(s_state#68) AND (s_state#68 = TN)) OR (hex(cast(value#0 as bigint)) = B))" ->
- Array("isnotnull", "And", "Or", "hex", "EqualTo"),
+ Map("isnotnull" -> 1,
+ "And" -> 1,
+ "Or" -> 1,
+ "hex" -> 1,
+ "EqualTo" -> 2),
// Test that AND followed by '(' without space can be parsed
"((isnotnull(s_state#68) AND(s_state#68 = TN)) OR (hex(cast(value#0 as bigint)) = B))" ->
- Array("isnotnull", "And", "Or", "hex", "EqualTo"),
+ Map("isnotnull" -> 1,
+ "And" -> 1,
+ "Or" -> 1,
+ "hex" -> 1,
+ "EqualTo" -> 2),
"(((isnotnull(d_year#498) AND isnotnull(d_moy#500)) AND (d_year#498 = 1998)) AND (d_moy#500 = 12))" ->
- Array("isnotnull", "And", "EqualTo"),
+ Map("isnotnull" -> 2,
+ "And" -> 3,
+ "EqualTo" -> 2),
"IsNotNull(d_year) AND IsNotNull(d_moy) AND EqualTo(d_year,1998) AND EqualTo(d_moy,12)" ->
- Array("IsNotNull", "And", "EqualTo"),
+ Map("IsNotNull" -> 2,
+ "EqualTo" -> 2,
+ "And" -> 3),
// check that a predicate with a single variable name is fine
"flagVariable" ->
- Array(),
+ Map(),
// check that a predicate with a single function call
"isnotnull(c_customer_sk#412)" ->
- Array("isnotnull"),
+ Map("isnotnull" -> 1),
"((substr(ca_zip#457, 1, 5) IN (85669,86197,88274,83405,86475,85392,85460,80348,81792) OR ca_state#456 IN (CA,WA,GA)) OR (cs_sales_price#20 > 500.00))" ->
- Array("substr", "In", "Or", "GreaterThan"),
+ Map("substr" -> 1,
+ "In" -> 2,
+ "Or" -> 2,
+ "GreaterThan" -> 1),
// test the operator is at the beginning of expression and not followed by space
"NOT(isnotnull(d_moy))" ->
- Array("Not", "isnotnull"),
+ Map("isnotnull" -> 1,
+ "Not" -> 1),
// test the shiftright operator(since spark-4.0)
"((isnotnull(d_year#498) AND isnotnull(d_moy#500)) AND (d_year#498 >> 1) >= 100)" ->
- Array("isnotnull", "And", "GreaterThanOrEqual", "ShiftRight")
+ Map("isnotnull" -> 2,
+ "And" -> 2,
+ "ShiftRight" -> 1,
+ "GreaterThanOrEqual" -> 1)
)
// scalastyle:on line.size.limit
- for ((condExpr, expectedExpression) <- expressionsMap) {
- val parsedExpressionsMine = SQLPlanParser.parseConditionalExpressions(condExpr)
- val currOutput = parsedExpressionsMine.sorted
- val expectedOutput = expectedExpression.sorted
- assert(currOutput sameElements expectedOutput,
- s"The parsed expressions are not as expected. Expression: ${condExpr}, " +
- s"Expected: ${expectedOutput.mkString}, " +
- s"Output: ${currOutput.mkString}")
+ for ((condExpr, expectedExpressionCounts) <- expressionsMap) {
+ val rawExpressions = SQLPlanParser.parseConditionalExpressions(condExpr)
+ val expected = expectedExpressionCounts.map(e => ExprOpRef(OpRef.fromExpr(e._1), e._2))
+ val actualExpressions = ExprOpRef.fromRawExprSeq(rawExpressions)
+ actualExpressions should ===(expected)
}
}
@@ -1403,10 +1427,18 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
"AND (content_name_16#197L = 1)) AND NOT (split(split(split(replace(replace(replace" +
"(replace(trim(replace(cast(unbase64(content#192) as string), , ), Some( )), *., ), *, ), " +
"https://, ), http://, ), /, -1)[0], :, -1)[0], \\?, -1)[0] = ))"
- val expected = Array("isnotnull", "split", "replace", "trim", "unbase64", "And",
- "EqualTo", "Not")
- val expressions = SQLPlanParser.parseFilterExpressions(exprString)
- expressions should ===(expected)
+ val expected = Map(
+ "isnotnull" -> 3,
+ "split" -> 3,
+ "replace" -> 5,
+ "trim" -> 1,
+ "unbase64" -> 1,
+ "And" -> 6,
+ "EqualTo" -> 4, // EqualTo comes from the = operator
+ "Not" -> 2).map(e => ExprOpRef(OpRef.fromExpr(e._1), e._2))
+ val rawExpressions = SQLPlanParser.parseFilterExpressions(exprString)
+ val actualExpressions = ExprOpRef.fromRawExprSeq(rawExpressions)
+ actualExpressions should ===(expected)
}
@@ -1418,9 +1450,16 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
"THEN concat(replace(cast(unbase64(content#192) as string), , ), %) " +
"ELSE replace(replace(replace(cast(unbase64(content#192) as string), , ), %, " +
"\\%), *, %) END#200])"
- val expected = Array("replace", "concat", "instr", "split", "trim", "unbase64")
- val expressions = SQLPlanParser.parseAggregateExpressions(exprString)
- expressions should ===(expected)
+ val expected = Map(
+ "replace" -> 10,
+ "concat" -> 1,
+ "instr" -> 1,
+ "split" -> 3,
+ "trim" -> 1,
+ "unbase64" -> 4).map(e => ExprOpRef(OpRef.fromExpr(e._1), e._2))
+ val rawExpressions = SQLPlanParser.parseAggregateExpressions(exprString)
+ val actualExpressions = ExprOpRef.fromRawExprSeq(rawExpressions)
+ actualExpressions should ===(expected)
}
runConditionalTest("promote_precision is supported for Spark LT 3.4.0: issue-517",
@@ -1522,7 +1561,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
"gid#1296," +
"CAST((IF((supersql_t12.`ret_type` = 2), 1, 0)) AS BIGINT)#1300L]]"
// Only "IF" should be picked up as a function name
- val expected = Array("IF")
+ val expected = Array("IF", "IF", "IF")
val expressions = SQLPlanParser.parseExpandExpressions(exprString)
expressions should ===(expected)
}