Skip to content

Commit

Permalink
Improve shuffle manager recommendation in AutoTuner with version vali…
Browse files Browse the repository at this point in the history
…dation (#1483)

* Add checks for supported RapidsShuffleManager versions

Signed-off-by: Partho Sarthi <[email protected]>

* Minor refactoring

Signed-off-by: Partho Sarthi <[email protected]>

* Add doc url in error comment

Signed-off-by: Partho Sarthi <[email protected]>

* Refactor to make shuffle manager validation platform specific

Signed-off-by: Partho Sarthi <[email protected]>

* Fix line length

Signed-off-by: Partho Sarthi <[email protected]>

* Update comment for unsupported version

Signed-off-by: Partho Sarthi <[email protected]>

* Refactor and update comments

Signed-off-by: Partho Sarthi <[email protected]>

* Revert "Refactor and update comments"

This reverts commit 2a84eda.

* Update comment to include an example config

Signed-off-by: Partho Sarthi <[email protected]>

---------

Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa authored Jan 7, 2025
1 parent 21c0ab9 commit dccf8b8
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 40 deletions.
51 changes: 50 additions & 1 deletion core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -121,6 +121,7 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
val clusterProperties: Option[ClusterProperties]) extends Logging {
val platformName: String
val defaultGpuDevice: GpuDevice
val sparkVersionLabel: String = "Spark version"

// It's not deal to use vars here but to minimize changes and
// keep backwards compatibility we put them here for now and hopefully
Expand All @@ -139,6 +140,46 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS
)

// scalastyle:off line.size.limit
// Supported Spark version to RapidsShuffleManager version mapping.
// Reference: https://docs.nvidia.com/spark-rapids/user-guide/latest/additional-functionality/rapids-shuffle.html#rapids-shuffle-manager
// scalastyle:on line.size.limit
val supportedShuffleManagerVersionMap: Array[(String, String)] = Array(
"3.2.0" -> "320",
"3.2.1" -> "321",
"3.2.2" -> "322",
"3.2.3" -> "323",
"3.2.4" -> "324",
"3.3.0" -> "330",
"3.3.1" -> "331",
"3.3.2" -> "332",
"3.3.3" -> "333",
"3.3.4" -> "334",
"3.4.0" -> "340",
"3.4.1" -> "341",
"3.4.2" -> "342",
"3.4.3" -> "343",
"3.5.0" -> "350",
"3.5.1" -> "351"
)

/**
* Determine the appropriate RapidsShuffleManager version based on the
* provided spark version.
*/
def getShuffleManagerVersion(sparkVersion: String): Option[String] = {
supportedShuffleManagerVersionMap.collectFirst {
case (supportedVersion, smVersion) if sparkVersion.contains(supportedVersion) => smVersion
}
}

/**
* Identify the latest supported Spark and RapidsShuffleManager version for the platform.
*/
lazy val latestSupportedShuffleManagerInfo: (String, String) = {
supportedShuffleManagerVersionMap.maxBy(_._1)
}

/**
* Checks if the given runtime is supported by the platform.
*/
Expand Down Expand Up @@ -522,6 +563,7 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice],
clusterProperties: Option[ClusterProperties]) extends Platform(gpuDevice, clusterProperties) {
override val defaultGpuDevice: GpuDevice = T4Gpu
override val sparkVersionLabel: String = "Databricks runtime"
override def isPlatformCSP: Boolean = true

override val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set(
Expand All @@ -538,6 +580,13 @@ abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice],
"spark.executor.memoryOverhead"
)

// Supported Databricks version to RapidsShuffleManager version mapping.
override val supportedShuffleManagerVersionMap: Array[(String, String)] = Array(
"11.3" -> "330db",
"12.2" -> "332db",
"13.3" -> "341db"
)

