Skip to content

Commit

Permalink
[FEA] Profiler autotuner should only specify standard Spark versions …
Browse files Browse the repository at this point in the history
…for shuffle manager setting (#662)
  • Loading branch information
kuhushukla authored Dec 28, 2023
1 parent 40b1b9e commit 42a0fa5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,8 @@ class AutoTuner(
}

def calculateJobLevelRecommendations(): Unit = {
val shuffleManagerVersion = appInfoProvider.getSparkVersion.get.filterNot("().".toSet)
appendRecommendation("spark.shuffle.manager",
"com.nvidia.spark.rapids.spark" + shuffleManagerVersion + ".RapidsShuffleManager")
val smClassName = getShuffleManagerClassName
appendRecommendation("spark.shuffle.manager", smClassName)
appendComment(classPathComments("rapids.shuffle.jars"))

recommendFileCache()
Expand All @@ -603,6 +602,22 @@ class AutoTuner(
recommendClassPathEntries()
}

def getShuffleManagerClassName() : String = {
val shuffleManagerVersion = appInfoProvider.getSparkVersion.get.filterNot("().".toSet)
val dbVersion = appInfoProvider.getProperty(
"spark.databricks.clusterUsageTags.sparkVersion").getOrElse("")
val finalShuffleVersion : String = if (dbVersion.nonEmpty) {
dbVersion match {
case ver if ver.contains("10.4") => "321db"
case ver if ver.contains("11.3") => "330db"
case _ => "332db"
}
} else {
shuffleManagerVersion
}
"com.nvidia.spark.rapids.spark" + finalShuffleVersion + ".RapidsShuffleManager"
}

/**
* Checks whether the cluster properties are valid.
* If the cluster worker-info is missing entries (i.e., CPU and GPU count), it sets the entries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class AutoTunerSuite extends FunSuite with BeforeAndAfterEach with Logging {
val systemProperties = customProps match {
case None => mutable.Map[String, String]()
case Some(newProps) => newProps
}
}
val convertedMap = new util.LinkedHashMap[String, String](systemProperties.asJava)
val clusterProps = new ClusterProperties(cpuSystem, gpuWorkerProps, convertedMap)
// set the options to convert the object into formatted yaml content
Expand Down Expand Up @@ -654,7 +654,7 @@ class AutoTunerSuite extends FunSuite with BeforeAndAfterEach with Logging {
assert(expectedResults == autoTunerOutput)
}

test("test AutoTuner with empty sparkProperties" ) {
test("test AutoTuner with empty sparkProperties") {
val dataprocWorkerInfo = buildWorkerInfoAsString(None)
val expectedResults =
s"""|
Expand Down Expand Up @@ -1507,4 +1507,31 @@ class AutoTunerSuite extends FunSuite with BeforeAndAfterEach with Logging {
// scalastyle:on line.size.limit
assert(expectedResults == autoTunerOutput)
}

test("test shuffle manager version for databricks") {
val databricksWorkerInfo = buildWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin",
"spark.databricks.clusterUsageTags.sparkVersion" -> "11.3.x-gpu-ml-scala2.12"),
Some("3.3.0"), Seq())
// Do not set the platform as DB to see if it can work correctly irrespective
val autoTuner = AutoTuner.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
val smVersion = autoTuner.getShuffleManagerClassName()
// Assert shuffle manager string for DB 11.3 tag
assert(smVersion == "com.nvidia.spark.rapids.spark330db.RapidsShuffleManager")
}

test("test shuffle manager version for non-databricks") {
val databricksWorkerInfo = buildWorkerInfoAsString(None)
val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0),
mutable.Map("spark.rapids.sql.enabled" -> "true",
"spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"),
Some("3.3.0"), Seq())
val autoTuner = AutoTuner.buildAutoTunerFromProps(databricksWorkerInfo,
infoProvider, PlatformFactory.createInstance())
val smVersion = autoTuner.getShuffleManagerClassName()
assert(smVersion == "com.nvidia.spark.rapids.spark330.RapidsShuffleManager")
}
}

0 comments on commit 42a0fa5

Please sign in to comment.