diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala index d1b689245..67a36e8b9 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids.tool +import scala.annotation.tailrec + import org.apache.spark.internal.Logging /** @@ -33,6 +35,7 @@ object PlatformNames { val EMR_A10 = "emr-a10" val EMR_T4 = "emr-t4" val ONPREM = "onprem" + val DEFAULT: String = ONPREM /** * Return a list of all platform names. @@ -132,30 +135,28 @@ object PlatformFactory extends Logging { * @return An instance of the specified platform. * @throws IllegalArgumentException if the specified platform key is not supported. */ - def createInstance(platformKey: String): Platform = { + @tailrec + def createInstance(platformKey: String = PlatformNames.DEFAULT): Platform = { platformKey match { - case PlatformNames.DATABRICKS_AWS => new DatabricksPlatform(PlatformNames.DATABRICKS_AWS) - case PlatformNames.DATABRICKS_AZURE => new DatabricksPlatform(PlatformNames.DATABRICKS_AZURE) + case PlatformNames.DATABRICKS_AWS | PlatformNames.DATABRICKS_AZURE => + new DatabricksPlatform(platformKey) case PlatformNames.DATAPROC | PlatformNames.DATAPROC_T4 => // if no GPU specified, then default to dataproc-t4 for backward compatibility new DataprocPlatform(PlatformNames.DATAPROC_T4) - case PlatformNames.DATAPROC_L4 => new DataprocPlatform(PlatformNames.DATAPROC_L4) - case PlatformNames.DATAPROC_SL_L4 => new DataprocPlatform(PlatformNames.DATAPROC_SL_L4) - case PlatformNames.DATAPROC_GKE_L4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_L4) - case PlatformNames.DATAPROC_GKE_T4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_T4) + case PlatformNames.DATAPROC_L4 | PlatformNames.DATAPROC_SL_L4 | + PlatformNames.DATAPROC_GKE_L4 | PlatformNames.DATAPROC_GKE_T4 => + new DataprocPlatform(platformKey) case PlatformNames.EMR | PlatformNames.EMR_T4 => // if no GPU specified, then default to emr-t4 for backward compatibility new EmrPlatform(PlatformNames.EMR_T4) case PlatformNames.EMR_A10 => new EmrPlatform(PlatformNames.EMR_A10) case PlatformNames.ONPREM => new OnPremPlatform case p if p.isEmpty => - logInfo(s"Platform is not specified. Using ${PlatformFactory.getDefault.getName} " + + logInfo(s"Platform is not specified. Using ${PlatformNames.DEFAULT} " + "as default.") - PlatformFactory.getDefault + PlatformFactory.createInstance(PlatformNames.DEFAULT) case _ => throw new IllegalArgumentException(s"Unsupported platform: $platformKey. " + s"Options include ${PlatformNames.getAllNames.mkString(", ")}.") } } - - def getDefault: Platform = new OnPremPlatform() } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala index cdb8b89e7..e0f923bfb 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala @@ -1075,7 +1075,7 @@ object AutoTuner extends Logging { def buildAutoTunerFromProps( clusterProps: String, singleAppProvider: AppSummaryInfoBaseProvider, - platform: Platform = PlatformFactory.getDefault): AutoTuner = { + platform: Platform = PlatformFactory.createInstance()): AutoTuner = { try { val clusterPropsOpt = loadClusterPropertiesFromContent(clusterProps) new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform) @@ -1088,7 +1088,7 @@ object AutoTuner extends Logging { def buildAutoTuner( filePath: String, singleAppProvider: AppSummaryInfoBaseProvider, - platform: Platform = PlatformFactory.getDefault): AutoTuner = { + platform: Platform = PlatformFactory.createInstance()): AutoTuner = { try { val clusterPropsOpt = loadClusterProps(filePath) new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileArgs.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileArgs.scala index d4e452863..3723e4aea 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileArgs.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileArgs.scala @@ -72,8 +72,8 @@ Usage: java -cp rapids-4-spark-tools_2.12-.jar:$SPARK_HOME/jars/* opt[String](required = false, descr = "Cluster platform where Spark GPU workloads were executed. Options include " + s"${PlatformNames.getAllNames.mkString(", ")}. " + - s"Default is ${PlatformFactory.getDefault.getName}.", - default = Some(PlatformFactory.getDefault.getName)) + s"Default is ${PlatformNames.DEFAULT}.", + default = Some(PlatformNames.DEFAULT)) val generateTimeline: ScallopOption[Boolean] = opt[Boolean](required = false, descr = "Write an SVG graph out for the full application timeline.") diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala index ea6eb3d04..0ff5bd614 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.Logging * by the plugin which lists the formats and types supported. * The class also supports a custom speedup factor file as input. */ -class PluginTypeChecker(platform: Platform = PlatformFactory.getDefault, +class PluginTypeChecker(platform: Platform = PlatformFactory.createInstance(), speedupFactorFile: Option[String] = None) extends Logging { private val NS = "NS" diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala index fb604588e..24942b4b9 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala @@ -157,8 +157,8 @@ Usage: java -cp rapids-4-spark-tools_2.12-.jar:$SPARK_HOME/jars/* opt[String](required = false, descr = "Cluster platform where Spark CPU workloads were executed. Options include " + s"${PlatformNames.getAllNames.mkString(", ")}. " + - s"Default is ${PlatformFactory.getDefault.getName}.", - default = Some(PlatformFactory.getDefault.getName)) + s"Default is ${PlatformNames.DEFAULT}.", + default = Some(PlatformNames.DEFAULT)) val speedupFactorFile: ScallopOption[String] = opt[String](required = false, descr = "Custom speedup factor file used to get estimated GPU speedup that is specific " +