Skip to content

Commit

Permalink
Add test and address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Kuhu Shukla <[email protected]>
  • Loading branch information
kuhushukla committed Dec 26, 2023
1 parent e8f3788 commit deea588
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -591,20 +591,8 @@ class AutoTuner(
}

def calculateJobLevelRecommendations(): Unit = {
val shuffleManagerVersion = appInfoProvider.getSparkVersion.get.filterNot("().".toSet)
val finalShuffleVersion = if (platform.contains("databricks")) {
val dbVersion = appInfoProvider.getProperty(
"spark.databricks.clusterUsageTags.sparkVersion").getOrElse("")
if (dbVersion.contains("10.4")) {
"321db"
} else if (dbVersion.contains("11.3")) {
"330db"
} else {
"332db"
}
} else shuffleManagerVersion
appendRecommendation("spark.shuffle.manager",
"com.nvidia.spark.rapids.spark" + finalShuffleVersion + ".RapidsShuffleManager")
val smClassName = getShuffleManagerClassName
appendRecommendation("spark.shuffle.manager", smClassName)
appendComment(classPathComments("rapids.shuffle.jars"))

recommendFileCache()
Expand All @@ -614,6 +602,20 @@ class AutoTuner(
recommendClassPathEntries()
}

def getShuffleManagerClassName() : String = {
val shuffleManagerVersion = appInfoProvider.getSparkVersion.get.filterNot("().".toSet)
val finalShuffleVersion : String = if (platform.contains("databricks")) {
val dbVersion = appInfoProvider.getProperty(
"spark.databricks.clusterUsageTags.sparkVersion").getOrElse("")
dbVersion match {
case ver if ver.contains("10.4") => "321db"
case ver if ver.contains("11.3") => "330db"
case _ => "332db"
}
} else shuffleManagerVersion
"com.nvidia.spark.rapids.spark" + finalShuffleVersion + ".RapidsShuffleManager"
}

/**
* Checks whether the cluster properties are valid.
* If the cluster worker-info is missing entries (i.e., CPU and GPU count), it sets the entries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1507,4 +1507,19 @@ class AutoTunerSuite extends FunSuite with BeforeAndAfterEach with Logging {
// scalastyle:on line.size.limit
assert(expectedResults == autoTunerOutput)
}

test("test shuffle manager version for databricks") {
val customProps = mutable.LinkedHashMap(
"spark.databricks.clusterUsageTags.sparkVersion" -> "11.3.x-gpu-ml-scala2.12")
val databricksWorkerInfo = buildWorkerInfoAsString(Some(customProps))
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin",
"spark.databricks.clusterUsageTags.sparkVersion" -> "11.3.x-gpu-ml-scala2.12"),
Some("3.3.0"), Seq())
val autoTuner = AutoTuner.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, "databricks")
val smVersion = autoTuner.getShuffleManagerClassName()
// Assert shuffle manager string for DB 11.3 tag
assert(smVersion == "com.nvidia.spark.rapids.spark330db.RapidsShuffleManager")
}

0 comments on commit deea588

Please sign in to comment.