Skip to content

Commit

Permalink
Fix platform names as string constants and reduce redundancy in unit …
Browse files Browse the repository at this point in the history
…tests (#667)

* Improve string usage of default platform and docs

Signed-off-by: Partho Sarthi <[email protected]>

* Replace redundant tests with for loops

Signed-off-by: Partho Sarthi <[email protected]>

* Fix typo in tests

Signed-off-by: Partho Sarthi <[email protected]>

* Remove creation of new instance for default case

Signed-off-by: Partho Sarthi <[email protected]>

---------

Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa authored Nov 29, 2023
1 parent 88d420c commit 6d2fb4e
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 379 deletions.
24 changes: 15 additions & 9 deletions core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.nvidia.spark.rapids.tool

import scala.annotation.tailrec

import org.apache.spark.internal.Logging

/**
Expand All @@ -33,6 +35,7 @@ object PlatformNames {
val EMR_A10 = "emr-a10"
val EMR_T4 = "emr-t4"
val ONPREM = "onprem"
val DEFAULT: String = ONPREM

/**
* Return a list of all platform names.
Expand Down Expand Up @@ -93,6 +96,8 @@ class Platform(platformName: String) {
recommendationsToExclude.forall(excluded => !comment.contains(excluded))
}

def getName: String = platformName

def getOperatorScoreFile: String = {
s"operatorsScore-$platformName.csv"
}
Expand Down Expand Up @@ -130,25 +135,26 @@ object PlatformFactory extends Logging {
* @return An instance of the specified platform.
* @throws IllegalArgumentException if the specified platform key is not supported.
*/
def createInstance(platformKey: String): Platform = {
@tailrec
def createInstance(platformKey: String = PlatformNames.DEFAULT): Platform = {
platformKey match {
case PlatformNames.DATABRICKS_AWS => new DatabricksPlatform(PlatformNames.DATABRICKS_AWS)
case PlatformNames.DATABRICKS_AZURE => new DatabricksPlatform(PlatformNames.DATABRICKS_AZURE)
case PlatformNames.DATABRICKS_AWS | PlatformNames.DATABRICKS_AZURE =>
new DatabricksPlatform(platformKey)
case PlatformNames.DATAPROC | PlatformNames.DATAPROC_T4 =>
// if no GPU specified, then default to dataproc-t4 for backward compatibility
new DataprocPlatform(PlatformNames.DATAPROC_T4)
case PlatformNames.DATAPROC_L4 => new DataprocPlatform(PlatformNames.DATAPROC_L4)
case PlatformNames.DATAPROC_SL_L4 => new DataprocPlatform(PlatformNames.DATAPROC_SL_L4)
case PlatformNames.DATAPROC_GKE_L4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_L4)
case PlatformNames.DATAPROC_GKE_T4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_T4)
case PlatformNames.DATAPROC_L4 | PlatformNames.DATAPROC_SL_L4 |
PlatformNames.DATAPROC_GKE_L4 | PlatformNames.DATAPROC_GKE_T4 =>
new DataprocPlatform(platformKey)
case PlatformNames.EMR | PlatformNames.EMR_T4 =>
// if no GPU specified, then default to emr-t4 for backward compatibility
new EmrPlatform(PlatformNames.EMR_T4)
case PlatformNames.EMR_A10 => new EmrPlatform(PlatformNames.EMR_A10)
case PlatformNames.ONPREM => new OnPremPlatform
case p if p.isEmpty =>
logInfo(s"Platform is not specified. Using ${PlatformNames.ONPREM} as default.")
new OnPremPlatform
logInfo(s"Platform is not specified. Using ${PlatformNames.DEFAULT} " +
"as default.")
PlatformFactory.createInstance(PlatformNames.DEFAULT)
case _ => throw new IllegalArgumentException(s"Unsupported platform: $platformKey. " +
s"Options include ${PlatformNames.getAllNames.mkString(", ")}.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class RecommendationEntry(val name: String,
class AutoTuner(
val clusterProps: ClusterProperties,
val appInfoProvider: AppSummaryInfoBaseProvider,
val platform: String) extends Logging {
val platform: Platform) extends Logging {

import AutoTuner._

Expand All @@ -344,7 +344,6 @@ class AutoTuner(
private val limitedLogicRecommendations: mutable.HashSet[String] = mutable.HashSet[String]()
// When enabled, the profiler recommendations should only include updated settings.
private var filterByUpdatedPropertiesEnabled: Boolean = true
val selectedPlatform: Platform = PlatformFactory.createInstance(platform)

private def isCalculationEnabled(prop: String) : Boolean = {
!limitedLogicRecommendations.contains(prop)
Expand Down Expand Up @@ -908,7 +907,7 @@ class AutoTuner(
limitedSeq.foreach(_ => limitedLogicRecommendations.add(_))
}
skipList.foreach(skipSeq => skipSeq.foreach(_ => skippedRecommendations.add(_)))
skippedRecommendations ++= selectedPlatform.recommendationsToExclude
skippedRecommendations ++= platform.recommendationsToExclude
initRecommendations()
calculateJobLevelRecommendations()
if (processPropsAndCheck) {
Expand All @@ -918,7 +917,7 @@ class AutoTuner(
addDefaultComments()
}
// add all platform specific recommendations
selectedPlatform.recommendationsToInclude.foreach {
platform.recommendationsToInclude.foreach {
case (property, value) => appendRecommendation(property, value)
}
}
Expand Down Expand Up @@ -1024,7 +1023,7 @@ object AutoTuner extends Logging {
private def handleException(
ex: Exception,
appInfo: AppSummaryInfoBaseProvider,
platform: String): AutoTuner = {
platform: Platform): AutoTuner = {
logError("Exception: " + ex.getStackTrace.mkString("Array(", ", ", ")"))
val tuning = new AutoTuner(new ClusterProperties(), appInfo, platform)
val msg = ex match {
Expand Down Expand Up @@ -1076,7 +1075,7 @@ object AutoTuner extends Logging {
def buildAutoTunerFromProps(
clusterProps: String,
singleAppProvider: AppSummaryInfoBaseProvider,
platform: String = Profiler.DEFAULT_PLATFORM): AutoTuner = {
platform: Platform = PlatformFactory.createInstance()): AutoTuner = {
try {
val clusterPropsOpt = loadClusterPropertiesFromContent(clusterProps)
new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform)
Expand All @@ -1089,7 +1088,7 @@ object AutoTuner extends Logging {
def buildAutoTuner(
filePath: String,
singleAppProvider: AppSummaryInfoBaseProvider,
platform: String = Profiler.DEFAULT_PLATFORM): AutoTuner = {
platform: Platform = PlatformFactory.createInstance()): AutoTuner = {
try {
val clusterPropsOpt = loadClusterProps(filePath)
new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.tool.profiling

import com.nvidia.spark.rapids.tool.PlatformNames
import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames}
import org.rogach.scallop.{ScallopConf, ScallopOption}
import org.rogach.scallop.exceptions.ScallopException

Expand Down Expand Up @@ -71,8 +71,9 @@ Usage: java -cp rapids-4-spark-tools_2.12-<version>.jar:$SPARK_HOME/jars/*
val platform: ScallopOption[String] =
opt[String](required = false,
descr = "Cluster platform where Spark GPU workloads were executed. Options include " +
s"${PlatformNames.getAllNames.mkString(", ")}. Default is ${PlatformNames.ONPREM}.",
default = Some(PlatformNames.ONPREM))
s"${PlatformNames.getAllNames.mkString(", ")}. " +
s"Default is ${PlatformNames.DEFAULT}.",
default = Some(PlatformNames.DEFAULT))
val generateTimeline: ScallopOption[Boolean] =
opt[Boolean](required = false,
descr = "Write an SVG graph out for the full application timeline.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal

import com.nvidia.spark.rapids.ThreadFactoryBuilder
import com.nvidia.spark.rapids.tool.{EventLogInfo, EventLogPathProcessor, PlatformNames}
import com.nvidia.spark.rapids.tool.{EventLogInfo, EventLogPathProcessor, PlatformFactory}
import org.apache.hadoop.conf.Configuration

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -511,9 +511,10 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea

if (useAutoTuner) {
val workerInfoPath = appArgs.workerInfo.getOrElse(AutoTuner.DEFAULT_WORKER_INFO_PATH)
val platform = appArgs.platform.getOrElse(Profiler.DEFAULT_PLATFORM)
val platform = appArgs.platform()
val autoTuner: AutoTuner = AutoTuner.buildAutoTuner(workerInfoPath,
new SingleAppSummaryInfoProvider(app), platform)
new SingleAppSummaryInfoProvider(app),
PlatformFactory.createInstance(platform))
// the autotuner allows skipping some properties
// e.g. getRecommendedProperties(Some(Seq("spark.executor.instances"))) skips the
// recommendation related to executor instances.
Expand Down Expand Up @@ -548,7 +549,6 @@ object Profiler {
val COMPARE_LOG_FILE_NAME_PREFIX = "rapids_4_spark_tools_compare"
val COMBINED_LOG_FILE_NAME_PREFIX = "rapids_4_spark_tools_combined"
val SUBDIR = "rapids_4_spark_profile"
val DEFAULT_PLATFORM: String = PlatformNames.ONPREM

def getAutoTunerResultsAsString(props: Seq[RecommendedPropertyResult],
comments: Seq[RecommendedCommentResult]): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.collection.mutable.{ArrayBuffer,HashMap}
import scala.io.{BufferedSource, Source}
import scala.util.control.NonFatal

import com.nvidia.spark.rapids.tool.PlatformFactory
import com.nvidia.spark.rapids.tool.{Platform, PlatformFactory}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

Expand All @@ -33,7 +33,7 @@ import org.apache.spark.internal.Logging
* by the plugin which lists the formats and types supported.
* The class also supports a custom speedup factor file as input.
*/
class PluginTypeChecker(platform: String = "onprem",
class PluginTypeChecker(platform: Platform = PlatformFactory.createInstance(),
speedupFactorFile: Option[String] = None) extends Logging {

private val NS = "NS"
Expand Down Expand Up @@ -92,7 +92,7 @@ class PluginTypeChecker(platform: String = "onprem",
speedupFactorFile match {
case None =>
logInfo(s"Reading operators scores with platform: $platform")
val file = PlatformFactory.createInstance(platform).getOperatorScoreFile
val file = platform.getOperatorScoreFile
val source = Source.fromResource(file)
readSupportedOperators(source, "score").map(x => (x._1, x._2.toDouble))
case Some(file) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.tool.qualification

import com.nvidia.spark.rapids.tool.PlatformNames
import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames}
import org.rogach.scallop.{ScallopConf, ScallopOption}
import org.rogach.scallop.exceptions.ScallopException

Expand Down Expand Up @@ -156,8 +156,9 @@ Usage: java -cp rapids-4-spark-tools_2.12-<version>.jar:$SPARK_HOME/jars/*
val platform: ScallopOption[String] =
opt[String](required = false,
descr = "Cluster platform where Spark CPU workloads were executed. Options include " +
s"${PlatformNames.getAllNames.mkString(", ")}. Default is ${PlatformNames.ONPREM}.",
default = Some(PlatformNames.ONPREM))
s"${PlatformNames.getAllNames.mkString(", ")}. " +
s"Default is ${PlatformNames.DEFAULT}.",
default = Some(PlatformNames.DEFAULT))
val speedupFactorFile: ScallopOption[String] =
opt[String](required = false,
descr = "Custom speedup factor file used to get estimated GPU speedup that is specific " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.tool.qualification

import com.nvidia.spark.rapids.tool.EventLogPathProcessor
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformFactory}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.tool.AppFilterImpl
Expand Down Expand Up @@ -58,14 +58,16 @@ object QualificationMain extends Logging {
val order = appArgs.order.getOrElse("desc")
val uiEnabled = appArgs.htmlReport.getOrElse(false)
val reportSqlLevel = appArgs.perSql.getOrElse(false)
val platform = appArgs.platform.getOrElse("onprem")
val platform = appArgs.platform()
val mlOpsEnabled = appArgs.mlFunctions.getOrElse(false)
val penalizeTransitions = appArgs.penalizeTransitions.getOrElse(true)

val hadoopConf = RapidsToolsConfUtil.newHadoopConf

val pluginTypeChecker = try {
new PluginTypeChecker(platform, appArgs.speedupFactorFile.toOption)
new PluginTypeChecker(
PlatformFactory.createInstance(platform),
appArgs.speedupFactorFile.toOption)
} catch {
case ie: IllegalStateException =>
logError("Error creating the plugin type checker!", ie)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable

import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames}
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.scalatest.Matchers.convertToAnyShouldWrapper
import org.yaml.snakeyaml.{DumperOptions, Yaml}
Expand Down Expand Up @@ -1285,14 +1286,15 @@ class AutoTunerSuite extends FunSuite with BeforeAndAfterEach with Logging {

test("test recommendations for databricks-aws platform argument") {
val databricksWorkerInfo = buildWorkerInfoAsString()
val platform = PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS)
val autoTuner = AutoTuner.buildAutoTunerFromProps(databricksWorkerInfo,
getGpuAppMockInfoProvider, "databricks-aws")
getGpuAppMockInfoProvider, platform)
val (properties, comments) = autoTuner.getRecommendedProperties()

// Assert recommendations are excluded in properties
assert(properties.map(_.property).forall(autoTuner.selectedPlatform.isValidRecommendation))
assert(properties.map(_.property).forall(autoTuner.platform.isValidRecommendation))
// Assert recommendations are skipped in comments
assert(comments.map(_.comment).forall(autoTuner.selectedPlatform.isValidComment))
assert(comments.map(_.comment).forall(autoTuner.platform.isValidComment))
}

// When spark is running as a standalone, the memoryOverhead should not be listed as a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids.tool.qualification
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}

import com.nvidia.spark.rapids.tool.ToolTestUtils
import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames, ToolTestUtils}
import com.nvidia.spark.rapids.tool.planparser.DataWritingCommandExecParser
import org.scalatest.FunSuite

Expand Down Expand Up @@ -153,68 +153,33 @@ class PluginTypeCheckerSuite extends FunSuite with Logging {
assert(result(2) == "ORC")
}

test("supported operator score from onprem") {
val checker = new PluginTypeChecker("onprem")
assert(checker.getSpeedupFactor("UnionExec") == 3.0)
assert(checker.getSpeedupFactor("Ceil") == 4)
}

test("supported operator score from dataproc-t4") {
val checker = new PluginTypeChecker("dataproc-t4")
assert(checker.getSpeedupFactor("UnionExec") == 4.88)
assert(checker.getSpeedupFactor("Ceil") == 4.88)
}

test("supported operator score from emr-t4") {
val checker = new PluginTypeChecker("emr-t4")
assert(checker.getSpeedupFactor("UnionExec") == 2.07)
assert(checker.getSpeedupFactor("Ceil") == 2.07)
}

test("supported operator score from databricks-aws") {
val checker = new PluginTypeChecker("databricks-aws")
assert(checker.getSpeedupFactor("UnionExec") == 2.45)
assert(checker.getSpeedupFactor("Ceil") == 2.45)
}

test("supported operator score from databricks-azure") {
val checker = new PluginTypeChecker("databricks-azure")
assert(checker.getSpeedupFactor("UnionExec") == 2.73)
assert(checker.getSpeedupFactor("Ceil") == 2.73)
}

test("supported operator score from dataproc-serverless-l4") {
val checker = new PluginTypeChecker("dataproc-serverless-l4")
assert(checker.getSpeedupFactor("WindowExec") == 4.25)
assert(checker.getSpeedupFactor("Ceil") == 4.25)
}

test("supported operator score from dataproc-l4") {
val checker = new PluginTypeChecker("dataproc-l4")
assert(checker.getSpeedupFactor("UnionExec") == 4.16)
assert(checker.getSpeedupFactor("Ceil") == 4.16)
}

test("supported operator score from dataproc-gke-t4") {
val checker = new PluginTypeChecker("dataproc-gke-t4")
assert(checker.getSpeedupFactor("WindowExec") == 3.65)
assert(checker.getSpeedupFactor("Ceil") == 3.65)
}

test("supported operator score from dataproc-gke-l4") {
val checker = new PluginTypeChecker("dataproc-gke-l4")
assert(checker.getSpeedupFactor("WindowExec") == 3.74)
assert(checker.getSpeedupFactor("Ceil") == 3.74)
}

test("supported operator score from emr-a10") {
val checker = new PluginTypeChecker("emr-a10")
assert(checker.getSpeedupFactor("UnionExec") == 2.59)
assert(checker.getSpeedupFactor("Ceil") == 2.59)
val platformSpeedupEntries: Seq[(String, Map[String, Double])] = Seq(
(PlatformNames.ONPREM, Map("UnionExec" -> 3.0, "Ceil" -> 4.0)),
(PlatformNames.DATAPROC_T4, Map("UnionExec" -> 4.88, "Ceil" -> 4.88)),
(PlatformNames.EMR_T4, Map("UnionExec" -> 2.07, "Ceil" -> 2.07)),
(PlatformNames.DATABRICKS_AWS, Map("UnionExec" -> 2.45, "Ceil" -> 2.45)),
(PlatformNames.DATABRICKS_AZURE, Map("UnionExec" -> 2.73, "Ceil" -> 2.73)),
(PlatformNames.DATAPROC_SL_L4, Map("WindowExec" -> 4.25, "Ceil" -> 4.25)),
(PlatformNames.DATAPROC_L4, Map("UnionExec" -> 4.16, "Ceil" -> 4.16)),
(PlatformNames.DATAPROC_GKE_T4, Map("WindowExec" -> 3.65, "Ceil" -> 3.65)),
(PlatformNames.DATAPROC_GKE_L4, Map("WindowExec" -> 3.74, "Ceil" -> 3.74)),
(PlatformNames.EMR_A10, Map("UnionExec" -> 2.59, "Ceil" -> 2.59))
)

platformSpeedupEntries.foreach { case (platformName, speedupMap) =>
test(s"supported operator score from $platformName") {
val platform = PlatformFactory.createInstance(platformName)
val checker = new PluginTypeChecker(platform)
speedupMap.foreach { case (operator, speedup) =>
assert(checker.getSpeedupFactor(operator) == speedup)
}
}
}

test("supported operator score from custom speedup factor file") {
val speedupFactorFile = ToolTestUtils.getTestResourcePath("operatorsScore-databricks-azure.csv")
// Using databricks azure speedup factor as custom file
val platform = PlatformFactory.createInstance(PlatformNames.DATABRICKS_AZURE)
val speedupFactorFile = ToolTestUtils.getTestResourcePath(platform.getOperatorScoreFile)
val checker = new PluginTypeChecker(speedupFactorFile=Some(speedupFactorFile))
assert(checker.getSpeedupFactor("SortExec") == 13.11)
assert(checker.getSpeedupFactor("FilterExec") == 3.14)
Expand Down
Loading

0 comments on commit 6d2fb4e

Please sign in to comment.