Skip to content

Commit

Permalink
Remove creation of new instance for default case
Browse files Browse the repository at this point in the history
Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa committed Nov 28, 2023
1 parent b94975c commit 8533b52
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 18 deletions.
23 changes: 12 additions & 11 deletions core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.nvidia.spark.rapids.tool

import scala.annotation.tailrec

import org.apache.spark.internal.Logging

/**
Expand All @@ -33,6 +35,7 @@ object PlatformNames {
val EMR_A10 = "emr-a10"
val EMR_T4 = "emr-t4"
val ONPREM = "onprem"
val DEFAULT: String = ONPREM

/**
* Return a list of all platform names.
Expand Down Expand Up @@ -132,30 +135,28 @@ object PlatformFactory extends Logging {
* @return An instance of the specified platform.
* @throws IllegalArgumentException if the specified platform key is not supported.
*/
def createInstance(platformKey: String): Platform = {
@tailrec
def createInstance(platformKey: String = PlatformNames.DEFAULT): Platform = {
platformKey match {
case PlatformNames.DATABRICKS_AWS => new DatabricksPlatform(PlatformNames.DATABRICKS_AWS)
case PlatformNames.DATABRICKS_AZURE => new DatabricksPlatform(PlatformNames.DATABRICKS_AZURE)
case PlatformNames.DATABRICKS_AWS | PlatformNames.DATABRICKS_AZURE =>
new DatabricksPlatform(platformKey)
case PlatformNames.DATAPROC | PlatformNames.DATAPROC_T4 =>
// if no GPU specified, then default to dataproc-t4 for backward compatibility
new DataprocPlatform(PlatformNames.DATAPROC_T4)
case PlatformNames.DATAPROC_L4 => new DataprocPlatform(PlatformNames.DATAPROC_L4)
case PlatformNames.DATAPROC_SL_L4 => new DataprocPlatform(PlatformNames.DATAPROC_SL_L4)
case PlatformNames.DATAPROC_GKE_L4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_L4)
case PlatformNames.DATAPROC_GKE_T4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_T4)
case PlatformNames.DATAPROC_L4 | PlatformNames.DATAPROC_SL_L4 |
PlatformNames.DATAPROC_GKE_L4 | PlatformNames.DATAPROC_GKE_T4 =>
new DataprocPlatform(platformKey)
case PlatformNames.EMR | PlatformNames.EMR_T4 =>
// if no GPU specified, then default to emr-t4 for backward compatibility
new EmrPlatform(PlatformNames.EMR_T4)
case PlatformNames.EMR_A10 => new EmrPlatform(PlatformNames.EMR_A10)
case PlatformNames.ONPREM => new OnPremPlatform
case p if p.isEmpty =>
logInfo(s"Platform is not specified. Using ${PlatformFactory.getDefault.getName} " +
logInfo(s"Platform is not specified. Using ${PlatformNames.DEFAULT} " +
"as default.")
PlatformFactory.getDefault
PlatformFactory.createInstance(PlatformNames.DEFAULT)
case _ => throw new IllegalArgumentException(s"Unsupported platform: $platformKey. " +
s"Options include ${PlatformNames.getAllNames.mkString(", ")}.")
}
}

def getDefault: Platform = new OnPremPlatform()
}
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ object AutoTuner extends Logging {
def buildAutoTunerFromProps(
clusterProps: String,
singleAppProvider: AppSummaryInfoBaseProvider,
platform: Platform = PlatformFactory.getDefault): AutoTuner = {
platform: Platform = PlatformFactory.createInstance()): AutoTuner = {
try {
val clusterPropsOpt = loadClusterPropertiesFromContent(clusterProps)
new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform)
Expand All @@ -1088,7 +1088,7 @@ object AutoTuner extends Logging {
def buildAutoTuner(
filePath: String,
singleAppProvider: AppSummaryInfoBaseProvider,
platform: Platform = PlatformFactory.getDefault): AutoTuner = {
platform: Platform = PlatformFactory.createInstance()): AutoTuner = {
try {
val clusterPropsOpt = loadClusterProps(filePath)
new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ Usage: java -cp rapids-4-spark-tools_2.12-<version>.jar:$SPARK_HOME/jars/*
opt[String](required = false,
descr = "Cluster platform where Spark GPU workloads were executed. Options include " +
s"${PlatformNames.getAllNames.mkString(", ")}. " +
s"Default is ${PlatformFactory.getDefault.getName}.",
default = Some(PlatformFactory.getDefault.getName))
s"Default is ${PlatformNames.DEFAULT}.",
default = Some(PlatformNames.DEFAULT))
val generateTimeline: ScallopOption[Boolean] =
opt[Boolean](required = false,
descr = "Write an SVG graph out for the full application timeline.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.internal.Logging
* by the plugin which lists the formats and types supported.
* The class also supports a custom speedup factor file as input.
*/
class PluginTypeChecker(platform: Platform = PlatformFactory.getDefault,
class PluginTypeChecker(platform: Platform = PlatformFactory.createInstance(),
speedupFactorFile: Option[String] = None) extends Logging {

private val NS = "NS"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ Usage: java -cp rapids-4-spark-tools_2.12-<version>.jar:$SPARK_HOME/jars/*
opt[String](required = false,
descr = "Cluster platform where Spark CPU workloads were executed. Options include " +
s"${PlatformNames.getAllNames.mkString(", ")}. " +
s"Default is ${PlatformFactory.getDefault.getName}.",
default = Some(PlatformFactory.getDefault.getName))
s"Default is ${PlatformNames.DEFAULT}.",
default = Some(PlatformNames.DEFAULT))
val speedupFactorFile: ScallopOption[String] =
opt[String](required = false,
descr = "Custom speedup factor file used to get estimated GPU speedup that is specific " +
Expand Down

0 comments on commit 8533b52

Please sign in to comment.