diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 1891dd3befa40..a30234d17c4ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1424,12 +1424,6 @@ object StateStore extends Logging { } } - // Pause maintenance for testing purposes only. - @volatile private var maintenancePaused: Boolean = false - private[spark] def setMaintenancePaused(maintPaused: Boolean): Unit = { - maintenancePaused = maintPaused - } - /** * Execute background maintenance task in all the loaded store providers if they are still * the active instances according to the coordinator. @@ -1439,10 +1433,6 @@ object StateStore extends Logging { if (SparkEnv.get == null) { throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores") } - if (maintenancePaused) { - logDebug("Maintenance paused") - return - } // Providers that couldn't be processed now and need to be added back to the queue val providersToRequeue = new ArrayBuffer[(StateStoreProviderId, StateStoreProvider)]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 6bb64315e356b..d2f621d1c961a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -253,6 +253,28 @@ private object FakeStateStoreProviderWithMaintenanceError { val errorOnMaintenance = new AtomicBoolean(false) } +/** + * A fake StateStoreProvider for testing maintenance before unload. + * Extends HDFSBackedStateStoreProvider to get actual store functionality, + * but tracks the number of times doMaintenance is called. + */ +class MaintenanceCountingStateStoreProvider extends HDFSBackedStateStoreProvider { + import MaintenanceCountingStateStoreProvider._ + + override def doMaintenance(): Unit = { + maintenanceCallCount.incrementAndGet() + super.doMaintenance() + } +} + +private object MaintenanceCountingStateStoreProvider { + val maintenanceCallCount = new java.util.concurrent.atomic.AtomicInteger(0) + + def reset(): Unit = { + maintenanceCallCount.set(0) + } +} + @ExtendedSQLTest class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] with BeforeAndAfter { @@ -1010,32 +1032,25 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // Ensure that maintenance is called before unloading test("SPARK-40492: maintenance before unload") { + // Reset the maintenance call counter + MaintenanceCountingStateStoreProvider.reset() + val conf = new SparkConf() .setMaster("local") .setAppName("SPARK-40492") val opId = 0 val dir1 = newDir() val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID) - val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get) - sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 10) - sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L) + val sqlConf = getDefaultSQLConf( + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get + ) + sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 5.seconds.toMillis) + sqlConf.setConf(SQLConf.STATE_STORE_PROVIDER_CLASS, + classOf[MaintenanceCountingStateStoreProvider].getName) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - var latestStoreVersion = 0 - - def generateStoreVersions(): Unit = { - for (i <- 1 to 20) { - val store = StateStore.get(storeProviderId1, keySchema, valueSchema, - NoPrefixKeyStateEncoderSpec(keySchema), - latestStoreVersion, None, None, useColumnFamilies = false, storeConf, hadoopConf) - put(store, "a", 0, i) - store.commit() - latestStoreVersion += 1 - } - } - val timeoutDuration = 1.minute quietly { @@ -1043,44 +1058,39 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] withCoordinatorRef(sc) { coordinatorRef => require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running") - // Generate sufficient versions of store for snapshots - generateStoreVersions() + // Load the store + StateStore.get(storeProviderId1, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + 0, None, None, useColumnFamilies = false, storeConf, hadoopConf) + + // Ensure the store is loaded eventually(timeout(timeoutDuration)) { - // Store should have been reported to the coordinator assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty, "active instance was not reported") - // Background maintenance should clean up and generate snapshots - assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") - // Some snapshots should have been generated - tryWithProviderResource(newStoreProvider(storeProviderId1.storeId)) { provider => - val snapshotVersions = (1 to latestStoreVersion).filter { version => - fileExists(provider, version, isSnapshot = true) - } - assert(snapshotVersions.nonEmpty, "no snapshot file found") - } + assert(StateStore.isLoaded(storeProviderId1), "Store is not loaded") } - // Pause maintenance - StateStore.setMaintenancePaused(true) - // Generate more versions such that there is another snapshot. - generateStoreVersions() + // Record the current maintenance call count before deactivation + val maintenanceCountBeforeDeactivate = + MaintenanceCountingStateStoreProvider.maintenanceCallCount.get() - // If driver decides to deactivate all stores related to a query run, - // then this instance should be unloaded. + // Deactivate the store instance - this should trigger maintenance before unload coordinatorRef.deactivateInstances(storeProviderId1.queryRunId) - // Resume maintenance which should unload the deactivated store - StateStore.setMaintenancePaused(false) + // Wait for the store to be unloaded eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeProviderId1)) + assert(!StateStore.isLoaded(storeProviderId1), "Store was not unloaded") } - // Ensure the earliest delta file should be cleaned up during unload. - tryWithProviderResource(newStoreProvider(storeProviderId1.storeId)) { provider => - eventually(timeout(timeoutDuration)) { - assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") - } - } + // Get the maintenance count after unload + val maintenanceCountAfterUnload = + MaintenanceCountingStateStoreProvider.maintenanceCallCount.get() + + // Ensure that maintenance was called at least one more time during unload + assert(maintenanceCountAfterUnload > maintenanceCountBeforeDeactivate, + s"Maintenance should be called before unload. " + + s"Before: $maintenanceCountBeforeDeactivate, " + + s"After: $maintenanceCountAfterUnload") } } }