Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix platform names as string constants and reduce redundancy in unit tests #667

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.nvidia.spark.rapids.tool

import scala.annotation.tailrec

import org.apache.spark.internal.Logging

/**
Expand All @@ -33,6 +35,7 @@ object PlatformNames {
val EMR_A10 = "emr-a10"
val EMR_T4 = "emr-t4"
val ONPREM = "onprem"
val DEFAULT: String = ONPREM

/**
* Return a list of all platform names.
Expand Down Expand Up @@ -93,6 +96,8 @@ class Platform(platformName: String) {
recommendationsToExclude.forall(excluded => !comment.contains(excluded))
}

def getName: String = platformName

def getOperatorScoreFile: String = {
s"operatorsScore-$platformName.csv"
}
Expand Down Expand Up @@ -130,25 +135,26 @@ object PlatformFactory extends Logging {
* @return An instance of the specified platform.
* @throws IllegalArgumentException if the specified platform key is not supported.
*/
def createInstance(platformKey: String): Platform = {
@tailrec
def createInstance(platformKey: String = PlatformNames.DEFAULT): Platform = {
platformKey match {
case PlatformNames.DATABRICKS_AWS => new DatabricksPlatform(PlatformNames.DATABRICKS_AWS)
case PlatformNames.DATABRICKS_AZURE => new DatabricksPlatform(PlatformNames.DATABRICKS_AZURE)
case PlatformNames.DATABRICKS_AWS | PlatformNames.DATABRICKS_AZURE =>
new DatabricksPlatform(platformKey)
case PlatformNames.DATAPROC | PlatformNames.DATAPROC_T4 =>
// if no GPU specified, then default to dataproc-t4 for backward compatibility
new DataprocPlatform(PlatformNames.DATAPROC_T4)
case PlatformNames.DATAPROC_L4 => new DataprocPlatform(PlatformNames.DATAPROC_L4)
case PlatformNames.DATAPROC_SL_L4 => new DataprocPlatform(PlatformNames.DATAPROC_SL_L4)
case PlatformNames.DATAPROC_GKE_L4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_L4)
case PlatformNames.DATAPROC_GKE_T4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_T4)
case PlatformNames.DATAPROC_L4 | PlatformNames.DATAPROC_SL_L4 |
PlatformNames.DATAPROC_GKE_L4 | PlatformNames.DATAPROC_GKE_T4 =>
new DataprocPlatform(platformKey)
case PlatformNames.EMR | PlatformNames.EMR_T4 =>
// if no GPU specified, then default to emr-t4 for backward compatibility
new EmrPlatform(PlatformNames.EMR_T4)
case PlatformNames.EMR_A10 => new EmrPlatform(PlatformNames.EMR_A10)
case PlatformNames.ONPREM => new OnPremPlatform
case p if p.isEmpty =>
logInfo(s"Platform is not specified. Using ${PlatformNames.ONPREM} as default.")
new OnPremPlatform
logInfo(s"Platform is not specified. Using ${PlatformNames.DEFAULT} " +
"as default.")
PlatformFactory.createInstance(PlatformNames.DEFAULT)
case _ => throw new IllegalArgumentException(s"Unsupported platform: $platformKey. " +
s"Options include ${PlatformNames.getAllNames.mkString(", ")}.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class RecommendationEntry(val name: String,
class AutoTuner(
val clusterProps: ClusterProperties,
val appInfoProvider: AppSummaryInfoBaseProvider,
val platform: String) extends Logging {
val platform: Platform) extends Logging {

import AutoTuner._

Expand All @@ -344,7 +344,6 @@ class AutoTuner(
private val limitedLogicRecommendations: mutable.HashSet[String] = mutable.HashSet[String]()
// When enabled, the profiler recommendations should only include updated settings.
private var filterByUpdatedPropertiesEnabled: Boolean = true
val selectedPlatform: Platform = PlatformFactory.createInstance(platform)

private def isCalculationEnabled(prop: String) : Boolean = {
!limitedLogicRecommendations.contains(prop)
Expand Down Expand Up @@ -908,7 +907,7 @@ class AutoTuner(
limitedSeq.foreach(_ => limitedLogicRecommendations.add(_))
}
skipList.foreach(skipSeq => skipSeq.foreach(_ => skippedRecommendations.add(_)))
skippedRecommendations ++= selectedPlatform.recommendationsToExclude
skippedRecommendations ++= platform.recommendationsToExclude
initRecommendations()
calculateJobLevelRecommendations()
if (processPropsAndCheck) {
Expand All @@ -918,7 +917,7 @@ class AutoTuner(
addDefaultComments()
}
// add all platform specific recommendations
selectedPlatform.recommendationsToInclude.foreach {
platform.recommendationsToInclude.foreach {
case (property, value) => appendRecommendation(property, value)
}
}
Expand Down Expand Up @@ -1024,7 +1023,7 @@ object AutoTuner extends Logging {
private def handleException(
ex: Exception,
appInfo: AppSummaryInfoBaseProvider,
platform: String): AutoTuner = {
platform: Platform): AutoTuner = {
logError("Exception: " + ex.getStackTrace.mkString("Array(", ", ", ")"))
val tuning = new AutoTuner(new ClusterProperties(), appInfo, platform)
val msg = ex match {
Expand Down Expand Up @@ -1076,7 +1075,7 @@ object AutoTuner extends Logging {
def buildAutoTunerFromProps(
clusterProps: String,
singleAppProvider: AppSummaryInfoBaseProvider,
platform: String = Profiler.DEFAULT_PLATFORM): AutoTuner = {
platform: Platform = PlatformFactory.createInstance()): AutoTuner = {
try {
val clusterPropsOpt = loadClusterPropertiesFromContent(clusterProps)
new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform)
Expand All @@ -1089,7 +1088,7 @@ object AutoTuner extends Logging {
def buildAutoTuner(
filePath: String,
singleAppProvider: AppSummaryInfoBaseProvider,
platform: String = Profiler.DEFAULT_PLATFORM): AutoTuner = {
platform: Platform = PlatformFactory.createInstance()): AutoTuner = {
try {
val clusterPropsOpt = loadClusterProps(filePath)
new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.tool.profiling

import com.nvidia.spark.rapids.tool.PlatformNames
import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames}
import org.rogach.scallop.{ScallopConf, ScallopOption}
import org.rogach.scallop.exceptions.ScallopException

Expand Down Expand Up @@ -71,8 +71,9 @@ Usage: java -cp rapids-4-spark-tools_2.12-<version>.jar:$SPARK_HOME/jars/*
val platform: ScallopOption[String] =
opt[String](required = false,
descr = "Cluster platform where Spark GPU workloads were executed. Options include " +
s"${PlatformNames.getAllNames.mkString(", ")}. Default is ${PlatformNames.ONPREM}.",
default = Some(PlatformNames.ONPREM))
s"${PlatformNames.getAllNames.mkString(", ")}. " +
s"Default is ${PlatformNames.DEFAULT}.",
default = Some(PlatformNames.DEFAULT))
val generateTimeline: ScallopOption[Boolean] =
opt[Boolean](required = false,
descr = "Write an SVG graph out for the full application timeline.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal

import com.nvidia.spark.rapids.ThreadFactoryBuilder
import com.nvidia.spark.rapids.tool.{EventLogInfo, EventLogPathProcessor, PlatformNames}
import com.nvidia.spark.rapids.tool.{EventLogInfo, EventLogPathProcessor, PlatformFactory}
import org.apache.hadoop.conf.Configuration

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -511,9 +511,10 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea

if (useAutoTuner) {
val workerInfoPath = appArgs.workerInfo.getOrElse(AutoTuner.DEFAULT_WORKER_INFO_PATH)
val platform = appArgs.platform.getOrElse(Profiler.DEFAULT_PLATFORM)
val platform = appArgs.platform()
val autoTuner: AutoTuner = AutoTuner.buildAutoTuner(workerInfoPath,
new SingleAppSummaryInfoProvider(app), platform)
new SingleAppSummaryInfoProvider(app),
PlatformFactory.createInstance(platform))
// the autotuner allows skipping some properties
// e.g. getRecommendedProperties(Some(Seq("spark.executor.instances"))) skips the
// recommendation related to executor instances.
Expand Down Expand Up @@ -548,7 +549,6 @@ object Profiler {
val COMPARE_LOG_FILE_NAME_PREFIX = "rapids_4_spark_tools_compare"
val COMBINED_LOG_FILE_NAME_PREFIX = "rapids_4_spark_tools_combined"
val SUBDIR = "rapids_4_spark_profile"
val DEFAULT_PLATFORM: String = PlatformNames.ONPREM

def getAutoTunerResultsAsString(props: Seq[RecommendedPropertyResult],
comments: Seq[RecommendedCommentResult]): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.collection.mutable.{ArrayBuffer,HashMap}
import scala.io.{BufferedSource, Source}
import scala.util.control.NonFatal

import com.nvidia.spark.rapids.tool.PlatformFactory
import com.nvidia.spark.rapids.tool.{Platform, PlatformFactory}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

Expand All @@ -33,7 +33,7 @@ import org.apache.spark.internal.Logging
* by the plugin which lists the formats and types supported.
* The class also supports a custom speedup factor file as input.
*/
class PluginTypeChecker(platform: String = "onprem",
class PluginTypeChecker(platform: Platform = PlatformFactory.createInstance(),
speedupFactorFile: Option[String] = None) extends Logging {

private val NS = "NS"
Expand Down Expand Up @@ -92,7 +92,7 @@ class PluginTypeChecker(platform: String = "onprem",
speedupFactorFile match {
case None =>
logInfo(s"Reading operators scores with platform: $platform")
val file = PlatformFactory.createInstance(platform).getOperatorScoreFile
val file = platform.getOperatorScoreFile
val source = Source.fromResource(file)
readSupportedOperators(source, "score").map(x => (x._1, x._2.toDouble))
case Some(file) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.tool.qualification

import com.nvidia.spark.rapids.tool.PlatformNames
import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames}
import org.rogach.scallop.{ScallopConf, ScallopOption}
import org.rogach.scallop.exceptions.ScallopException

Expand Down Expand Up @@ -156,8 +156,9 @@ Usage: java -cp rapids-4-spark-tools_2.12-<version>.jar:$SPARK_HOME/jars/*
val platform: ScallopOption[String] =
opt[String](required = false,
descr = "Cluster platform where Spark CPU workloads were executed. Options include " +
s"${PlatformNames.getAllNames.mkString(", ")}. Default is ${PlatformNames.ONPREM}.",
default = Some(PlatformNames.ONPREM))
s"${PlatformNames.getAllNames.mkString(", ")}. " +
s"Default is ${PlatformNames.DEFAULT}.",
default = Some(PlatformNames.DEFAULT))
val speedupFactorFile: ScallopOption[String] =
opt[String](required = false,
descr = "Custom speedup factor file used to get estimated GPU speedup that is specific " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.tool.qualification

import com.nvidia.spark.rapids.tool.EventLogPathProcessor
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformFactory}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.tool.AppFilterImpl
Expand Down Expand Up @@ -58,14 +58,16 @@ object QualificationMain extends Logging {
val order = appArgs.order.getOrElse("desc")
val uiEnabled = appArgs.htmlReport.getOrElse(false)
val reportSqlLevel = appArgs.perSql.getOrElse(false)
val platform = appArgs.platform.getOrElse("onprem")
val platform = appArgs.platform()
val mlOpsEnabled = appArgs.mlFunctions.getOrElse(false)
val penalizeTransitions = appArgs.penalizeTransitions.getOrElse(true)

val hadoopConf = RapidsToolsConfUtil.newHadoopConf

val pluginTypeChecker = try {
new PluginTypeChecker(platform, appArgs.speedupFactorFile.toOption)
new PluginTypeChecker(
PlatformFactory.createInstance(platform),
appArgs.speedupFactorFile.toOption)
} catch {
case ie: IllegalStateException =>
logError("Error creating the plugin type checker!", ie)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable

import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames}
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.scalatest.Matchers.convertToAnyShouldWrapper
import org.yaml.snakeyaml.{DumperOptions, Yaml}
Expand Down Expand Up @@ -1285,14 +1286,15 @@ class AutoTunerSuite extends FunSuite with BeforeAndAfterEach with Logging {

test("test recommendations for databricks-aws platform argument") {
val databricksWorkerInfo = buildWorkerInfoAsString()
val platform = PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS)
val autoTuner = AutoTuner.buildAutoTunerFromProps(databricksWorkerInfo,
getGpuAppMockInfoProvider, "databricks-aws")
getGpuAppMockInfoProvider, platform)
val (properties, comments) = autoTuner.getRecommendedProperties()

// Assert recommendations are excluded in properties
assert(properties.map(_.property).forall(autoTuner.selectedPlatform.isValidRecommendation))
assert(properties.map(_.property).forall(autoTuner.platform.isValidRecommendation))
// Assert recommendations are skipped in comments
assert(comments.map(_.comment).forall(autoTuner.selectedPlatform.isValidComment))
assert(comments.map(_.comment).forall(autoTuner.platform.isValidComment))
}

// When spark is running as a standalone, the memoryOverhead should not be listed as a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids.tool.qualification
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}

import com.nvidia.spark.rapids.tool.ToolTestUtils
import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames, ToolTestUtils}
import com.nvidia.spark.rapids.tool.planparser.DataWritingCommandExecParser
import org.scalatest.FunSuite

Expand Down Expand Up @@ -153,68 +153,33 @@ class PluginTypeCheckerSuite extends FunSuite with Logging {
assert(result(2) == "ORC")
}

test("supported operator score from onprem") {
val checker = new PluginTypeChecker("onprem")
assert(checker.getSpeedupFactor("UnionExec") == 3.0)
assert(checker.getSpeedupFactor("Ceil") == 4)
}

test("supported operator score from dataproc-t4") {
val checker = new PluginTypeChecker("dataproc-t4")
assert(checker.getSpeedupFactor("UnionExec") == 4.88)
assert(checker.getSpeedupFactor("Ceil") == 4.88)
}

test("supported operator score from emr-t4") {
val checker = new PluginTypeChecker("emr-t4")
assert(checker.getSpeedupFactor("UnionExec") == 2.07)
assert(checker.getSpeedupFactor("Ceil") == 2.07)
}

test("supported operator score from databricks-aws") {
val checker = new PluginTypeChecker("databricks-aws")
assert(checker.getSpeedupFactor("UnionExec") == 2.45)
assert(checker.getSpeedupFactor("Ceil") == 2.45)
}

test("supported operator score from databricks-azure") {
val checker = new PluginTypeChecker("databricks-azure")
assert(checker.getSpeedupFactor("UnionExec") == 2.73)
assert(checker.getSpeedupFactor("Ceil") == 2.73)
}

test("supported operator score from dataproc-serverless-l4") {
val checker = new PluginTypeChecker("dataproc-serverless-l4")
assert(checker.getSpeedupFactor("WindowExec") == 4.25)
assert(checker.getSpeedupFactor("Ceil") == 4.25)
}

test("supported operator score from dataproc-l4") {
val checker = new PluginTypeChecker("dataproc-l4")
assert(checker.getSpeedupFactor("UnionExec") == 4.16)
assert(checker.getSpeedupFactor("Ceil") == 4.16)
}

test("supported operator score from dataproc-gke-t4") {
val checker = new PluginTypeChecker("dataproc-gke-t4")
assert(checker.getSpeedupFactor("WindowExec") == 3.65)
assert(checker.getSpeedupFactor("Ceil") == 3.65)
}

test("supported operator score from dataproc-gke-l4") {
val checker = new PluginTypeChecker("dataproc-gke-l4")
assert(checker.getSpeedupFactor("WindowExec") == 3.74)
assert(checker.getSpeedupFactor("Ceil") == 3.74)
}

test("supported operator score from emr-a10") {
val checker = new PluginTypeChecker("emr-a10")
assert(checker.getSpeedupFactor("UnionExec") == 2.59)
assert(checker.getSpeedupFactor("Ceil") == 2.59)
val platformSpeedupEntries: Seq[(String, Map[String, Double])] = Seq(
(PlatformNames.ONPREM, Map("UnionExec" -> 3.0, "Ceil" -> 4.0)),
(PlatformNames.DATAPROC_T4, Map("UnionExec" -> 4.88, "Ceil" -> 4.88)),
(PlatformNames.EMR_T4, Map("UnionExec" -> 2.07, "Ceil" -> 2.07)),
(PlatformNames.DATABRICKS_AWS, Map("UnionExec" -> 2.45, "Ceil" -> 2.45)),
(PlatformNames.DATABRICKS_AZURE, Map("UnionExec" -> 2.73, "Ceil" -> 2.73)),
(PlatformNames.DATAPROC_SL_L4, Map("WindowExec" -> 4.25, "Ceil" -> 4.25)),
(PlatformNames.DATAPROC_L4, Map("UnionExec" -> 4.16, "Ceil" -> 4.16)),
(PlatformNames.DATAPROC_GKE_T4, Map("WindowExec" -> 3.65, "Ceil" -> 3.65)),
(PlatformNames.DATAPROC_GKE_L4, Map("WindowExec" -> 3.74, "Ceil" -> 3.74)),
(PlatformNames.EMR_A10, Map("UnionExec" -> 2.59, "Ceil" -> 2.59))
)

platformSpeedupEntries.foreach { case (platformName, speedupMap) =>
test(s"supported operator score from $platformName") {
val platform = PlatformFactory.createInstance(platformName)
val checker = new PluginTypeChecker(platform)
speedupMap.foreach { case (operator, speedup) =>
assert(checker.getSpeedupFactor(operator) == speedup)
}
}
}

test("supported operator score from custom speedup factor file") {
val speedupFactorFile = ToolTestUtils.getTestResourcePath("operatorsScore-databricks-azure.csv")
// Using databricks azure speedup factor as custom file
val platform = PlatformFactory.createInstance(PlatformNames.DATABRICKS_AZURE)
val speedupFactorFile = ToolTestUtils.getTestResourcePath(platform.getOperatorScoreFile)
val checker = new PluginTypeChecker(speedupFactorFile=Some(speedupFactorFile))
assert(checker.getSpeedupFactor("SortExec") == 13.11)
assert(checker.getSpeedupFactor("FilterExec") == 3.14)
Expand Down
Loading
Loading