Skip to content

Commit

Permalink
Count expressions per Exec in SQLPlanParser
Browse files Browse the repository at this point in the history
Signed-off-by: Ahmed Hussein (amahussein) <[email protected]>

Fixes NVIDIA#1447

Improves the operators-stats by counting the occurence of each
expression within the Exec node.
  • Loading branch information
amahussein committed Dec 5, 2024
1 parent a414e09 commit c15aac9
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, WeakHashMap}
import scala.util.control.NonFatal
import scala.util.matching.Regex

import com.nvidia.spark.rapids.tool.planparser.ops.{OperatorRefBase, OpRef, UnsupportedExprOpRef}
import com.nvidia.spark.rapids.tool.planparser.ops.{ExprOpRef, OperatorRefTrait, OpRef, UnsupportedExprOpRef}
import com.nvidia.spark.rapids.tool.planparser.photon.{PhotonPlanParser, PhotonStageExecParser}
import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker

Expand Down Expand Up @@ -75,7 +75,7 @@ object UnsupportedReasons extends Enumeration {
case class UnsupportedExecSummary(
sqlId: Long,
execId: Long,
execRef: OperatorRefBase,
execRef: OperatorRefTrait,
opType: OpTypes.OpType,
reason: UnsupportedReasons.UnsupportedReason,
opAction: OpActions.OpAction) {
Expand All @@ -89,8 +89,6 @@ case class UnsupportedExecSummary(
val unsupportedOperatorCSVFormat: String = execRef.getOpNameCSV

val details: String = UnsupportedReasons.reportUnsupportedReason(reason)

def isExpression: Boolean = execRef.isInstanceOf[UnsupportedExprOpRef]
}

case class ExecInfo(
Expand All @@ -110,7 +108,7 @@ case class ExecInfo(
dataSet: Boolean,
udf: Boolean,
shouldIgnore: Boolean,
expressions: Seq[OpRef]) {
expressions: Seq[ExprOpRef]) {

private def childrenToString = {
val str = children.map { c =>
Expand Down Expand Up @@ -141,10 +139,6 @@ case class ExecInfo(
stages = stageIDs
}

def appendToStages(stageIDs: Set[Int]): Unit = {
stages ++= stageIDs
}

def setShouldRemove(value: Boolean): Unit = {
shouldRemove ||= value
}
Expand Down Expand Up @@ -286,7 +280,7 @@ object ExecInfo {
udf,
shouldIgnore,
// convert array of string expressions to OpRefs
expressions = expressions.map(OpRef.fromExpr)
expressions = ExprOpRef.fromRawExprSeq(expressions)
)
}

Expand Down Expand Up @@ -351,7 +345,7 @@ case class PlanInfo(
sqlID: Long,
sqlDesc: String,
execInfo: Seq[ExecInfo]) {
def getUnsupportedExpressions: Seq[OperatorRefBase] = {
def getUnsupportedExpressions: Seq[OperatorRefTrait] = {
execInfo.flatMap { e =>
if (e.isClusterNode) {
// wholeStageCodeGen does not have expressions/unsupported-expressions
Expand Down Expand Up @@ -727,21 +721,21 @@ object SQLPlanParser extends Logging {
}

private def getAllFunctionNames(regPattern: Regex, expr: String,
groupInd: Int = 1, isAggr: Boolean = true): Set[String] = {
groupInd: Int = 1, isAggr: Boolean = true): Array[String] = {
// Returns all matches in an expression. This can be used when the SQL expression is not
// tokenized.
val newExpr = processSpecialFunctions(expr)

// first get all the functionNames
val exprss =
regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSet
regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSeq

// For aggregate expressions we want to process the results to remove the prefix
// DB: remove the "^partial_" and "^finalmerge_" prefixes
// TODO:
// for performance sake, we can turn off the aggregate processing by enabling it only
// when needed. However, for now, we always do this processing until we are confident we know
// the correct place to turn on/off that flag.we can use the argument isAgg only when needed
// the correct place to turn on/off that flag. We can use the argument isAgg only when needed
val results = if (isAggr) {
exprss.collect {
case func =>
Expand All @@ -750,15 +744,15 @@ object SQLPlanParser extends Logging {
} else {
exprss
}
results.filterNot(ignoreExpression(_))
results.filterNot(ignoreExpression(_)).toArray
}

def parseProjectExpressions(exprStr: String): Array[String] = {
// Project [cast(value#136 as string) AS value#144, CEIL(value#136) AS CEIL(value)#143L]
// This is to split the string such that only function names are extracted. The pattern is
// such that function name is succeeded by `(`. We use regex to extract all the function names
// below:
getAllFunctionNames(functionPrefixPattern, exprStr).toArray
getAllFunctionNames(functionPrefixPattern, exprStr)
}

// This parser is used for SortAggregateExec, HashAggregateExec and ObjectHashAggregateExec
Expand Down Expand Up @@ -792,7 +786,7 @@ object SQLPlanParser extends Logging {
}
}
}
parsedExpressions.distinct.toArray
parsedExpressions.toArray
}

def parseWindowExpressions(exprStr:String): Array[String] = {
Expand Down Expand Up @@ -828,7 +822,7 @@ object SQLPlanParser extends Logging {
}
}
}
parsedExpressions.distinct.toArray
parsedExpressions.toArray
}

def parseWindowGroupLimitExpressions(exprStr: String): Array[String] = {
Expand Down Expand Up @@ -858,10 +852,10 @@ object SQLPlanParser extends Logging {
// - Some values can be NULLs. That's why we cannot limit the extract to the first row.
// - Nested brackets/parenthesis makes it challenging to use regex that contains
// brackets/parenthesis to extract expressions.
// The implementation Use regex to extract all function names and return distinct set of
// The implementation Use regex to extract all function names and return a list of
// function names.
// This implementation is 1 line implementation, but it can be a memory/time bottleneck.
getAllFunctionNames(functionPrefixPattern, exprStr).toArray
getAllFunctionNames(functionPrefixPattern, exprStr)
}

def parseTakeOrderedExpressions(exprStr: String): Array[String] = {
Expand Down Expand Up @@ -889,19 +883,19 @@ object SQLPlanParser extends Logging {
}
}
}
parsedExpressions.distinct.toArray
parsedExpressions.toArray
}

def parseGenerateExpressions(exprStr: String): Array[String] = {
// Get the function names from the GenerateExec. The GenerateExec has the following format:
// 1. Generate explode(arrays#1306), [id#1304], true, [col#1426]
// 2. Generate json_tuple(values#1305, Zipcode, ZipCodeType, City), [id#1304],
// false, [c0#1407, c1#1408, c2#1409]
getAllFunctionNames(functionPrefixPattern, exprStr).toArray
getAllFunctionNames(functionPrefixPattern, exprStr)
}

private def addFunctionNames(exprs: String, parsedExpressions: ArrayBuffer[String]): Unit = {
val functionNames = getAllFunctionNames(functionPrefixPattern, exprs).toArray
val functionNames = getAllFunctionNames(functionPrefixPattern, exprs)
functionNames.foreach(parsedExpressions += _)
}

Expand Down Expand Up @@ -946,7 +940,7 @@ object SQLPlanParser extends Logging {
val isSortMergeSupported = !(isSortMergeJoin &&
joinCondition.nonEmpty && isSMJConditionUnsupported(joinCondition))

(parsedExpressions.distinct.toArray, equiJoinSupportedTypes(buildSide, joinType)
(parsedExpressions.toArray, equiJoinSupportedTypes(buildSide, joinType)
&& isSortMergeSupported)
}

Expand Down Expand Up @@ -1054,7 +1048,7 @@ object SQLPlanParser extends Logging {
}
}
}
parsedExpressions.distinct.toArray
parsedExpressions.toArray
}

def parseFilterExpressions(exprStr: String): Array[String] = {
Expand Down Expand Up @@ -1109,7 +1103,7 @@ object SQLPlanParser extends Logging {
processedExpr = nonBinaryOperatorsRegEx.replaceAllIn(processedExpr, " ")

// Step-4: remove remaining parentheses '(', ')' and commas if we had functionCalls
if (!functionMatches.isEmpty) {
if (functionMatches.nonEmpty) {
// remove ","
processedExpr = processedExpr.replaceAll(",", " ")
}
Expand Down Expand Up @@ -1147,7 +1141,6 @@ object SQLPlanParser extends Logging {
logDebug(s"Unrecognized Token - $token")
}
}

parsedExpressions.distinct.toArray
parsedExpressions.toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ abstract class WholeStageExecParserBase(
// The node should be marked as shouldRemove when all the children of the
// wholeStageCodeGen are marked as shouldRemove.
val removeNode = isDupNode || childNodes.forall(_.shouldRemove)
// Remove any suffix in order to get the node label without any trailing number.
// Remove any suffix to get the node label without any trailing number.
val nodeLabel = nodeNameRegeX.findFirstMatchIn(node.name) match {
case Some(m) => m.group(1)
// in case not found, use the full exec name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ case class WindowGroupLimitParser(

override def getUnsupportedExprReasonsForExec(
expressions: Array[String]): Seq[UnsupportedExprOpRef] = {
expressions.flatMap { expr =>
expressions.distinct.flatMap { expr =>
if (!supportedRankingExprs.contains(expr)) {
Some(UnsupportedExprOpRef(expr,
s"Ranking function $expr is not supported in $fullExecName"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.tool.planparser.ops

/**
* Represents a reference to an expression operator that is stored in the ExecInfo expressions
* @param opRef the opRef to wrap
* @param count the count of that expression within the exec.
*/
case class ExprOpRef(opRef: OpRef, count: Int = 1) extends OpRefWrapperBase(opRef)

object ExprOpRef extends OpRefWrapperBaseTrait[ExprOpRef] {
def fromRawExprSeq(exprArr: Seq[String]): Seq[ExprOpRef] = {
exprArr.groupBy(identity)
.mapValues(expr => ExprOpRef(OpRef.fromExpr(expr.head), expr.size)).values.toSeq
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.tool.planparser.ops

/**
* An instance that wraps OpRef and exposes the OpRef methods.
* This is used to provide common interface for all classes that wrap OpRef along with other
* metadata.
* @param opRef the opRef to wrap
*/
class OpRefWrapperBase(opRef: OpRef) extends OperatorRefTrait {
override def getOpName: String = opRef.getOpName

override def getOpNameCSV: String = opRef.getOpNameCSV

override def getOpType: String = opRef.getOpType

override def getOpTypeCSV: String = opRef.getOpTypeCSV

def getOpRef: OpRef = opRef
}

/**
* A trait that provides a factory method to create instances of OpRefWrapperBase from a sequence of
* @tparam R the type of the OpRefWrapperBase
*/
trait OpRefWrapperBaseTrait[R <: OpRefWrapperBase] {
/**
* Create instances of OpRefWrapperBase from a sequence of expressions.
* The expressions are grouped by their value and the count of each expression is stored in the
* OpRefWrapperBase entry.
* @param exprArr the sequence of expressions
* @return a sequence of OpRefWrapperBase instances
*/
def fromRawExprSeq(exprArr: Seq[String]): Seq[R]
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,66 +38,71 @@ case class OperatorCounter(planInfo: PlanInfo) {
* @param stages The set of stages where the operator appears.
*/
case class OperatorData(
opRef: OperatorRefBase,
opRef: OpRef,
var count: Int = 0,
var stages: Set[Int] = Set())
var stages: Set[Int] = Set()) extends OpRefWrapperBase(opRef)

// Summarizes the count information for an exec or expression, including whether it is supported.
case class OperatorCountSummary(
opData: OperatorData,
isSupported: Boolean)

private val supportedMap: mutable.Map[OperatorRefBase, OperatorData] = mutable.Map()
private val unsupportedMap: mutable.Map[OperatorRefBase, OperatorData] = mutable.Map()
private val supportedMap: mutable.Map[OpRef, OperatorData] = mutable.Map()
private val unsupportedMap: mutable.Map[OpRef, OperatorData] = mutable.Map()

// Returns a sequence of `OperatorCountSummary`, combining both supported and
// unsupported operators.
def getOpsCountSummary(): Seq[OperatorCountSummary] = {
def getOpsCountSummary: Seq[OperatorCountSummary] = {
supportedMap.values.map(OperatorCountSummary(_, isSupported = true)).toSeq ++
unsupportedMap.values.map(OperatorCountSummary(_, isSupported = false)).toSeq
}


// Updates the operator data in the given map (supported or unsupported).
// Increments the count and updates the stages where the operator appears.
private def updateOpRefEntry(opRef: OperatorRefBase, stages: Set[Int],
targetMap: mutable.Map[OperatorRefBase, OperatorData]): Unit = {
private def updateOpRefEntry(opRef: OpRef, stages: Set[Int],
targetMap: mutable.Map[OpRef, OperatorData], incrValue: Int = 1): Unit = {
val operatorData = targetMap.getOrElseUpdate(opRef, OperatorData(opRef))
operatorData.count += 1
operatorData.count += incrValue
operatorData.stages ++= stages
}

// Processes an `ExecInfo` node to update exec and expression counts.
// Separates supported and unsupported execs and expressions into their respective maps.
private def processExecInfo(execInfo: ExecInfo): Unit = {
val opMap = execInfo.isSupported match {
case true => supportedMap
case false => unsupportedMap
val opMap = if (execInfo.isSupported) {
supportedMap
} else {
unsupportedMap
}
updateOpRefEntry(execInfo.execRef, execInfo.stages, opMap)
// update the map for supported expressions. We should exclude the unsupported expressions.
execInfo.expressions.filterNot(
e => execInfo.unsupportedExprs.exists(exp => exp.opRef.equals(e))).foreach { expr =>
updateOpRefEntry(expr, execInfo.stages, supportedMap)
}
// update the map for unsupported expressions
execInfo.unsupportedExprs.foreach { expr =>
updateOpRefEntry(expr, execInfo.stages, unsupportedMap)
// Update the map for supported expressions. For unsupported expressions,
// we use the count stored in the supported expressions.
execInfo.expressions.foreach { expr =>
val exprMap =
if (execInfo.unsupportedExprs.exists(unsupExec =>
unsupExec.getOpRef.equals(expr.getOpRef))) {
// The expression skips because it exists in the unsupported expressions.
unsupportedMap
} else {
supportedMap
}
updateOpRefEntry(expr.getOpRef, execInfo.stages, exprMap, expr.count)
}
}

// Counts the execs and expressions in the execution plan.
// Counts the execs and expressions in the execution plan excluding clusterNodes
// (i.e., WholeStageCodeGen).
private def countOperators(): Unit = {
planInfo.execInfo.foreach { exec =>
exec.isClusterNode match {
// we do not want to count the cluster nodes in that aggregation
case true =>
if (exec.children.nonEmpty) {
exec.children.get.foreach { child =>
processExecInfo(child)
}
if (exec.isClusterNode) {
if (exec.children.nonEmpty) {
exec.children.get.foreach { child =>
processExecInfo(child)
}
case false => processExecInfo(exec)
}
} else {
processExecInfo(exec)
}
}
}
Expand Down
Loading

0 comments on commit c15aac9

Please sign in to comment.