Skip to content

Commit

Permalink
[SPARK-47764][CORE][SQL] Cleanup shuffle dependencies based on Shuffl…
Browse files Browse the repository at this point in the history
…eCleanupMode

### What changes were proposed in this pull request?
This change adds a new trait, `ShuffleCleanupMode` under `QueryExecution`, and two new configs, `spark.sql.shuffleDependency.skipMigration.enabled` and `spark.sql.shuffleDependency.fileCleanup.enabled`.

For Spark Connect query executions, `ShuffleCleanupMode` is controlled by the two new configs, and shuffle dependency cleanup are performed accordingly.

When `spark.sql.shuffleDependency.fileCleanup.enabled` is `true`, shuffle dependency files will be cleaned up at the end of query executions.

When `spark.sql.shuffleDependency.skipMigration.enabled` is `true`, shuffle dependencies will be skipped at the shuffle data migration for node decommissions.

### Why are the changes needed?
This is to: 1. speed up shuffle data migration at decommissions and 2. possibly (when file cleanup mode is enabled) release disk space occupied by unused shuffle files.

### Does this PR introduce _any_ user-facing change?
Yes. This change adds two new configs, `spark.sql.shuffleDependency.skipMigration.enabled` and `spark.sql.shuffleDependency.fileCleanup.enabled` to control the cleanup behaviors.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#45930 from bozhang2820/spark-47764.

Authored-by: Bo Zhang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
bozhang2820 authored and cloud-fan committed Apr 24, 2024
1 parent 461ffa1 commit c44493d
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.ExecuteHolder
import org.apache.spark.sql.connect.utils.MetricGenerator
import org.apache.spark.sql.execution.{LocalTableScanExec, SQLExecution}
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, RemoveShuffleFiles, SkipMigration, SQLExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils

Expand All @@ -58,11 +59,21 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
}
val planner = new SparkConnectPlanner(executeHolder)
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
val conf = session.sessionState.conf
val shuffleCleanupMode =
if (conf.getConf(SQLConf.SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED)) {
RemoveShuffleFiles
} else if (conf.getConf(SQLConf.SHUFFLE_DEPENDENCY_SKIP_MIGRATION_ENABLED)) {
SkipMigration
} else {
DoNotCleanup
}
val dataframe =
Dataset.ofRows(
sessionHolder.session,
planner.transformRelation(request.getPlan.getRoot),
tracker)
tracker,
shuffleCleanupMode)
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import java.nio.file.Files

import scala.collection.mutable.ArrayBuffer

import com.google.common.cache.CacheBuilder

