Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve shuffle manager recommendation in AutoTuner with version validation #1483

Merged
merged 9 commits into from
Jan 7, 2025
51 changes: 50 additions & 1 deletion core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -121,6 +121,7 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
val clusterProperties: Option[ClusterProperties]) extends Logging {
val platformName: String
val defaultGpuDevice: GpuDevice
val sparkVersionLabel: String = "Spark version"

// It's not deal to use vars here but to minimize changes and
// keep backwards compatibility we put them here for now and hopefully
Expand All @@ -139,6 +140,46 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS
)

// scalastyle:off line.size.limit
// Supported Spark version to RapidsShuffleManager version mapping.
// Reference: https://docs.nvidia.com/spark-rapids/user-guide/latest/additional-functionality/rapids-shuffle.html#rapids-shuffle-manager
// scalastyle:on line.size.limit
val supportedShuffleManagerVersionMap: Array[(String, String)] = Array(
"3.2.0" -> "320",
"3.2.1" -> "321",
"3.2.2" -> "322",
"3.2.3" -> "323",
"3.2.4" -> "324",
"3.3.0" -> "330",
"3.3.1" -> "331",
"3.3.2" -> "332",
"3.3.3" -> "333",
"3.3.4" -> "334",
"3.4.0" -> "340",
"3.4.1" -> "341",
"3.4.2" -> "342",
"3.4.3" -> "343",
"3.5.0" -> "350",
"3.5.1" -> "351"
)

/**
* Determine the appropriate RapidsShuffleManager version based on the
* provided spark version.
*/
def getShuffleManagerVersion(sparkVersion: String): Option[String] = {
supportedShuffleManagerVersionMap.collectFirst {
case (supportedVersion, smVersion) if sparkVersion.contains(supportedVersion) => smVersion
}
}

/**
* Identify the latest supported Spark and RapidsShuffleManager version for the platform.
*/
lazy val latestSupportedShuffleManagerInfo: (String, String) = {
supportedShuffleManagerVersionMap.maxBy(_._1)
}

/**
* Checks if the given runtime is supported by the platform.
*/
Expand Down Expand Up @@ -522,6 +563,7 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice],
clusterProperties: Option[ClusterProperties]) extends Platform(gpuDevice, clusterProperties) {
override val defaultGpuDevice: GpuDevice = T4Gpu
override val sparkVersionLabel: String = "Databricks runtime"
override def isPlatformCSP: Boolean = true

override val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set(
Expand All @@ -538,6 +580,13 @@ abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice],
"spark.executor.memoryOverhead"
)

// Supported Databricks version to RapidsShuffleManager version mapping.
override val supportedShuffleManagerVersionMap: Array[(String, String)] = Array(
"11.3" -> "330db",
"12.2" -> "332db",
"13.3" -> "341db"
)

