Skip to content

Commit

Permalink
Replace Map with pattern matching and rename PlatformTypes to Platfor…
Browse files Browse the repository at this point in the history
…mNames

Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa committed Nov 16, 2023
1 parent 4978952 commit 032436f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 37 deletions.
54 changes: 26 additions & 28 deletions core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
/**
* Utility object containing constants for various platform names.
*/
object PlatformTypes {
object PlatformNames {
val DATABRICKS_AWS = "databricks-aws"
val DATABRICKS_AZURE = "databricks-azure"
val DATAPROC = "dataproc"
Expand All @@ -37,15 +37,16 @@ object PlatformTypes {
/**
* Return a list of all platform names.
*/
def getAllPlatformNames: List[String] = List(
def getAllNames: List[String] = List(
DATABRICKS_AWS, DATABRICKS_AZURE, DATAPROC, DATAPROC_GKE_L4, DATAPROC_GKE_T4,
DATAPROC_L4, DATAPROC_SL_L4, DATAPROC_T4, EMR, EMR_A10, EMR_T4, ONPREM
)
}

/**
* Represents a platform and its associated recommendations.
* @param platformType Type of the platform. See [[PlatformTypes]] for supported platform types.
*
* @param platformType Type of the platform. See [[PlatformNames]] for supported platform types.
*/
class Platform(platformType: String) {
/**
Expand Down Expand Up @@ -113,31 +114,14 @@ class DataprocPlatform(platformType: String) extends Platform(platformType)

class EmrPlatform(platformType: String) extends Platform(platformType)

class OnPremPlatform extends Platform(PlatformTypes.ONPREM)
class OnPremPlatform extends Platform(PlatformNames.ONPREM)

/**
* Factory for creating instances of different platforms.
* This factory supports various platforms and provides methods for creating
* corresponding platform instances.
*/
object PlatformFactory extends Logging {
private lazy val platformInstancesMap: Map[String, Platform] = Map(
PlatformTypes.DATABRICKS_AWS -> new DatabricksPlatform(PlatformTypes.DATABRICKS_AWS),
PlatformTypes.DATABRICKS_AZURE -> new DatabricksPlatform(PlatformTypes.DATABRICKS_AZURE),
// if no GPU specified, then default to dataproc-t4 for backward compatibility
PlatformTypes.DATAPROC -> new DataprocPlatform(PlatformTypes.DATAPROC_T4),
PlatformTypes.DATAPROC_T4 -> new DataprocPlatform(PlatformTypes.DATAPROC_T4),
PlatformTypes.DATAPROC_L4 -> new DataprocPlatform(PlatformTypes.DATAPROC_L4),
PlatformTypes.DATAPROC_SL_L4 -> new DataprocPlatform(PlatformTypes.DATAPROC_SL_L4),
PlatformTypes.DATAPROC_GKE_L4 -> new DataprocPlatform(PlatformTypes.DATAPROC_GKE_L4),
PlatformTypes.DATAPROC_GKE_T4 -> new DataprocPlatform(PlatformTypes.DATAPROC_GKE_T4),
// if no GPU specified, then default to emr-t4 for backward compatibility
PlatformTypes.EMR -> new EmrPlatform(PlatformTypes.EMR_T4),
PlatformTypes.EMR_T4 -> new EmrPlatform(PlatformTypes.EMR_T4),
PlatformTypes.EMR_A10 -> new EmrPlatform(PlatformTypes.EMR_A10),
PlatformTypes.ONPREM -> new OnPremPlatform
)

/**
* Creates an instance of a platform based on the specified platform key.
*
Expand All @@ -146,12 +130,26 @@ object PlatformFactory extends Logging {
* @throws IllegalArgumentException if the specified platform key is not supported.
*/
def createInstance(platformKey: String): Platform = {
val platformToUse = if (platformKey.isEmpty) {
logInfo(s"Platform is not specified, defaulting to ${PlatformTypes.ONPREM}")
PlatformTypes.ONPREM
} else platformKey

platformInstancesMap.getOrElse(platformToUse,
throw new IllegalArgumentException(s"Platform $platformToUse is not supported"))
platformKey match {
case PlatformNames.DATABRICKS_AWS => new DatabricksPlatform(PlatformNames.DATABRICKS_AWS)
case PlatformNames.DATABRICKS_AZURE => new DatabricksPlatform(PlatformNames.DATABRICKS_AZURE)
case PlatformNames.DATAPROC | PlatformNames.DATAPROC_T4 =>
// if no GPU specified, then default to dataproc-t4 for backward compatibility
new DataprocPlatform(PlatformNames.DATAPROC_T4)
case PlatformNames.DATAPROC_L4 => new DataprocPlatform(PlatformNames.DATAPROC_L4)
case PlatformNames.DATAPROC_SL_L4 => new DataprocPlatform(PlatformNames.DATAPROC_SL_L4)
case PlatformNames.DATAPROC_GKE_L4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_L4)
case PlatformNames.DATAPROC_GKE_T4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_T4)
case PlatformNames.EMR | PlatformNames.EMR_T4 =>
// if no GPU specified, then default to emr-t4 for backward compatibility
new EmrPlatform(PlatformNames.EMR_T4)
case PlatformNames.EMR_A10 => new EmrPlatform(PlatformNames.EMR_A10)
case PlatformNames.ONPREM => new OnPremPlatform
case p if p.isEmpty =>
logInfo(s"Platform is not specified. Using ${PlatformNames.ONPREM} as default.")
new OnPremPlatform
case _ => throw new IllegalArgumentException(s"Unsupported platform: $platformKey. " +
s"Options include ${PlatformNames.getAllNames.mkString(", ")}.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.tool.profiling

import com.nvidia.spark.rapids.tool.PlatformTypes
import com.nvidia.spark.rapids.tool.PlatformNames
import org.rogach.scallop.{ScallopConf, ScallopOption}
import org.rogach.scallop.exceptions.ScallopException

Expand Down Expand Up @@ -68,8 +68,8 @@ Usage: java -cp rapids-4-spark-tools_2.12-<version>.jar:$SPARK_HOME/jars/*
val platform: ScallopOption[String] =
opt[String](required = false,
descr = "Cluster platform where Spark GPU workloads were executed. Options include " +
s"${PlatformTypes.getAllPlatformNames.mkString(", ")}. Default is ${PlatformTypes.ONPREM}.",
default = Some(Profiler.DEFAULT_PLATFORM))
s"${PlatformNames.getAllNames.mkString(", ")}. Default is ${PlatformNames.ONPREM}.",
default = Some(PlatformNames.ONPREM))
val generateTimeline: ScallopOption[Boolean] =
opt[Boolean](required = false,
descr = "Write an SVG graph out for the full application timeline.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal

import com.nvidia.spark.rapids.ThreadFactoryBuilder
import com.nvidia.spark.rapids.tool.{EventLogInfo, EventLogPathProcessor, PlatformTypes}
import com.nvidia.spark.rapids.tool.{EventLogInfo, EventLogPathProcessor, PlatformNames}
import org.apache.hadoop.conf.Configuration

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -533,7 +533,7 @@ object Profiler {
val COMPARE_LOG_FILE_NAME_PREFIX = "rapids_4_spark_tools_compare"
val COMBINED_LOG_FILE_NAME_PREFIX = "rapids_4_spark_tools_combined"
val SUBDIR = "rapids_4_spark_profile"
val DEFAULT_PLATFORM: String = PlatformTypes.ONPREM
val DEFAULT_PLATFORM: String = PlatformNames.ONPREM

def getAutoTunerResultsAsString(props: Seq[RecommendedPropertyResult],
comments: Seq[RecommendedCommentResult]): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
*/
package com.nvidia.spark.rapids.tool.qualification

import com.nvidia.spark.rapids.tool.PlatformTypes

import com.nvidia.spark.rapids.tool.PlatformNames
import org.rogach.scallop.{ScallopConf, ScallopOption}
import org.rogach.scallop.exceptions.ScallopException

Expand Down Expand Up @@ -157,8 +156,8 @@ Usage: java -cp rapids-4-spark-tools_2.12-<version>.jar:$SPARK_HOME/jars/*
val platform: ScallopOption[String] =
opt[String](required = false,
descr = "Cluster platform where Spark CPU workloads were executed. Options include " +
s"${PlatformTypes.getAllPlatformNames.mkString(", ")}. Default is ${PlatformTypes.ONPREM}.",
default = Some("onprem"))
s"${PlatformNames.getAllNames.mkString(", ")}. Default is ${PlatformNames.ONPREM}.",
default = Some(PlatformNames.ONPREM))
val speedupFactorFile: ScallopOption[String] =
opt[String](required = false,
descr = "Custom speedup factor file used to get estimated GPU speedup that is specific " +
Expand Down

0 comments on commit 032436f

Please sign in to comment.