Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve shuffle manager recommendation in AutoTuner with version validation #1483

Merged
merged 9 commits into from
Jan 7, 2025
144 changes: 115 additions & 29 deletions core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -264,6 +264,110 @@ class RecommendationEntry(val name: String,
}
}

// scalastyle:off line.size.limit
/**
* Resolves the appropriate RapidsShuffleManager class name based on Spark or Databricks version.
*
* Note:
* - Supported RapidsShuffleManagers: https://docs.nvidia.com/spark-rapids/user-guide/latest/additional-functionality/rapids-shuffle.html#rapids-shuffle-manager
* - Version mappings need to be updated as new versions are supported.
* - This can be extended to support more version mappings (e.g. Cloudera).
*/
// scalastyle:on line.size.limit
object ShuffleManagerResolver {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kinda platform specific thing, right. Why cannot we have it in the Platform class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I will refactor this to make it platform specific.

// Supported Databricks version to RapidsShuffleManager version mapping.
private val supportedDatabricksVersionMap = Array(
"11.3" -> "330db",
"12.3" -> "332db",
"13.3" -> "341db"
)

// Supported Spark version to RapidsShuffleManager version mapping.
private val supportedSparkVersionMap = Array(
"3.2.0" -> "320",
"3.2.1" -> "321",
"3.2.2" -> "322",
"3.2.3" -> "323",
"3.2.4" -> "324",
"3.3.0" -> "330",
"3.3.1" -> "331",
"3.3.2" -> "332",
"3.3.3" -> "333",
"3.3.4" -> "334",
"3.4.0" -> "340",
"3.4.1" -> "341",
"3.4.2" -> "342",
"3.4.3" -> "343",
"3.5.0" -> "350",
"3.5.1" -> "351"
)

private val shuffleManagerDocUrl = "https://docs.nvidia.com/spark-rapids/user-guide/latest/" +
"additional-functionality/rapids-shuffle.html#rapids-shuffle-manager"

def buildShuffleManagerClassName(smVersion: String): String = {
s"com.nvidia.spark.rapids.spark$smVersion.RapidsShuffleManager"
}

def commentForUnsupportedVersion(sparkVersion: String): String = {
s"Cannot recommend RAPIDS Shuffle Manager for unsupported \'$sparkVersion\' version.\n" +
s" See supported versions: $shuffleManagerDocUrl."
}

def commentForMissingVersion: String = {
"Could not recommend RapidsShuffleManager as neither Spark nor Databricks version is provided."
}

/**
* Internal method to determine the appropriate RapidsShuffleManager class name based on the
* provided databricks or spark version.
*
* Example:
* sparkVersion: "3.2.0-amzn-1"
* supportedVersionsMap: ["3.2.0" -> "320", "3.2.1" -> "321"]
* return: Right("com.nvidia.spark.rapids.spark320.RapidsShuffleManager")
*
* sparkVersion: "13.3.x-gpu-ml-scala2.12"
* supportedVersionsMap: ["11.3" -> "330db", "12.3" -> "332db", "13.3" -> "341db"]
* return: Right("com.nvidia.spark.rapids.spark341db.RapidsShuffleManager")
*
* sparkVersion: "3.1.2"
* supportedVersionsMap: ["3.2.0" -> "320", "3.2.1" -> "321"]
* return: Left("Could not recommend RapidsShuffleManager as the provided version
* 3.1.2 is not supported.")
*
* @return Either an error message (Left) or the RapidsShuffleManager class name (Right)
*/
private def getClassNameInternal(
supportedVersionsMap: Array[(String, String)],
sparkVersion: String): Either[String, String] = {
supportedVersionsMap.collectFirst {
case (supportedVersion, smVersion) if sparkVersion.contains(supportedVersion) => smVersion
} match {
case Some(smVersion) => Right(buildShuffleManagerClassName(smVersion))
case None => Left(commentForUnsupportedVersion(sparkVersion))
}
}

/**
* Determines the appropriate RapidsShuffleManager class name based on the provided Databricks or
* Spark version. Databricks version takes precedence over Spark version. If a valid class name
* is not found, an error message is returned.
*
* @param dbVersion Databricks version.
* @param sparkVersion Spark version.
* @return Either an error message (Left) or the RapidsShuffleManager class name (Right)
*/
def getClassName(
dbVersion: Option[String], sparkVersion: Option[String]): Either[String, String] = {
(dbVersion, sparkVersion) match {
case (Some(dbVer), _) => getClassNameInternal(supportedDatabricksVersionMap, dbVer)
case (None, Some(sparkVer)) => getClassNameInternal(supportedSparkVersionMap, sparkVer)
case _ => Left(commentForMissingVersion)
}
}
}