override def createClusterInfo(coresPerExecutor: Int,
numExecsPerNode: Int,
numExecs: Int,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,7 +27,6 @@ import scala.util.control.NonFatal
import scala.util.matching.Regex

import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, GpuDevice, Platform, PlatformFactory}
import com.nvidia.spark.rapids.tool.planparser.DatabricksParseHelper
import com.nvidia.spark.rapids.tool.profiling._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path}
Expand Down Expand Up @@ -717,9 +716,9 @@ class AutoTuner(
def calculateJobLevelRecommendations(): Unit = {
// TODO - do we do anything with 200 shuffle partitions or maybe if its close
// set the Spark config spark.shuffle.sort.bypassMergeThreshold
getShuffleManagerClassName match {
case Some(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName)
case None => appendComment("Could not define the Spark Version")
getShuffleManagerClassName match {
case Right(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName)
case Left(comment) => appendComment(comment)
}
appendComment(autoTunerConfigsProvider.classPathComments("rapids.shuffle.jars"))
recommendFileCache()
Expand Down Expand Up @@ -752,30 +751,31 @@ class AutoTuner(
}
}

def getShuffleManagerClassName() : Option[String] = {
appInfoProvider.getSparkVersion.map { sparkVersion =>
val shuffleManagerVersion = sparkVersion.filterNot("().".toSet)
val dbVersion = getPropertyValue(
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY).getOrElse("")
val finalShuffleVersion : String = if (dbVersion.nonEmpty) {
dbVersion match {
case ver if ver.contains("10.4") => "321db"
case ver if ver.contains("11.3") => "330db"
case _ => "332db"
}
} else if (sparkVersion.contains("amzn")) {
sparkVersion match {
case ver if ver.contains("3.5.2") => "352"
case ver if ver.contains("3.5.1") => "351"
case ver if ver.contains("3.5.0") => "350"
case ver if ver.contains("3.4.1") => "341"
case ver if ver.contains("3.4.0") => "340"
case _ => "332"
/**
* Resolves the RapidsShuffleManager class name based on the Spark version.
* If a valid class name is not found, an error message is returned.
*
* Example:
* sparkVersion: "3.2.0-amzn-1"
* return: Right("com.nvidia.spark.rapids.spark320.RapidsShuffleManager")
*
* sparkVersion: "3.1.2"
* return: Left("Cannot recommend RAPIDS Shuffle Manager for unsupported '3.1.2' version.")
*
* @return Either an error message (Left) or the RapidsShuffleManager class name (Right)
*/
def getShuffleManagerClassName: Either[String, String] = {
appInfoProvider.getSparkVersion match {
case Some(sparkVersion) =>
platform.getShuffleManagerVersion(sparkVersion) match {
case Some(smVersion) =>
Right(autoTunerConfigsProvider.buildShuffleManagerClassName(smVersion))
case None =>
Left(autoTunerConfigsProvider.shuffleManagerCommentForUnsupportedVersion(
sparkVersion, platform))
}
} else {
shuffleManagerVersion
}
"com.nvidia.spark.rapids.spark" + finalShuffleVersion + ".RapidsShuffleManager"
case None =>
Left(autoTunerConfigsProvider.shuffleManagerCommentForMissingVersion)
}
}

Expand Down Expand Up @@ -1344,6 +1344,9 @@ trait AutoTunerConfigsProvider extends Logging {
// the plugin jar is in the form of rapids-4-spark_scala_binary-(version)-*.jar
val pluginJarRegEx: Regex = "rapids-4-spark_\\d\\.\\d+-(\\d{2}\\.\\d{2}\\.\\d+).*\\.jar".r

private val shuffleManagerDocUrl = "https://docs.nvidia.com/spark-rapids/user-guide/latest/" +
"additional-functionality/rapids-shuffle.html#rapids-shuffle-manager"

/**
* Abstract method to create an instance of the AutoTuner.
*/
Expand Down Expand Up @@ -1460,6 +1463,27 @@ trait AutoTunerConfigsProvider extends Logging {
case _ => true
}
}

def buildShuffleManagerClassName(smVersion: String): String = {
s"com.nvidia.spark.rapids.spark$smVersion.RapidsShuffleManager"
}

def shuffleManagerCommentForUnsupportedVersion(
sparkVersion: String, platform: Platform): String = {
val (latestSparkVersion, latestSmVersion) = platform.latestSupportedShuffleManagerInfo
// scalastyle:off line.size.limit
s"""
|Cannot recommend RAPIDS Shuffle Manager for unsupported ${platform.sparkVersionLabel}: '$sparkVersion'.
|To enable RAPIDS Shuffle Manager, use a supported ${platform.sparkVersionLabel} (e.g., '$latestSparkVersion')
|and set: '--conf spark.shuffle.manager=com.nvidia.spark.rapids.spark$latestSmVersion.RapidsShuffleManager'.
|See supported versions: $shuffleManagerDocUrl.
|""".stripMargin.trim.replaceAll("\n", "\n ")
// scalastyle:on line.size.limit
}

def shuffleManagerCommentForMissingVersion: String = {
"Could not recommend RapidsShuffleManager as Spark version cannot be determined."
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2183,33 +2183,139 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
assert(expectedResults == autoTunerOutput)
}

test("test shuffle manager version for databricks") {
/**
* Helper method to verify that the recommended shuffle manager version matches the
* expected version.
*/
private def verifyRecommendedShuffleManagerVersion(
autoTuner: AutoTuner,
expectedSmVersion: String): Unit = {
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
assert(smClassName == ProfilingAutoTunerConfigsProvider
.buildShuffleManagerClassName(expectedSmVersion))
case Left(comment) =>
fail(s"Expected valid RapidsShuffleManager but got comment: $comment")
}
}

test("test shuffle manager version for supported databricks version") {
val databricksVersion = "11.3.x-gpu-ml-scala2.12"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin",
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> "11.3.x-gpu-ml-scala2.12"),
Some("3.3.0"), Seq())
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> databricksVersion),
Some(databricksVersion), Seq())
// Do not set the platform as DB to see if it can work correctly irrespective
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
val smVersion = autoTuner.getShuffleManagerClassName()
infoProvider, PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS))
// Assert shuffle manager string for DB 11.3 tag
assert(smVersion.get == "com.nvidia.spark.rapids.spark330db.RapidsShuffleManager")
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330db")
}

test("test shuffle manager version for non-databricks") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
test("test shuffle manager version for supported spark version") {
val sparkVersion = "3.3.0"
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some("3.3.0"), Seq())
Some(sparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
// Assert shuffle manager string for supported Spark v3.3.0
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330")
}

test("test shuffle manager version for supported custom spark version") {
val customSparkVersion = "3.3.0-custom"
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some(customSparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
// Assert shuffle manager string for supported custom Spark v3.3.0
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330")
}

/**
* Helper method to verify that the shuffle manager version is not recommended
* for the unsupported Spark version.
*/
private def verifyUnsupportedSparkVersionForShuffleManager(
autoTuner: AutoTuner,
sparkVersion: String): Unit = {
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName")
case Left(comment) =>
assert(comment == ProfilingAutoTunerConfigsProvider
.shuffleManagerCommentForUnsupportedVersion(sparkVersion, autoTuner.platform))
}
}

test("test shuffle manager version for unsupported databricks version") {
val databricksVersion = "9.1.x-gpu-ml-scala2.12"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin",
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> databricksVersion),
Some(databricksVersion), Seq())
// Do not set the platform as DB to see if it can work correctly irrespective
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS))
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, databricksVersion)
}

test("test shuffle manager version for unsupported spark version") {
val sparkVersion = "3.1.2"
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some(sparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
val smVersion = autoTuner.getShuffleManagerClassName()
assert(smVersion.get == "com.nvidia.spark.rapids.spark330.RapidsShuffleManager")
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, sparkVersion)
}

test("test shuffle manager version for unsupported custom spark version") {
val customSparkVersion = "3.1.2-custom"
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some(customSparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, customSparkVersion)
}

test("test shuffle manager version for missing spark version") {
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
None, Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
// Verify that the shuffle manager is not recommended for missing Spark version
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName")
case Left(comment) =>
assert(comment == ProfilingAutoTunerConfigsProvider.shuffleManagerCommentForMissingVersion)
}
}

test("Test spilling occurred in shuffle stages") {
Expand Down
Loading