@@ -254,6 +254,28 @@ private object FakeStateStoreProviderWithMaintenanceError {
254254 val errorOnMaintenance = new AtomicBoolean (false )
255255}
256256
257+ /**
258+ * A fake StateStoreProvider for testing maintenance before unload.
259+ * Extends HDFSBackedStateStoreProvider to get actual store functionality,
260+ * but tracks the number of times doMaintenance is called.
261+ */
262+ class MaintenanceCountingStateStoreProvider extends HDFSBackedStateStoreProvider {
263+ import MaintenanceCountingStateStoreProvider ._
264+
265+ override def doMaintenance (): Unit = {
266+ maintenanceCallCount.incrementAndGet()
267+ super .doMaintenance()
268+ }
269+ }
270+
271+ private object MaintenanceCountingStateStoreProvider {
272+ val maintenanceCallCount = new java.util.concurrent.atomic.AtomicInteger (0 )
273+
274+ def reset (): Unit = {
275+ maintenanceCallCount.set(0 )
276+ }
277+ }
278+
257279@ ExtendedSQLTest
258280class StateStoreSuite extends StateStoreSuiteBase [HDFSBackedStateStoreProvider ]
259281 with SharedSparkSession
@@ -1013,77 +1035,65 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
10131035
10141036 // Ensure that maintenance is called before unloading
10151037 test(" SPARK-40492: maintenance before unload" ) {
1038+ // Reset the maintenance call counter
1039+ MaintenanceCountingStateStoreProvider .reset()
1040+
10161041 val conf = new SparkConf ()
10171042 .setMaster(" local" )
10181043 .setAppName(" SPARK-40492" )
10191044 val opId = 0
10201045 val dir1 = newDir()
10211046 val storeProviderId1 = StateStoreProviderId (StateStoreId (dir1, opId, 0 ), UUID .randomUUID)
1022- val sqlConf = getDefaultSQLConf(SQLConf .STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT .defaultValue.get,
1023- SQLConf .MAX_BATCHES_TO_RETAIN_IN_MEMORY .defaultValue.get)
1024- sqlConf.setConf(SQLConf .MIN_BATCHES_TO_RETAIN , 10 )
1025- sqlConf.setConf(SQLConf .STREAMING_MAINTENANCE_INTERVAL , 10L )
1047+ val sqlConf = getDefaultSQLConf(
1048+ SQLConf .STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT .defaultValue.get,
1049+ SQLConf .MAX_BATCHES_TO_RETAIN_IN_MEMORY .defaultValue.get
1050+ )
1051+ sqlConf.setConf(SQLConf .STREAMING_MAINTENANCE_INTERVAL , 5 .seconds.toMillis)
1052+ sqlConf.setConf(SQLConf .STATE_STORE_PROVIDER_CLASS ,
1053+ classOf [MaintenanceCountingStateStoreProvider ].getName)
10261054 val storeConf = StateStoreConf (sqlConf)
10271055 val hadoopConf = new Configuration ()
10281056
1029- var latestStoreVersion = 0
1030-
1031- def generateStoreVersions (): Unit = {
1032- for (i <- 1 to 20 ) {
1033- val store = StateStore .get(storeProviderId1, keySchema, valueSchema,
1034- NoPrefixKeyStateEncoderSpec (keySchema),
1035- latestStoreVersion, None , None , useColumnFamilies = false , storeConf, hadoopConf)
1036- put(store, " a" , 0 , i)
1037- store.commit()
1038- latestStoreVersion += 1
1039- }
1040- }
1041-
10421057 val timeoutDuration = 1 .minute
10431058
10441059 quietly {
10451060 withSpark(SparkContext .getOrCreate(conf)) { sc =>
10461061 withCoordinatorRef(sc) { coordinatorRef =>
10471062 require(! StateStore .isMaintenanceRunning, " StateStore is unexpectedly running" )
10481063
1049- // Generate sufficient versions of store for snapshots
1050- generateStoreVersions()
1064+ // Load the store
1065+ StateStore .get(storeProviderId1, keySchema, valueSchema,
1066+ NoPrefixKeyStateEncoderSpec (keySchema),
1067+ 0 , None , None , useColumnFamilies = false , storeConf, hadoopConf)
1068+
1069+ // Ensure the store is loaded
10511070 eventually(timeout(timeoutDuration)) {
1052- // Store should have been reported to the coordinator
10531071 assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
10541072 " active instance was not reported" )
1055- // Background maintenance should clean up and generate snapshots
1056- assert(StateStore .isMaintenanceRunning, " Maintenance task is not running" )
1057- // Some snapshots should have been generated
1058- tryWithProviderResource(newStoreProvider(storeProviderId1.storeId)) { provider =>
1059- val snapshotVersions = (1 to latestStoreVersion).filter { version =>
1060- fileExists(provider, version, isSnapshot = true )
1061- }
1062- assert(snapshotVersions.nonEmpty, " no snapshot file found" )
1063- }
1073+ assert(StateStore .isLoaded(storeProviderId1), " Store is not loaded" )
10641074 }
1065- // Pause maintenance
1066- StateStore .setMaintenancePaused(true )
10671075
1068- // Generate more versions such that there is another snapshot.
1069- generateStoreVersions()
1076+ // Record the current maintenance call count before deactivation
1077+ val maintenanceCountBeforeDeactivate =
1078+ MaintenanceCountingStateStoreProvider .maintenanceCallCount.get()
10701079
1071- // If driver decides to deactivate all stores related to a query run,
1072- // then this instance should be unloaded.
1080+ // Deactivate the store instance - this should trigger maintenance before unload
10731081 coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
10741082
1075- // Resume maintenance which should unload the deactivated store
1076- StateStore .setMaintenancePaused(false )
1083+ // Wait for the store to be unloaded
10771084 eventually(timeout(timeoutDuration)) {
1078- assert(! StateStore .isLoaded(storeProviderId1))
1085+ assert(! StateStore .isLoaded(storeProviderId1), " Store was not unloaded " )
10791086 }
10801087
1081- // Ensure the earliest delta file should be cleaned up during unload.
1082- tryWithProviderResource(newStoreProvider(storeProviderId1.storeId)) { provider =>
1083- eventually(timeout(timeoutDuration)) {
1084- assert(! fileExists(provider, 1 , isSnapshot = false ), " earliest file not deleted" )
1085- }
1086- }
1088+ // Get the maintenance count after unload
1089+ val maintenanceCountAfterUnload =
1090+ MaintenanceCountingStateStoreProvider .maintenanceCallCount.get()
1091+
1092+ // Ensure that maintenance was called at least one more time during unload
1093+ assert(maintenanceCountAfterUnload > maintenanceCountBeforeDeactivate,
1094+ s " Maintenance should be called before unload. " +
1095+ s " Before: $maintenanceCountBeforeDeactivate, " +
1096+ s " After: $maintenanceCountAfterUnload" )
10871097 }
10881098 }
10891099 }
0 commit comments