Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa committed Jan 2, 2025
1 parent 526a073 commit 1259534
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -275,15 +275,15 @@ class RecommendationEntry(val name: String,
*/
// scalastyle:on line.size.limit
object ShuffleManagerResolver {
// Databricks version to RapidsShuffleManager version mapping.
private val DatabricksVersionMap = Map(
// Supported Databricks version to RapidsShuffleManager version mapping.
private val supportedDatabricksVersionMap = Array(
"11.3" -> "330db",
"12.3" -> "332db",
"13.3" -> "341db"
)

// Spark version to RapidsShuffleManager version mapping.
private val SparkVersionMap = Map(
// Supported Spark version to RapidsShuffleManager version mapping.
private val supportedSparkVersionMap = Array(
"3.2.0" -> "320",
"3.2.1" -> "321",
"3.2.2" -> "322",
Expand All @@ -299,8 +299,7 @@ object ShuffleManagerResolver {
"3.4.2" -> "342",
"3.4.3" -> "343",
"3.5.0" -> "350",
"3.5.1" -> "351",
"4.0.0" -> "400"
"3.5.1" -> "351"
)

def buildShuffleManagerClassName(smVersion: String): String = {
Expand All @@ -322,36 +321,35 @@ object ShuffleManagerResolver {
*
* Example:
* sparkVersion: "3.2.0-amzn-1"
* versionMap: {"3.2.0" -> "320", "3.2.1" -> "321"}
* Then, smVersion: "320"
* supportedVersionsMap: ["3.2.0" -> "320", "3.2.1" -> "321"]
* return: Right("com.nvidia.spark.rapids.spark320.RapidsShuffleManager")
*
* sparkVersion: "13.3-ml-1"
* versionMap: {"11.3" -> "330db", "12.3" -> "332db", "13.3" -> "341db"}
* Then, smVersion: "341db"
* sparkVersion: "13.3.x-gpu-ml-scala2.12"
* supportedVersionsMap: ["11.3" -> "330db", "12.3" -> "332db", "13.3" -> "341db"]
* return: Right("com.nvidia.spark.rapids.spark341db.RapidsShuffleManager")
*
* sparkVersion: "3.1.2"
* versionMap: {"3.2.0" -> "320", "3.2.1" -> "321"}
* Then, smVersion: None
* supportedVersionsMap: ["3.2.0" -> "320", "3.2.1" -> "321"]
* return: Left("Could not recommend RapidsShuffleManager as the provided version
* 3.1.2 is not supported.")
*
* @return Either an error message (Left) or the RapidsShuffleManager class name (Right)
*/
private def getClassNameInternal(
versionMap: Map[String, String], sparkVersion: String): Either[String, String] = {
val smVersionOpt = versionMap.collectFirst {
case (key, value) if sparkVersion.contains(key) => value
}
smVersionOpt match {
case Some(smVersion) =>
Right(buildShuffleManagerClassName(smVersion))
case None =>
Left(commentForUnsupportedVersion(sparkVersion))
supportedVersionsMap: Array[(String, String)],
sparkVersion: String): Either[String, String] = {
supportedVersionsMap.collectFirst {
case (supportedVersion, smVersion) if sparkVersion.contains(supportedVersion) => smVersion
} match {
case Some(smVersion) => Right(buildShuffleManagerClassName(smVersion))
case None => Left(commentForUnsupportedVersion(sparkVersion))
}
}

/**
* Determines the appropriate RapidsShuffleManager class name based on the provided versions.
* Databricks version takes precedence over Spark version. If a valid class name is not found,
* an error message is returned.
* Determines the appropriate RapidsShuffleManager class name based on the provided Databricks or
* Spark version. Databricks version takes precedence over Spark version. If a valid class name
* is not found, an error message is returned.
*
* @param dbVersion Databricks version.
* @param sparkVersion Spark version.
Expand All @@ -360,12 +358,9 @@ object ShuffleManagerResolver {
def getClassName(
dbVersion: Option[String], sparkVersion: Option[String]): Either[String, String] = {
(dbVersion, sparkVersion) match {
case (Some(dbVer), _) =>
getClassNameInternal(DatabricksVersionMap, dbVer)
case (None, Some(sparkVer)) =>
getClassNameInternal(SparkVersionMap, sparkVer)
case _ =>
Left(commentForMissingVersion)
case (Some(dbVer), _) => getClassNameInternal(supportedDatabricksVersionMap, dbVer)
case (None, Some(sparkVer)) => getClassNameInternal(supportedSparkVersionMap, sparkVer)
case _ => Left(commentForMissingVersion)
}
}
}
Expand Down Expand Up @@ -824,10 +819,8 @@ class AutoTuner(
// TODO - do we do anything with 200 shuffle partitions or maybe if its close
// set the Spark config spark.shuffle.sort.bypassMergeThreshold
getShuffleManagerClassName match {
case Right(smClassName) =>
appendRecommendation("spark.shuffle.manager", smClassName)
case Left(errMessage) =>
appendComment(errMessage)
case Right(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName)
case Left(comment) => appendComment(comment)
}
appendComment(autoTunerConfigsProvider.classPathComments("rapids.shuffle.jars"))
recommendFileCache()
Expand Down Expand Up @@ -861,8 +854,8 @@ class AutoTuner(
}

/**
* Resolves the RapidsShuffleManager class name based on the Spark or Databricks version.
* If a valid class name is not found an error message is appended as a comment.
* Resolves the RapidsShuffleManager class name based on the Databricks or Spark version.
* If a valid class name is not found, an error message is returned.
*/
def getShuffleManagerClassName: Either[String, String] = {
val dbVersion = getPropertyValue(DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2191,14 +2191,15 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
autoTuner: AutoTuner,
expectedSmVersion: String): Unit = {
autoTuner.getShuffleManagerClassName match {
case Right(smVersion) =>
assert(smVersion == ShuffleManagerResolver.buildShuffleManagerClassName(expectedSmVersion))
case Right(smClassName) =>
assert(smClassName ==
ShuffleManagerResolver.buildShuffleManagerClassName(expectedSmVersion))
case Left(comment) =>
fail(s"Expected valid RapidsShuffleManager but got comment: $comment")
}
}

test("test shuffle manager version for supported databricks") {
test("test shuffle manager version for supported databricks version") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
Expand All @@ -2213,7 +2214,7 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330db")
}

test("test shuffle manager version for supported non-databricks") {
test("test shuffle manager version for supported spark version") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
Expand All @@ -2226,7 +2227,7 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330")
}

test("test shuffle manager version for supported custom version") {
test("test shuffle manager version for supported custom spark version") {
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
Expand All @@ -2247,14 +2248,14 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
autoTuner: AutoTuner,
sparkVersion: String): Unit = {
autoTuner.getShuffleManagerClassName match {
case Right(smVersion) =>
fail(s"Expected error comment but got valid RapidsShuffleManager with version $smVersion")
case Right(smClassName) =>
fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName")
case Left(comment) =>
assert(comment == ShuffleManagerResolver.commentForUnsupportedVersion(sparkVersion))
}
}

test("test shuffle manager version for unsupported databricks") {
test("test shuffle manager version for unsupported databricks version") {
val databricksVersion = "9.1.x-gpu-ml-scala2.12"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
Expand All @@ -2269,7 +2270,7 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, databricksVersion)
}

test("test shuffle manager version for unsupported non-databricks") {
test("test shuffle manager version for unsupported spark version") {
val sparkVersion = "3.1.2"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
Expand All @@ -2282,7 +2283,7 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
verifyUnsupportedSparkVersionForShuffleManager(autoTuner, sparkVersion)
}

test("test shuffle manager version for unsupported custom version") {
test("test shuffle manager version for unsupported custom spark version") {
val customSparkVersion = "3.1.2-custom"
val databricksWorkerInfo = buildGpuWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
Expand All @@ -2306,8 +2307,8 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory."
infoProvider, PlatformFactory.createInstance())
// Verify that the shuffle manager is not recommended for missing Spark version
autoTuner.getShuffleManagerClassName match {
case Right(smVersion) =>
fail(s"Expected error comment but got valid RapidsShuffleManager with version $smVersion")
case Right(smClassName) =>
fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName")
case Left(comment) =>
assert(comment == ShuffleManagerResolver.commentForMissingVersion)
}
Expand Down

0 comments on commit 1259534

Please sign in to comment.