override def createClusterInfo(coresPerExecutor: Int,
numExecsPerNode: Int,
numExecs: Int,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,7 +27,6 @@ import scala.util.control.NonFatal
import scala.util.matching.Regex

import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, GpuDevice, Platform, PlatformFactory}
import com.nvidia.spark.rapids.tool.planparser.DatabricksParseHelper
import com.nvidia.spark.rapids.tool.profiling._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path}
Expand Down Expand Up @@ -717,9 +716,9 @@ class AutoTuner(
def calculateJobLevelRecommendations(): Unit = {
// TODO - do we do anything with 200 shuffle partitions or maybe if its close
// set the Spark config spark.shuffle.sort.bypassMergeThreshold
getShuffleManagerClassName match {
case Some(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName)
case None => appendComment("Could not define the Spark Version")
getShuffleManagerClassName match {
case Right(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName)
case Left(comment) => appendComment(comment)
}
appendComment(autoTunerConfigsProvider.classPathComments("rapids.shuffle.jars"))
recommendFileCache()
Expand Down Expand Up @@ -752,30 +751,31 @@ class AutoTuner(
}
}

def getShuffleManagerClassName() : Option[String] = {
appInfoProvider.getSparkVersion.map { sparkVersion =>
val shuffleManagerVersion = sparkVersion.filterNot("().".toSet)
val dbVersion = getPropertyValue(
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY).getOrElse("")
val finalShuffleVersion : String = if (dbVersion.nonEmpty) {
dbVersion match {
case ver if ver.contains("10.4") => "321db"
case ver if ver.contains("11.3") => "330db"
case _ => "332db"
}
} else if (sparkVersion.contains("amzn")) {
sparkVersion match {
case ver if ver.contains("3.5.2") => "352"
case ver if ver.contains("3.5.1") => "351"
case ver if ver.contains("3.5.0") => "350"
case ver if ver.contains("3.4.1") => "341"
case ver if ver.contains("3.4.0") => "340"
case _ => "332"
/**
* Resolves the RapidsShuffleManager class name based on the Spark version.
* If a valid class name is not found, an error message is returned.
*
* Example:
* sparkVersion: "3.2.0-amzn-1"
* return: Right("com.nvidia.spark.rapids.spark320.RapidsShuffleManager")
*
* sparkVersion: "3.1.2"
* return: Left("Cannot recommend RAPIDS Shuffle Manager for unsupported '3.1.2' version.")
*
* @return Either an error message (Left) or the RapidsShuffleManager class name (Right)
*/
def getShuffleManagerClassName: Either[String, String] = {
appInfoProvider.getSparkVersion match {
case Some(sparkVersion) =>
platform.getShuffleManagerVersion(sparkVersion) match {
case Some(smVersion) =>
Right(autoTunerConfigsProvider.buildShuffleManagerClassName(smVersion))
case None =>
Left(autoTunerConfigsProvider.shuffleManagerCommentForUnsupportedVersion(
sparkVersion, platform))
}
} else {
shuffleManagerVersion
}
"com.nvidia.spark.rapids.spark" + finalShuffleVersion + ".RapidsShuffleManager"
case None =>
Left(autoTunerConfigsProvider.shuffleManagerCommentForMissingVersion)
}
}

Expand Down Expand Up @@ -1344,6 +1344,9 @@ trait AutoTunerConfigsProvider extends Logging {
// the plugin jar is in the form of rapids-4-spark_scala_binary-(version)-*.jar
val pluginJarRegEx: Regex = "rapids-4-spark_\\d\\.\\d+-(\\d{2}\\.\\d{2}\\.\\d+).*\\.jar".r

private val shuffleManagerDocUrl = "https://docs.nvidia.com/spark-rapids/user-guide/latest/" +
"additional-functionality/rapids-shuffle.html#rapids-shuffle-manager"

/**
* Abstract method to create an instance of the AutoTuner.
*/
Expand Down Expand Up @@ -1460,6 +1463,27 @@ trait AutoTunerConfigsProvider extends Logging {
case _ => true
}
}

def buildShuffleManagerClassName(smVersion: String): String = {
s"com.nvidia.spark.rapids.spark$smVersion.RapidsShuffleManager"
}

def shuffleManagerCommentForUnsupportedVersion(
sparkVersion: String, platform: Platform): String = {
val (latestSparkVersion, latestSmVersion) = platform.latestSupportedShuffleManagerInfo
// scalastyle:off line.size.limit
s"""
|Cannot recommend RAPIDS Shuffle Manager for unsupported ${platform.sparkVersionLabel}: '$sparkVersion'.
|To enable RAPIDS Shuffle Manager, use a supported ${platform.sparkVersionLabel} (e.g., '$latestSparkVersion')
|and set: '--conf spark.shuffle.manager=com.nvidia.spark.rapids.spark$latestSmVersion.RapidsShuffleManager'.
|See supported versions: $shuffleManagerDocUrl.
|""".stripMargin.trim.replaceAll("\n", "\n ")
// scalastyle:on line.size.limit
}

def shuffleManagerCommentForMissingVersion: String = {
"Could not recommend RapidsShuffleManager as Spark version cannot be determined."
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2183,33 +2183,139 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
assert(expectedResults == autoTunerOutput)
}

test("test shuffle manager version for databricks") {
/**
* Helper method to verify that the recommended shuffle manager version matches the
* expected version.
*/
private def verifyRecommendedShuffleManagerVersion(
autoTuner: AutoTuner,
expectedSmVersion: String): Unit = {
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
assert(smClassName == ProfilingAutoTunerConfigsProvider
.buildShuffleManagerClassName(expectedSmVersion))
case Left(comment) =>
fail(s"Expected valid RapidsShuffleManager but got comment: $comment")
}
}

test("test shuffle manager version for supported databricks version") {
val databricksVersion = "11.3.x-gpu-ml-scala2.12"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin",
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> "11.3.x-gpu-ml-scala2.12"),
Some("3.3.0"), Seq())
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> databricksVersion),
Some(databricksVersion), Seq())
// Do not set the platform as DB to see if it can work correctly irrespective
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
val smVersion = autoTuner.getShuffleManagerClassName()
infoProvider, PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS))
// Assert shuffle manager string for DB 11.3 tag
assert(smVersion.get == "com.nvidia.spark.rapids.spark330db.RapidsShuffleManager")
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330db")
}

