diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala index 6bf62a537..c76e20fe2 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala @@ -22,7 +22,7 @@ import com.nvidia.spark.rapids.tool.profiling.ClusterProperties import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.tool.{ExistingClusterInfo, RecommendedClusterInfo} -import org.apache.spark.sql.rapids.tool.util.StringUtils +import org.apache.spark.sql.rapids.tool.util.{SparkRuntime, StringUtils} /** * Utility object containing constants for various platform names. @@ -132,6 +132,19 @@ abstract class Platform(var gpuDevice: Option[GpuDevice], var recommendedClusterInfo: Option[RecommendedClusterInfo] = None // the number of GPUs to use, this might be updated as we handle different cases var numGpus: Int = 1 + // Default runtime for the platform + val defaultRuntime: SparkRuntime.SparkRuntime = SparkRuntime.SPARK + // Set of supported runtimes for the platform + protected val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set( + SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS + ) + + /** + * Checks if the given runtime is supported by the platform. + */ + def isRuntimeSupported(runtime: SparkRuntime.SparkRuntime): Boolean = { + supportedRuntimes.contains(runtime) + } // This function allow us to have one gpu type used by the auto // tuner recommendations but have a different GPU used for speedup @@ -506,6 +519,10 @@ abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice], override val defaultGpuDevice: GpuDevice = T4Gpu override def isPlatformCSP: Boolean = true + override val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set( + SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS, SparkRuntime.PHOTON + ) + // note that Databricks generally sets the spark.executor.memory for the user. Our // auto tuner heuristics generally sets it lower then Databricks so go ahead and // allow our auto tuner to take affect for this in anticipation that we will use more diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala index bad5524e3..1b7d1cced 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, EventLogInfo, EventLogPathProcessor, FailedEventLog, PlatformFactory, ToolBase} +import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, EventLogInfo, EventLogPathProcessor, FailedEventLog, Platform, PlatformFactory, ToolBase} import com.nvidia.spark.rapids.tool.profiling.AutoTuner.loadClusterProps import com.nvidia.spark.rapids.tool.views._ import org.apache.hadoop.conf.Configuration @@ -43,6 +43,8 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea private val outputCombined: Boolean = appArgs.combined() private val useAutoTuner: Boolean = appArgs.autoTuner() private val outputAlignedSQLIds: Boolean = appArgs.outputSqlIdsAligned() + // Unlike qualification tool, profiler tool does not platform per app + private val platform: Platform = PlatformFactory.createInstance(appArgs.platform()) override def getNumThreads: Int = appArgs.numThreads.getOrElse( Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt) @@ -286,9 +288,9 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea private def createApp(path: EventLogInfo, index: Int, hadoopConf: Configuration): Either[FailureApp, ApplicationInfo] = { try { - // This apps only contains 1 app in each loop. + // These apps only contains 1 app in each loop. val startTime = System.currentTimeMillis() - val app = new ApplicationInfo(hadoopConf, path, index) + val app = new ApplicationInfo(hadoopConf, path, index, platform) EventLogPathProcessor.logApplicationInfo(app) val endTime = System.currentTimeMillis() if (!app.isAppMetaDefined) { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala index e3313b832..eb1f69fc7 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala @@ -23,7 +23,7 @@ import scala.collection.immutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashSet, Map} import com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent -import com.nvidia.spark.rapids.tool.{DatabricksEventLog, DatabricksRollingEventLogFilesFileReader, EventLogInfo} +import com.nvidia.spark.rapids.tool.{DatabricksEventLog, DatabricksRollingEventLogFilesFileReader, EventLogInfo, Platform} import com.nvidia.spark.rapids.tool.planparser.{HiveParseHelper, ReadParser} import com.nvidia.spark.rapids.tool.planparser.HiveParseHelper.isHiveTableScanNode import com.nvidia.spark.rapids.tool.profiling.{BlockManagerRemovedCase, DriverAccumCase, JobInfoClass, ResourceProfileInfoCase, SQLExecutionInfoClass, SQLPlanMetricsCase} @@ -37,12 +37,13 @@ import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo} import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraphNode import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager} -import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source} +import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, SparkRuntime, ToolsPlanGraph, UTF8Source} import org.apache.spark.util.Utils abstract class AppBase( val eventLogInfo: Option[EventLogInfo], - val hadoopConf: Option[Configuration]) extends Logging with ClusterTagPropHandler { + val hadoopConf: Option[Configuration], + val platform: Option[Platform] = None) extends Logging with ClusterTagPropHandler { var appMetaData: Option[AppMetaData] = None @@ -485,6 +486,27 @@ abstract class AppBase( processEventsInternal() postCompletion() } + + /** + * Returns the SparkRuntime environment in which the application is being executed. + * This is calculated based on other cached properties. + * + * If the platform is provided, and it does not support the parsed runtime, + * the method will log a warning and fall back to the platform’s default runtime. + */ + override def getSparkRuntime: SparkRuntime.SparkRuntime = { + val parsedRuntime = super.getSparkRuntime + platform.map { p => + if (p.isRuntimeSupported(parsedRuntime)) { + parsedRuntime + } else { + logWarning(s"Application $appId: Platform '${p.platformName}' does not support " + + s"the parsed runtime '$parsedRuntime'. Falling back to default runtime - " + + s"'${p.defaultRuntime}'.") + p.defaultRuntime + } + }.getOrElse(parsedRuntime) + } } object AppBase { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala index 83a3cbc0b..6fbf2bb68 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.tool.profiling import scala.collection.Map -import com.nvidia.spark.rapids.tool.EventLogInfo +import com.nvidia.spark.rapids.tool.{EventLogInfo, Platform, PlatformFactory} import com.nvidia.spark.rapids.tool.analysis.AppSQLPlanAnalyzer import org.apache.hadoop.conf.Configuration @@ -184,8 +184,9 @@ object SparkPlanInfoWithStage { class ApplicationInfo( hadoopConf: Configuration, eLogInfo: EventLogInfo, - val index: Int) - extends AppBase(Some(eLogInfo), Some(hadoopConf)) with Logging { + val index: Int, + platform: Platform = PlatformFactory.createInstance()) + extends AppBase(Some(eLogInfo), Some(hadoopConf), Some(platform)) with Logging { private lazy val eventProcessor = new EventsProcessor(this) diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala index 887075c8c..d2ac79ea2 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala @@ -41,7 +41,7 @@ class QualificationAppInfo( mlOpsEnabled: Boolean = false, penalizeTransitions: Boolean = true, platform: Platform) - extends AppBase(eventLogInfo, hadoopConf) with Logging { + extends AppBase(eventLogInfo, hadoopConf, Some(platform)) with Logging { var lastJobEndTime: Option[Long] = None var lastSQLEndTime: Option[Long] = None diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/ToolTestUtils.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/ToolTestUtils.scala index bd5e7bf25..011e5010e 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/ToolTestUtils.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/ToolTestUtils.scala @@ -144,12 +144,13 @@ object ToolTestUtils extends Logging { val apps: ArrayBuffer[ApplicationInfo] = ArrayBuffer[ApplicationInfo]() val appArgs = new ProfileArgs(logs) var index: Int = 1 + val platform = PlatformFactory.createInstance(appArgs.platform()) for (path <- appArgs.eventlog()) { val eventLogInfo = EventLogPathProcessor .getEventLogInfo(path, RapidsToolsConfUtil.newHadoopConf()) - assert(eventLogInfo.size >= 1, s"event log not parsed as expected $path") + assert(eventLogInfo.nonEmpty, s"event log not parsed as expected $path") apps += new ApplicationInfo(RapidsToolsConfUtil.newHadoopConf(), - eventLogInfo.head._1, index) + eventLogInfo.head._1, index, platform) index += 1 } apps diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala index de9921cec..ec8fb7ab8 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala @@ -22,7 +22,7 @@ import java.nio.file.{Files, Paths, StandardOpenOption} import scala.collection.mutable.ArrayBuffer -import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, StatusReportCounts, ToolTestUtils} +import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformNames, StatusReportCounts, ToolTestUtils} import com.nvidia.spark.rapids.tool.views.RawMetricProfilerView import org.apache.hadoop.io.IOUtils import org.scalatest.FunSuite @@ -1116,17 +1116,37 @@ class ApplicationInfoSuite extends FunSuite with Logging { } } - val sparkRuntimeTestCases: Seq[(SparkRuntime.Value, String)] = Seq( - SparkRuntime.SPARK -> s"$qualLogDir/nds_q86_test", - SparkRuntime.SPARK_RAPIDS -> s"$logDir/nds_q66_gpu.zstd", - SparkRuntime.PHOTON -> s"$qualLogDir/nds_q88_photon_db_13_3.zstd" + // scalastyle:off line.size.limit + val sparkRuntimeTestCases: Map[String, Seq[(String, SparkRuntime.Value)]] = Map( + // tests for standard Spark runtime + s"$qualLogDir/nds_q86_test" -> Seq( + (PlatformNames.DATABRICKS_AWS, SparkRuntime.SPARK), // Expected: SPARK on Databricks AWS + (PlatformNames.ONPREM, SparkRuntime.SPARK) // Expected: SPARK on Onprem + ), + // tests for Spark Rapids runtime + s"$logDir/nds_q66_gpu.zstd" -> Seq( + (PlatformNames.DATABRICKS_AWS, SparkRuntime.SPARK_RAPIDS), // Expected: SPARK_RAPIDS on Databricks AWS + (PlatformNames.ONPREM, SparkRuntime.SPARK_RAPIDS) // Expected: SPARK_RAPIDS on Onprem + ), + // tests for Photon runtime with fallback to SPARK for unsupported platforms + s"$qualLogDir/nds_q88_photon_db_13_3.zstd" -> Seq( + (PlatformNames.DATABRICKS_AWS, SparkRuntime.PHOTON), // Expected: PHOTON on Databricks AWS + (PlatformNames.DATABRICKS_AZURE, SparkRuntime.PHOTON), // Expected: PHOTON on Databricks Azure + (PlatformNames.ONPREM, SparkRuntime.SPARK), // Expected: Fallback to SPARK on Onprem + (PlatformNames.DATAPROC, SparkRuntime.SPARK) // Expected: Fallback to SPARK on Dataproc + ) ) - - sparkRuntimeTestCases.foreach { case (expectedSparkRuntime, eventLog) => - test(s"test spark runtime property for ${expectedSparkRuntime.toString} eventlog") { - val apps = ToolTestUtils.processProfileApps(Array(eventLog), sparkSession) - assert(apps.size == 1) - assert(apps.head.getSparkRuntime == expectedSparkRuntime) + // scalastyle:on line.size.limit + + sparkRuntimeTestCases.foreach { case (logPath, platformRuntimeCases) => + val baseFileName = logPath.split("/").last + platformRuntimeCases.foreach { case (platform, expectedRuntime) => + test(s"test eventlog $baseFileName on $platform has runtime: $expectedRuntime") { + val args = Array("--platform", platform, logPath) + val apps = ToolTestUtils.processProfileApps(args, sparkSession) + assert(apps.size == 1) + assert(apps.head.getSparkRuntime == expectedRuntime) + } } } }