/**
* AutoTuner module that uses event logs and worker's system properties to recommend Spark
* RAPIDS configuration based on heuristics.
Expand Down Expand Up @@ -717,9 +821,9 @@ class AutoTuner(
def calculateJobLevelRecommendations(): Unit = {
// TODO - do we do anything with 200 shuffle partitions or maybe if its close
// set the Spark config spark.shuffle.sort.bypassMergeThreshold
getShuffleManagerClassName match {
case Some(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName)
case None => appendComment("Could not define the Spark Version")
getShuffleManagerClassName match {
case Right(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName)
case Left(comment) => appendComment(comment)
}
appendComment(autoTunerConfigsProvider.classPathComments("rapids.shuffle.jars"))
recommendFileCache()
Expand Down Expand Up @@ -752,31 +856,13 @@ class AutoTuner(
}
}

def getShuffleManagerClassName() : Option[String] = {
appInfoProvider.getSparkVersion.map { sparkVersion =>
val shuffleManagerVersion = sparkVersion.filterNot("().".toSet)
val dbVersion = getPropertyValue(
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY).getOrElse("")
val finalShuffleVersion : String = if (dbVersion.nonEmpty) {
dbVersion match {
case ver if ver.contains("10.4") => "321db"
case ver if ver.contains("11.3") => "330db"
case _ => "332db"
}
} else if (sparkVersion.contains("amzn")) {
sparkVersion match {
case ver if ver.contains("3.5.2") => "352"
case ver if ver.contains("3.5.1") => "351"
case ver if ver.contains("3.5.0") => "350"
case ver if ver.contains("3.4.1") => "341"
case ver if ver.contains("3.4.0") => "340"
case _ => "332"
}
} else {
shuffleManagerVersion
}
"com.nvidia.spark.rapids.spark" + finalShuffleVersion + ".RapidsShuffleManager"
}
/**
* Resolves the RapidsShuffleManager class name based on the Databricks or Spark version.
* If a valid class name is not found, an error message is returned.
*/
def getShuffleManagerClassName: Either[String, String] = {
val dbVersion = getPropertyValue(DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY)
ShuffleManagerResolver.getClassName(dbVersion, appInfoProvider.getSparkVersion)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2183,7 +2183,23 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
assert(expectedResults == autoTunerOutput)
}

test("test shuffle manager version for databricks") {
/**
* Helper method to verify that the recommended shuffle manager version matches the
* expected version.
*/
private def verifyRecommendedShuffleManagerVersion(
autoTuner: AutoTuner,
expectedSmVersion: String): Unit = {
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
assert(smClassName ==
ShuffleManagerResolver.buildShuffleManagerClassName(expectedSmVersion))
case Left(comment) =>
fail(s"Expected valid RapidsShuffleManager but got comment: $comment")
}
}

test("test shuffle manager version for supported databricks version") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
Expand All @@ -2194,12 +2210,11 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
val smVersion = autoTuner.getShuffleManagerClassName()
// Assert shuffle manager string for DB 11.3 tag
assert(smVersion.get == "com.nvidia.spark.rapids.spark330db.RapidsShuffleManager")
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330db")
}

test("test shuffle manager version for non-databricks") {
test("test shuffle manager version for supported spark version") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
Expand All @@ -2208,8 +2223,95 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
val smVersion = autoTuner.getShuffleManagerClassName()
assert(smVersion.get == "com.nvidia.spark.rapids.spark330.RapidsShuffleManager")
// Assert shuffle manager string for supported Spark v3.3.0
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330")
}

test("test shuffle manager version for supported custom spark version") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some("3.3.0-custom"), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
// Assert shuffle manager string for supported custom Spark v3.3.0
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330")
}

/**
* Helper method to verify that the shuffle manager version is not recommended
* for the unsupported Spark version.
*/
private def verifyUnsupportedSparkVersionForShuffleManager(
autoTuner: AutoTuner,
sparkVersion: String): Unit = {
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName")
case Left(comment) =>
assert(comment == ShuffleManagerResolver.commentForUnsupportedVersion(sparkVersion))
}
}

test("test shuffle manager version for unsupported databricks version") {
val databricksVersion = "9.1.x-gpu-ml-scala2.12"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin",
DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> databricksVersion),
Some(databricksVersion), Seq())
// Do not set the platform as DB to see if it can work correctly irrespective
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, databricksVersion)
}

test("test shuffle manager version for unsupported spark version") {
val sparkVersion = "3.1.2"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some(sparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, sparkVersion)
}

test("test shuffle manager version for unsupported custom spark version") {
val customSparkVersion = "3.1.2-custom"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some(customSparkVersion), Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, customSparkVersion)
}

test("test shuffle manager version for missing spark version") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
None, Seq())
val autoTuner = ProfilingAutoTunerConfigsProvider
.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
// Verify that the shuffle manager is not recommended for missing Spark version
autoTuner.getShuffleManagerClassName match {
case Right(smClassName) =>
fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName")
case Left(comment) =>
assert(comment == ShuffleManagerResolver.commentForMissingVersion)
}
}

test("Test spilling occurred in shuffle stages") {
Expand Down
Loading