Skip to content

Commit

Permalink
Add platform specific runtime check
Browse files Browse the repository at this point in the history
Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa committed Nov 13, 2024
1 parent 4e783d9 commit d55430f
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 24 deletions.
19 changes: 18 additions & 1 deletion core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import com.nvidia.spark.rapids.tool.profiling.ClusterProperties

import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.tool.{ExistingClusterInfo, RecommendedClusterInfo}
import org.apache.spark.sql.rapids.tool.util.StringUtils
import org.apache.spark.sql.rapids.tool.util.{SparkRuntime, StringUtils}

/**
* Utility object containing constants for various platform names.
Expand Down Expand Up @@ -132,6 +132,19 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
var recommendedClusterInfo: Option[RecommendedClusterInfo] = None
// the number of GPUs to use, this might be updated as we handle different cases
var numGpus: Int = 1
// Default runtime for the platform
val defaultRuntime: SparkRuntime.SparkRuntime = SparkRuntime.SPARK
// Set of supported runtimes for the platform
protected val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set(
SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS
)

/**
* Checks if the given runtime is supported by the platform.
*/
def isRuntimeSupported(runtime: SparkRuntime.SparkRuntime): Boolean = {
supportedRuntimes.contains(runtime)
}

// This function allow us to have one gpu type used by the auto
// tuner recommendations but have a different GPU used for speedup
Expand Down Expand Up @@ -506,6 +519,10 @@ abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice],
override val defaultGpuDevice: GpuDevice = T4Gpu
override def isPlatformCSP: Boolean = true

override val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set(
SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS, SparkRuntime.PHOTON
)

// note that Databricks generally sets the spark.executor.memory for the user. Our
// auto tuner heuristics generally sets it lower then Databricks so go ahead and
// allow our auto tuner to take affect for this in anticipation that we will use more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal

import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, EventLogInfo, EventLogPathProcessor, FailedEventLog, PlatformFactory, ToolBase}
import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, EventLogInfo, EventLogPathProcessor, FailedEventLog, Platform, PlatformFactory, ToolBase}
import com.nvidia.spark.rapids.tool.profiling.AutoTuner.loadClusterProps
import com.nvidia.spark.rapids.tool.views._
import org.apache.hadoop.conf.Configuration
Expand All @@ -43,6 +43,8 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea
private val outputCombined: Boolean = appArgs.combined()
private val useAutoTuner: Boolean = appArgs.autoTuner()
private val outputAlignedSQLIds: Boolean = appArgs.outputSqlIdsAligned()
// Unlike qualification tool, profiler tool does not platform per app
private val platform: Platform = PlatformFactory.createInstance(appArgs.platform())