import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.internal.{config, Logging, MDC}
Expand Down Expand Up @@ -76,13 +78,21 @@ private[spark] class IndexShuffleBlockResolver(
override def getStoredShuffles(): Seq[ShuffleBlockInfo] = {
val allBlocks = blockManager.diskBlockManager.getAllBlocks()
allBlocks.flatMap {
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
case ShuffleIndexBlockId(shuffleId, mapId, _)
if Option(shuffleIdsToSkip.getIfPresent(shuffleId)).isEmpty =>
Some(ShuffleBlockInfo(shuffleId, mapId))
case _ =>
None
}
}

private val shuffleIdsToSkip =
CacheBuilder.newBuilder().maximumSize(1000).build[java.lang.Integer, java.lang.Boolean]()

override def addShuffleToSkip(shuffleId: ShuffleId): Unit = {
shuffleIdsToSkip.put(shuffleId, true)
}

private def getShuffleBytesStored(): Long = {
val shuffleFiles: Seq[File] = getStoredShuffles().map {
si => getDataFile(si.shuffleId, si.mapId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ trait MigratableResolver {
*/
def getStoredShuffles(): Seq[ShuffleBlockInfo]

/**
* Mark a shuffle that should not be migrated.
*/
def addShuffleToSkip(shuffleId: Int): Unit = {}

/**
* Write a provided shuffle block as a stream. Used for block migrations.
* Up to the implementation to support STORAGE_REMOTE_SHUFFLE_MAX_DISK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ private[spark] class BlockManager(

// This is a lazy val so someone can migrating RDDs even if they don't have a MigratableResolver
// for shuffles. Used in BlockManagerDecommissioner & block puts.
private[storage] lazy val migratableResolver: MigratableResolver = {
lazy val migratableResolver: MigratableResolver = {
shuffleManager.shuffleBlockResolver.asInstanceOf[MigratableResolver]
}

Expand Down
7 changes: 5 additions & 2 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
* limitations under the License.
*/

import com.typesafe.tools.mima.core._
import com.typesafe.tools.mima.core
import com.typesafe.tools.mima.core.*

/**
* Additional excludes for checking of Spark's binary compatibility.
Expand Down Expand Up @@ -93,7 +94,9 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.TestWritable"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.TestWritable$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator$")
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator$"),
// SPARK-47764: Cleanup shuffle dependencies based on ShuffleCleanupMode
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.MigratableResolver.addShuffleToSkip")
)

// Default exclude rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,22 @@ object SQLConf {
.intConf
.createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get)

val SHUFFLE_DEPENDENCY_SKIP_MIGRATION_ENABLED =
buildConf("spark.sql.shuffleDependency.skipMigration.enabled")
.doc("When enabled, shuffle dependencies for a Spark Connect SQL execution are marked at " +
"the end of the execution, and they will not be migrated during decommissions.")
.version("4.0.0")
.booleanConf
.createWithDefault(Utils.isTesting)

val SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED =
buildConf("spark.sql.shuffleDependency.fileCleanup.enabled")
.doc("When enabled, shuffle files will be cleaned up at the end of Spark Connect " +
"SQL executions.")
.version("4.0.0")
.booleanConf
.createWithDefault(Utils.isTesting)

val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold")
.internal()
Expand Down
20 changes: 18 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,26 @@ private[sql] object Dataset {
new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
}

def ofRows(
sparkSession: SparkSession,
logicalPlan: LogicalPlan,
shuffleCleanupMode: ShuffleCleanupMode): DataFrame =
sparkSession.withActive {
val qe = new QueryExecution(
sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode)
qe.assertAnalyzed()
new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
}

/** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker)
def ofRows(
sparkSession: SparkSession,
logicalPlan: LogicalPlan,
tracker: QueryPlanningTracker,
shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup)
: DataFrame = sparkSession.withActive {
val qe = new QueryExecution(sparkSession, logicalPlan, tracker)
val qe = new QueryExecution(
sparkSession, logicalPlan, tracker, shuffleCleanupMode = shuffleCleanupMode)
qe.assertAnalyzed()
new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class QueryExecution(
val sparkSession: SparkSession,
val logical: LogicalPlan,
val tracker: QueryPlanningTracker = new QueryPlanningTracker,
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) extends Logging {
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
val shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup) extends Logging {

val id: Long = QueryExecution.nextExecutionId

Expand Down Expand Up @@ -459,6 +460,22 @@ object CommandExecutionMode extends Enumeration {
val SKIP, NON_ROOT, ALL = Value
}

/**
* Modes for shuffle dependency cleanup.
*
* DoNotCleanup: Do not perform any cleanup.
* SkipMigration: Shuffle dependencies will not be migrated at node decommissions.
* RemoveShuffleFiles: Shuffle dependency files are removed at the end of SQL executions.
*/
sealed trait ShuffleCleanupMode

case object DoNotCleanup extends ShuffleCleanupMode

case object SkipMigration extends ShuffleCleanupMode

case object RemoveShuffleFiles extends ShuffleCleanupMode


object QueryExecution {
private val _nextExecutionId = new AtomicLong(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ package org.apache.spark.sql.execution
import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkException, SparkThrowable, SparkThrowableHelper}
import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkEnv, SparkException, SparkThrowable, SparkThrowableHelper}
import org.apache.spark.SparkContext.{SPARK_JOB_DESCRIPTION, SPARK_JOB_INTERRUPT_ON_CANCEL}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, SPARK_EXECUTOR_PREFIX}
import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.SQL_EVENT_TRUNCATE_LENGTH
Expand Down Expand Up @@ -115,6 +117,7 @@ object SQLExecution extends Logging {

withSQLConfPropagated(sparkSession) {
var ex: Option[Throwable] = None
var isExecutedPlanAvailable = false
val startTime = System.nanoTime()
val startEvent = SparkListenerSQLExecutionStart(
executionId = executionId,
Expand Down Expand Up @@ -147,6 +150,7 @@ object SQLExecution extends Logging {
}
sc.listenerBus.post(
startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo))
isExecutedPlanAvailable = true
f()
}
} catch {
Expand All @@ -161,6 +165,24 @@ object SQLExecution extends Logging {
case e =>
Utils.exceptionString(e)
}
if (queryExecution.shuffleCleanupMode != DoNotCleanup
&& isExecutedPlanAvailable) {
val shuffleIds = queryExecution.executedPlan match {
case ae: AdaptiveSparkPlanExec =>
ae.context.shuffleIds.asScala.keys
case _ =>
Iterable.empty
}
shuffleIds.foreach { shuffleId =>
queryExecution.shuffleCleanupMode match {
case RemoveShuffleFiles =>
SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
case SkipMigration =>
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
case _ => // this should not happen
}
}
}
val event = SparkListenerSQLExecutionEnd(
executionId,
System.currentTimeMillis(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.adaptive

import java.util
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}

import scala.collection.concurrent.TrieMap
import scala.collection.mutable
Expand Down Expand Up @@ -302,6 +302,11 @@ case class AdaptiveSparkPlanExec(
try {
stage.materialize().onComplete { res =>
if (res.isSuccess) {
// record shuffle IDs for successful stages for cleanup
stage.plan.collect {
case s: ShuffleExchangeLike =>
context.shuffleIds.put(s.shuffleId, true)
}
events.offer(StageSuccess(stage, res.get))
} else {
events.offer(StageFailure(stage, res.failed.get))
Expand Down Expand Up @@ -869,6 +874,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) {
*/
val stageCache: TrieMap[SparkPlan, ExchangeQueryStageExec] =
new TrieMap[SparkPlan, ExchangeQueryStageExec]()

val shuffleIds: ConcurrentHashMap[Int, Boolean] = new ConcurrentHashMap[Int, Boolean]()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ trait ShuffleExchangeLike extends Exchange {
* Returns the runtime statistics after shuffle materialization.
*/
def runtimeStatistics: Statistics

/**
* The shuffle ID.
*/
def shuffleId: Int
}

// Describes where the shuffle operator comes from.
Expand Down Expand Up @@ -166,6 +171,8 @@ case class ShuffleExchangeExec(
Statistics(dataSize, Some(rowCount))
}

override def shuffleId: Int = shuffleDependency.shuffleId

/**
* A [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE
val attributeStats = AttributeMap(Seq((child.output.head, columnStats)))
Statistics(stats.sizeInBytes, stats.rowCount, attributeStats)
}
override def shuffleId: Int = delegate.shuffleId
override def child: SparkPlan = delegate.child
override protected def doExecute(): RDD[InternalRow] = delegate.execute()
override def outputPartitioning: Partitioning = delegate.outputPartitioning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.datasources.v2.ShowTablesExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.storage.ShuffleIndexBlockId
import org.apache.spark.util.Utils

case class QueryExecutionTestRecord(
Expand Down Expand Up @@ -314,6 +315,48 @@ class QueryExecutionSuite extends SharedSparkSession {
mockCallback.assertExecutedPlanPrepared()
}

private def cleanupShuffles(): Unit = {
val blockManager = spark.sparkContext.env.blockManager
blockManager.diskBlockManager.getAllBlocks().foreach {
case ShuffleIndexBlockId(shuffleId, _, _) =>
spark.sparkContext.env.shuffleManager.unregisterShuffle(shuffleId)
case _ =>
}
}

test("SPARK-47764: Cleanup shuffle dependencies - DoNotCleanup mode") {
val plan = spark.range(100).repartition(10).logicalPlan
val df = Dataset.ofRows(spark, plan, DoNotCleanup)
df.collect()

val blockManager = spark.sparkContext.env.blockManager
assert(blockManager.migratableResolver.getStoredShuffles().nonEmpty)
assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty)
cleanupShuffles()
}

test("SPARK-47764: Cleanup shuffle dependencies - SkipMigration mode") {
val plan = spark.range(100).repartition(10).logicalPlan
val df = Dataset.ofRows(spark, plan, SkipMigration)
df.collect()

val blockManager = spark.sparkContext.env.blockManager
assert(blockManager.migratableResolver.getStoredShuffles().isEmpty)
assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty)
cleanupShuffles()
}

test("SPARK-47764: Cleanup shuffle dependencies - RemoveShuffleFiles mode") {
val plan = spark.range(100).repartition(10).logicalPlan
val df = Dataset.ofRows(spark, plan, RemoveShuffleFiles)
df.collect()

val blockManager = spark.sparkContext.env.blockManager
assert(blockManager.migratableResolver.getStoredShuffles().isEmpty)
assert(blockManager.diskBlockManager.getAllBlocks().isEmpty)
cleanupShuffles()
}

test("SPARK-35378: Return UnsafeRow in CommandResultExecCheck execute methods") {
val plan = spark.sql("SHOW FUNCTIONS").queryExecution.executedPlan
assert(plan.isInstanceOf[CommandResultExec])
Expand Down

0 comments on commit c44493d

Please sign in to comment.