test("test shuffle manager version for non-databricks") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
test("test shuffle manager version for supported spark version") {
val sparkVersion = "3.3.0"
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some("3.3.0"), Seq())
Some(sparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
// Assert shuffle manager string for supported Spark v3.3.0
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330")
}

test("test shuffle manager version for supported custom spark version") {
val customSparkVersion = "3.3.0-custom"
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some(customSparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
// Assert shuffle manager string for supported custom Spark v3.3.0
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330")
}

/**
* Helper method to verify that the shuffle manager version is not recommended
* for the unsupported Spark version.
*/
private def verifyUnsupportedSparkVersionForShuffleManager(
autoTuner: AutoTuner,
sparkVersion: String): Unit = {
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName")
case Left(comment) =>
assert(comment == ProfilingAutoTunerConfigsProvider
.shuffleManagerCommentForUnsupportedVersion(sparkVersion, autoTuner.platform))
}
}

test("test shuffle manager version for unsupported databricks version") {
val databricksVersion = "9.1.x-gpu-ml-scala2.12"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin",
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> databricksVersion),
Some(databricksVersion), Seq())
// Do not set the platform as DB to see if it can work correctly irrespective
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS))
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, databricksVersion)
}

test("test shuffle manager version for unsupported spark version") {
val sparkVersion = "3.1.2"
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some(sparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
val smVersion = autoTuner.getShuffleManagerClassName()
assert(smVersion.get == "com.nvidia.spark.rapids.spark330.RapidsShuffleManager")
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, sparkVersion)
}

test("test shuffle manager version for unsupported custom spark version") {
val customSparkVersion = "3.1.2-custom"
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some(customSparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, customSparkVersion)
}

test("test shuffle manager version for missing spark version") {
val workerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
None, Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(workerInfo,
infoProvider, PlatformFactory.createInstance())
// Verify that the shuffle manager is not recommended for missing Spark version
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName")
case Left(comment) =>
assert(comment == ProfilingAutoTunerConfigsProvider.shuffleManagerCommentForMissingVersion)
}
}

test("Test spilling occurred in shuffle stages") {
Expand Down

0 comments on commit dccf8b8

Please sign in to comment.