override def getNumThreads: Int = appArgs.numThreads.getOrElse(
Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt)
Expand Down Expand Up @@ -286,9 +288,9 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea
private def createApp(path: EventLogInfo, index: Int,
hadoopConf: Configuration): Either[FailureApp, ApplicationInfo] = {
try {
// This apps only contains 1 app in each loop.
// These apps only contains 1 app in each loop.
val startTime = System.currentTimeMillis()
val app = new ApplicationInfo(hadoopConf, path, index)
val app = new ApplicationInfo(hadoopConf, path, index, platform)
EventLogPathProcessor.logApplicationInfo(app)
val endTime = System.currentTimeMillis()
if (!app.isAppMetaDefined) {
Expand Down
28 changes: 25 additions & 3 deletions core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashSet, Map}

import com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent
import com.nvidia.spark.rapids.tool.{DatabricksEventLog, DatabricksRollingEventLogFilesFileReader, EventLogInfo}
import com.nvidia.spark.rapids.tool.{DatabricksEventLog, DatabricksRollingEventLogFilesFileReader, EventLogInfo, Platform}
import com.nvidia.spark.rapids.tool.planparser.{HiveParseHelper, ReadParser}
import com.nvidia.spark.rapids.tool.planparser.HiveParseHelper.isHiveTableScanNode
import com.nvidia.spark.rapids.tool.profiling.{BlockManagerRemovedCase, DriverAccumCase, JobInfoClass, ResourceProfileInfoCase, SQLExecutionInfoClass, SQLPlanMetricsCase}
Expand All @@ -37,12 +37,13 @@ import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraphNode
import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, SparkRuntime, ToolsPlanGraph, UTF8Source}
import org.apache.spark.util.Utils

abstract class AppBase(
val eventLogInfo: Option[EventLogInfo],
val hadoopConf: Option[Configuration]) extends Logging with ClusterTagPropHandler {
val hadoopConf: Option[Configuration],
val platform: Option[Platform] = None) extends Logging with ClusterTagPropHandler {

var appMetaData: Option[AppMetaData] = None

Expand Down Expand Up @@ -485,6 +486,27 @@ abstract class AppBase(
processEventsInternal()
postCompletion()
}

/**
* Returns the SparkRuntime environment in which the application is being executed.
* This is calculated based on other cached properties.
*
* If the platform is provided, and it does not support the parsed runtime,
* the method will log a warning and fall back to the platform’s default runtime.
*/
override def getSparkRuntime: SparkRuntime.SparkRuntime = {
val parsedRuntime = super.getSparkRuntime
platform.map { p =>
if (p.isRuntimeSupported(parsedRuntime)) {
parsedRuntime
} else {
logWarning(s"Application $appId: Platform '${p.platformName}' does not support " +
s"the parsed runtime '$parsedRuntime'. Falling back to default runtime - " +
s"'${p.defaultRuntime}'.")
p.defaultRuntime
}
}.getOrElse(parsedRuntime)
}
}

object AppBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.tool.profiling

import scala.collection.Map

import com.nvidia.spark.rapids.tool.EventLogInfo
import com.nvidia.spark.rapids.tool.{EventLogInfo, Platform, PlatformFactory}
import com.nvidia.spark.rapids.tool.analysis.AppSQLPlanAnalyzer
import org.apache.hadoop.conf.Configuration

Expand Down Expand Up @@ -184,8 +184,9 @@ object SparkPlanInfoWithStage {
class ApplicationInfo(
hadoopConf: Configuration,
eLogInfo: EventLogInfo,
val index: Int)
extends AppBase(Some(eLogInfo), Some(hadoopConf)) with Logging {
val index: Int,
platform: Platform = PlatformFactory.createInstance())
extends AppBase(Some(eLogInfo), Some(hadoopConf), Some(platform)) with Logging {

private lazy val eventProcessor = new EventsProcessor(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class QualificationAppInfo(
mlOpsEnabled: Boolean = false,
penalizeTransitions: Boolean = true,
platform: Platform)
extends AppBase(eventLogInfo, hadoopConf) with Logging {
extends AppBase(eventLogInfo, hadoopConf, Some(platform)) with Logging {

var lastJobEndTime: Option[Long] = None
var lastSQLEndTime: Option[Long] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ object ToolTestUtils extends Logging {
val apps: ArrayBuffer[ApplicationInfo] = ArrayBuffer[ApplicationInfo]()
val appArgs = new ProfileArgs(logs)
var index: Int = 1
val platform = PlatformFactory.createInstance(appArgs.platform())
for (path <- appArgs.eventlog()) {
val eventLogInfo = EventLogPathProcessor
.getEventLogInfo(path, RapidsToolsConfUtil.newHadoopConf())
assert(eventLogInfo.size >= 1, s"event log not parsed as expected $path")
assert(eventLogInfo.nonEmpty, s"event log not parsed as expected $path")
apps += new ApplicationInfo(RapidsToolsConfUtil.newHadoopConf(),
eventLogInfo.head._1, index)
eventLogInfo.head._1, index, platform)
index += 1
}
apps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.file.{Files, Paths, StandardOpenOption}

import scala.collection.mutable.ArrayBuffer

import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, StatusReportCounts, ToolTestUtils}
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformNames, StatusReportCounts, ToolTestUtils}
import com.nvidia.spark.rapids.tool.views.RawMetricProfilerView
import org.apache.hadoop.io.IOUtils
import org.scalatest.FunSuite
Expand Down Expand Up @@ -1116,17 +1116,37 @@ class ApplicationInfoSuite extends FunSuite with Logging {
}
}

val sparkRuntimeTestCases: Seq[(SparkRuntime.Value, String)] = Seq(
SparkRuntime.SPARK -> s"$qualLogDir/nds_q86_test",
SparkRuntime.SPARK_RAPIDS -> s"$logDir/nds_q66_gpu.zstd",
SparkRuntime.PHOTON -> s"$qualLogDir/nds_q88_photon_db_13_3.zstd"
// scalastyle:off line.size.limit
val sparkRuntimeTestCases: Map[String, Seq[(String, SparkRuntime.Value)]] = Map(
// tests for standard Spark runtime
s"$qualLogDir/nds_q86_test" -> Seq(
(PlatformNames.DATABRICKS_AWS, SparkRuntime.SPARK), // Expected: SPARK on Databricks AWS
(PlatformNames.ONPREM, SparkRuntime.SPARK) // Expected: SPARK on Onprem
),
// tests for Spark Rapids runtime
s"$logDir/nds_q66_gpu.zstd" -> Seq(
(PlatformNames.DATABRICKS_AWS, SparkRuntime.SPARK_RAPIDS), // Expected: SPARK_RAPIDS on Databricks AWS
(PlatformNames.ONPREM, SparkRuntime.SPARK_RAPIDS) // Expected: SPARK_RAPIDS on Onprem
),
// tests for Photon runtime with fallback to SPARK for unsupported platforms
s"$qualLogDir/nds_q88_photon_db_13_3.zstd" -> Seq(
(PlatformNames.DATABRICKS_AWS, SparkRuntime.PHOTON), // Expected: PHOTON on Databricks AWS
(PlatformNames.DATABRICKS_AZURE, SparkRuntime.PHOTON), // Expected: PHOTON on Databricks Azure
(PlatformNames.ONPREM, SparkRuntime.SPARK), // Expected: Fallback to SPARK on Onprem
(PlatformNames.DATAPROC, SparkRuntime.SPARK) // Expected: Fallback to SPARK on Dataproc
)
)

sparkRuntimeTestCases.foreach { case (expectedSparkRuntime, eventLog) =>
test(s"test spark runtime property for ${expectedSparkRuntime.toString} eventlog") {
val apps = ToolTestUtils.processProfileApps(Array(eventLog), sparkSession)
assert(apps.size == 1)
assert(apps.head.getSparkRuntime == expectedSparkRuntime)
// scalastyle:on line.size.limit

sparkRuntimeTestCases.foreach { case (logPath, platformRuntimeCases) =>
val baseFileName = logPath.split("/").last
platformRuntimeCases.foreach { case (platform, expectedRuntime) =>
test(s"test eventlog $baseFileName on $platform has runtime: $expectedRuntime") {
val args = Array("--platform", platform, logPath)
val apps = ToolTestUtils.processProfileApps(args, sparkSession)
assert(apps.size == 1)
assert(apps.head.getSparkRuntime == expectedRuntime)
}
}
}
}

0 comments on commit d55430f

Please sign in to comment.