diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala index 866e1fbbd..24cfbfeba 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,6 +121,7 @@ abstract class Platform(var gpuDevice: Option[GpuDevice], val clusterProperties: Option[ClusterProperties]) extends Logging { val platformName: String val defaultGpuDevice: GpuDevice + val sparkVersionLabel: String = "Spark version" // It's not deal to use vars here but to minimize changes and // keep backwards compatibility we put them here for now and hopefully @@ -139,6 +140,46 @@ abstract class Platform(var gpuDevice: Option[GpuDevice], SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS ) + // scalastyle:off line.size.limit + // Supported Spark version to RapidsShuffleManager version mapping. + // Reference: https://docs.nvidia.com/spark-rapids/user-guide/latest/additional-functionality/rapids-shuffle.html#rapids-shuffle-manager + // scalastyle:on line.size.limit + val supportedShuffleManagerVersionMap: Array[(String, String)] = Array( + "3.2.0" -> "320", + "3.2.1" -> "321", + "3.2.2" -> "322", + "3.2.3" -> "323", + "3.2.4" -> "324", + "3.3.0" -> "330", + "3.3.1" -> "331", + "3.3.2" -> "332", + "3.3.3" -> "333", + "3.3.4" -> "334", + "3.4.0" -> "340", + "3.4.1" -> "341", + "3.4.2" -> "342", + "3.4.3" -> "343", + "3.5.0" -> "350", + "3.5.1" -> "351" + ) + + /** + * Determine the appropriate RapidsShuffleManager version based on the + * provided spark version. + */ + def getShuffleManagerVersion(sparkVersion: String): Option[String] = { + supportedShuffleManagerVersionMap.collectFirst { + case (supportedVersion, smVersion) if sparkVersion.contains(supportedVersion) => smVersion + } + } + + /** + * Identify the latest supported Spark and RapidsShuffleManager version for the platform. + */ + lazy val latestSupportedShuffleManagerInfo: (String, String) = { + supportedShuffleManagerVersionMap.maxBy(_._1) + } + /** * Checks if the given runtime is supported by the platform. */ @@ -522,6 +563,7 @@ abstract class Platform(var gpuDevice: Option[GpuDevice], abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice], clusterProperties: Option[ClusterProperties]) extends Platform(gpuDevice, clusterProperties) { override val defaultGpuDevice: GpuDevice = T4Gpu + override val sparkVersionLabel: String = "Databricks runtime" override def isPlatformCSP: Boolean = true override val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set( @@ -538,6 +580,13 @@ abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice], "spark.executor.memoryOverhead" ) + // Supported Databricks version to RapidsShuffleManager version mapping. + override val supportedShuffleManagerVersionMap: Array[(String, String)] = Array( + "11.3" -> "330db", + "12.2" -> "332db", + "13.3" -> "341db" + ) + override def createClusterInfo(coresPerExecutor: Int, numExecsPerNode: Int, numExecs: Int, diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala index 7471491e4..048e71ada 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/tuning/AutoTuner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,6 @@ import scala.util.control.NonFatal import scala.util.matching.Regex import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, GpuDevice, Platform, PlatformFactory} -import com.nvidia.spark.rapids.tool.planparser.DatabricksParseHelper import com.nvidia.spark.rapids.tool.profiling._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} @@ -717,9 +716,9 @@ class AutoTuner( def calculateJobLevelRecommendations(): Unit = { // TODO - do we do anything with 200 shuffle partitions or maybe if its close // set the Spark config spark.shuffle.sort.bypassMergeThreshold - getShuffleManagerClassName match { - case Some(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName) - case None => appendComment("Could not define the Spark Version") + getShuffleManagerClassName match { + case Right(smClassName) => appendRecommendation("spark.shuffle.manager", smClassName) + case Left(comment) => appendComment(comment) } appendComment(autoTunerConfigsProvider.classPathComments("rapids.shuffle.jars")) recommendFileCache() @@ -752,30 +751,31 @@ class AutoTuner( } } - def getShuffleManagerClassName() : Option[String] = { - appInfoProvider.getSparkVersion.map { sparkVersion => - val shuffleManagerVersion = sparkVersion.filterNot("().".toSet) - val dbVersion = getPropertyValue( - DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY).getOrElse("") - val finalShuffleVersion : String = if (dbVersion.nonEmpty) { - dbVersion match { - case ver if ver.contains("10.4") => "321db" - case ver if ver.contains("11.3") => "330db" - case _ => "332db" - } - } else if (sparkVersion.contains("amzn")) { - sparkVersion match { - case ver if ver.contains("3.5.2") => "352" - case ver if ver.contains("3.5.1") => "351" - case ver if ver.contains("3.5.0") => "350" - case ver if ver.contains("3.4.1") => "341" - case ver if ver.contains("3.4.0") => "340" - case _ => "332" + /** + * Resolves the RapidsShuffleManager class name based on the Spark version. + * If a valid class name is not found, an error message is returned. + * + * Example: + * sparkVersion: "3.2.0-amzn-1" + * return: Right("com.nvidia.spark.rapids.spark320.RapidsShuffleManager") + * + * sparkVersion: "3.1.2" + * return: Left("Cannot recommend RAPIDS Shuffle Manager for unsupported '3.1.2' version.") + * + * @return Either an error message (Left) or the RapidsShuffleManager class name (Right) + */ + def getShuffleManagerClassName: Either[String, String] = { + appInfoProvider.getSparkVersion match { + case Some(sparkVersion) => + platform.getShuffleManagerVersion(sparkVersion) match { + case Some(smVersion) => + Right(autoTunerConfigsProvider.buildShuffleManagerClassName(smVersion)) + case None => + Left(autoTunerConfigsProvider.shuffleManagerCommentForUnsupportedVersion( + sparkVersion, platform)) } - } else { - shuffleManagerVersion - } - "com.nvidia.spark.rapids.spark" + finalShuffleVersion + ".RapidsShuffleManager" + case None => + Left(autoTunerConfigsProvider.shuffleManagerCommentForMissingVersion) } } @@ -1344,6 +1344,9 @@ trait AutoTunerConfigsProvider extends Logging { // the plugin jar is in the form of rapids-4-spark_scala_binary-(version)-*.jar val pluginJarRegEx: Regex = "rapids-4-spark_\\d\\.\\d+-(\\d{2}\\.\\d{2}\\.\\d+).*\\.jar".r + private val shuffleManagerDocUrl = "https://docs.nvidia.com/spark-rapids/user-guide/latest/" + + "additional-functionality/rapids-shuffle.html#rapids-shuffle-manager" + /** * Abstract method to create an instance of the AutoTuner. */ @@ -1460,6 +1463,27 @@ trait AutoTunerConfigsProvider extends Logging { case _ => true } } + + def buildShuffleManagerClassName(smVersion: String): String = { + s"com.nvidia.spark.rapids.spark$smVersion.RapidsShuffleManager" + } + + def shuffleManagerCommentForUnsupportedVersion( + sparkVersion: String, platform: Platform): String = { + val (latestSparkVersion, latestSmVersion) = platform.latestSupportedShuffleManagerInfo + // scalastyle:off line.size.limit + s""" + |Cannot recommend RAPIDS Shuffle Manager for unsupported ${platform.sparkVersionLabel}: '$sparkVersion'. + |To enable RAPIDS Shuffle Manager, use a supported ${platform.sparkVersionLabel} (e.g., '$latestSparkVersion') + |and set: '--conf spark.shuffle.manager=com.nvidia.spark.rapids.spark$latestSmVersion.RapidsShuffleManager'. + |See supported versions: $shuffleManagerDocUrl. + |""".stripMargin.trim.replaceAll("\n", "\n ") + // scalastyle:on line.size.limit + } + + def shuffleManagerCommentForMissingVersion: String = { + "Could not recommend RapidsShuffleManager as Spark version cannot be determined." + } } /** diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/tuning/ProfilingAutoTunerSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/tuning/ProfilingAutoTunerSuite.scala index bf3f6c61b..8b1007a9c 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/tuning/ProfilingAutoTunerSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/tuning/ProfilingAutoTunerSuite.scala @@ -2183,33 +2183,139 @@ We recommend using nodes/workers with more memory. Need at least 7796MB memory." assert(expectedResults == autoTunerOutput) } - test("test shuffle manager version for databricks") { + /** + * Helper method to verify that the recommended shuffle manager version matches the + * expected version. + */ + private def verifyRecommendedShuffleManagerVersion( + autoTuner: AutoTuner, + expectedSmVersion: String): Unit = { + autoTuner.getShuffleManagerClassName match { + case Right(smClassName) => + assert(smClassName == ProfilingAutoTunerConfigsProvider + .buildShuffleManagerClassName(expectedSmVersion)) + case Left(comment) => + fail(s"Expected valid RapidsShuffleManager but got comment: $comment") + } + } + + test("test shuffle manager version for supported databricks version") { + val databricksVersion = "11.3.x-gpu-ml-scala2.12" val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), mutable.Map("spark.rapids.sql.enabled" -> "true", "spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin", - DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> "11.3.x-gpu-ml-scala2.12"), - Some("3.3.0"), Seq()) + DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> databricksVersion), + Some(databricksVersion), Seq()) // Do not set the platform as DB to see if it can work correctly irrespective val autoTuner = ProfilingAutoTunerConfigsProvider .buildAutoTunerFromProps(databricksWorkerInfo, - infoProvider, PlatformFactory.createInstance()) - val smVersion = autoTuner.getShuffleManagerClassName() + infoProvider, PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS)) // Assert shuffle manager string for DB 11.3 tag - assert(smVersion.get == "com.nvidia.spark.rapids.spark330db.RapidsShuffleManager") + verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330db") } - test("test shuffle manager version for non-databricks") { - val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) + test("test shuffle manager version for supported spark version") { + val sparkVersion = "3.3.0" + val workerInfo = buildGpuWorkerInfoAsString(None) val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), mutable.Map("spark.rapids.sql.enabled" -> "true", "spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"), - Some("3.3.0"), Seq()) + Some(sparkVersion), Seq()) + val autoTuner = ProfilingAutoTunerConfigsProvider + .buildAutoTunerFromProps(workerInfo, + infoProvider, PlatformFactory.createInstance()) + // Assert shuffle manager string for supported Spark v3.3.0 + verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330") + } + + test("test shuffle manager version for supported custom spark version") { + val customSparkVersion = "3.3.0-custom" + val workerInfo = buildGpuWorkerInfoAsString(None) + val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), + mutable.Map("spark.rapids.sql.enabled" -> "true", + "spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"), + Some(customSparkVersion), Seq()) + val autoTuner = ProfilingAutoTunerConfigsProvider + .buildAutoTunerFromProps(workerInfo, + infoProvider, PlatformFactory.createInstance()) + // Assert shuffle manager string for supported custom Spark v3.3.0 + verifyRecommendedShuffleManagerVersion(autoTuner, expectedSmVersion="330") + } + + /** + * Helper method to verify that the shuffle manager version is not recommended + * for the unsupported Spark version. + */ + private def verifyUnsupportedSparkVersionForShuffleManager( + autoTuner: AutoTuner, + sparkVersion: String): Unit = { + autoTuner.getShuffleManagerClassName match { + case Right(smClassName) => + fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName") + case Left(comment) => + assert(comment == ProfilingAutoTunerConfigsProvider + .shuffleManagerCommentForUnsupportedVersion(sparkVersion, autoTuner.platform)) + } + } + + test("test shuffle manager version for unsupported databricks version") { + val databricksVersion = "9.1.x-gpu-ml-scala2.12" + val databricksWorkerInfo = buildGpuWorkerInfoAsString(None) + val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), + mutable.Map("spark.rapids.sql.enabled" -> "true", + "spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin", + DatabricksParseHelper.PROP_TAG_CLUSTER_SPARK_VERSION_KEY -> databricksVersion), + Some(databricksVersion), Seq()) + // Do not set the platform as DB to see if it can work correctly irrespective val autoTuner = ProfilingAutoTunerConfigsProvider .buildAutoTunerFromProps(databricksWorkerInfo, + infoProvider, PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS)) + verifyUnsupportedSparkVersionForShuffleManager(autoTuner, databricksVersion) + } + + test("test shuffle manager version for unsupported spark version") { + val sparkVersion = "3.1.2" + val workerInfo = buildGpuWorkerInfoAsString(None) + val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), + mutable.Map("spark.rapids.sql.enabled" -> "true", + "spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"), + Some(sparkVersion), Seq()) + val autoTuner = ProfilingAutoTunerConfigsProvider + .buildAutoTunerFromProps(workerInfo, infoProvider, PlatformFactory.createInstance()) - val smVersion = autoTuner.getShuffleManagerClassName() - assert(smVersion.get == "com.nvidia.spark.rapids.spark330.RapidsShuffleManager") + verifyUnsupportedSparkVersionForShuffleManager(autoTuner, sparkVersion) + } + + test("test shuffle manager version for unsupported custom spark version") { + val customSparkVersion = "3.1.2-custom" + val workerInfo = buildGpuWorkerInfoAsString(None) + val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), + mutable.Map("spark.rapids.sql.enabled" -> "true", + "spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"), + Some(customSparkVersion), Seq()) + val autoTuner = ProfilingAutoTunerConfigsProvider + .buildAutoTunerFromProps(workerInfo, + infoProvider, PlatformFactory.createInstance()) + verifyUnsupportedSparkVersionForShuffleManager(autoTuner, customSparkVersion) + } + + test("test shuffle manager version for missing spark version") { + val workerInfo = buildGpuWorkerInfoAsString(None) + val infoProvider = getMockInfoProvider(0, Seq(0), Seq(0.0), + mutable.Map("spark.rapids.sql.enabled" -> "true", + "spark.plugins" -> "com.nvidia.spark.AnotherPlugin, com.nvidia.spark.SQLPlugin"), + None, Seq()) + val autoTuner = ProfilingAutoTunerConfigsProvider + .buildAutoTunerFromProps(workerInfo, + infoProvider, PlatformFactory.createInstance()) + // Verify that the shuffle manager is not recommended for missing Spark version + autoTuner.getShuffleManagerClassName match { + case Right(smClassName) => + fail(s"Expected error comment but got valid RapidsShuffleManager: $smClassName") + case Left(comment) => + assert(comment == ProfilingAutoTunerConfigsProvider.shuffleManagerCommentForMissingVersion) + } } test("Test spilling occurred in shuffle stages") {