From 78eaf75baac2b420cc7e546566f8e537924c9e29 Mon Sep 17 00:00:00 2001 From: liviazhu Date: Thu, 25 Jun 2026 18:04:48 +0000 Subject: [PATCH 1/5] [SS] Split state store maintenance into separate snapshot and cleanup operations Prep work for decoupled state store maintenance: - Introduce a MaintenanceOpRequest enum (All/Snapshot/Cleanup) used to tag entries in the unloadedProvidersToClose queue, which is widened from a (id, provider) pair to a (id, provider, opRequest) triple. - Split doMaintenance into doSnapshotMaintenance and doCleanupMaintenance in the StateStoreProvider trait, HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, and RocksDB, so snapshot upload and old-file cleanup can be invoked independently. doMaintenance now calls both, preserving existing behavior. Ported from databricks-eng/runtime#200656 (SC-221356). Co-authored-by: Isaac --- .../state/HDFSBackedStateStoreProvider.scala | 17 ++++++++- .../execution/streaming/state/RocksDB.scala | 12 ++++++- .../state/RocksDBStateStoreProvider.scala | 27 +++++++++++--- .../streaming/state/StateStore.scala | 32 ++++++++++++++--- .../streaming/state/RocksDBSuite.scala | 36 +++++++++++++++++++ .../streaming/state/StateStoreSuite.scala | 33 +++++++++++++++-- 6 files changed, 142 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 2562f1ff3304e..cedf31adf9462 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -478,12 +478,27 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { + doSnapshotMaintenance() + doCleanupMaintenance() + } + + /** Run only the snapshot upload portion of maintenance. */ + override def doSnapshotMaintenance(): Unit = { try { doSnapshot("maintenance") + } catch { + case NonFatal(e) => + logWarning(log"Error performing snapshot maintenance", e) + } + } + + /** Run only the cleanup portion of maintenance. */ + override def doCleanupMaintenance(): Unit = { + try { cleanup() } catch { case NonFatal(e) => - logWarning(log"Error performing snapshot and cleaning up") + logWarning(log"Error performing cleanup maintenance", e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index bd479ffc6d2a8..0445163710391 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -2045,7 +2045,8 @@ class RocksDB( logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}") } - def doMaintenance(): Unit = { + /** Run only the snapshot upload portion of maintenance. */ + def doSnapshotMaintenance(): Unit = { if (enableChangelogCheckpointing) { var mostRecentSnapshot: Option[RocksDBSnapshot] = None @@ -2076,6 +2077,10 @@ class RocksDB( uploadSnapshot(snapshotToUpload) } } + } + + /** Run only the cleanup portion of maintenance. */ + def doCleanupMaintenance(): Unit = { val cleanupTime = timeTakenMs { fileManager.deleteOldVersions( numVersionsToRetain = conf.minVersionsToRetain, @@ -2085,6 +2090,11 @@ class RocksDB( logInfo(log"Cleaned old data, time taken: ${MDC(LogKeys.TIME_UNITS, cleanupTime)} ms") } + def doMaintenance(): Unit = { + doSnapshotMaintenance() + doCleanupMaintenance() + } + /** * This replaces stale reused files in the snapshot with new ones to be uploaded. * Stale means they are potential candidates for deletion by another diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c181130eec94b..a435ecbaca601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -1047,15 +1047,32 @@ private[sql] class RocksDBStateStoreProvider } override def doMaintenance(): Unit = { + doSnapshotMaintenance() + doCleanupMaintenance() + } + + /** Run only the snapshot upload portion of maintenance. */ + override def doSnapshotMaintenance(): Unit = { + doMaintenanceOp(rocksDB.doSnapshotMaintenance(), "snapshot maintenance") + } + + /** Run only the cleanup portion of maintenance. */ + override def doCleanupMaintenance(): Unit = { + doMaintenanceOp(rocksDB.doCleanupMaintenance(), "cleanup maintenance") + } + + /** + * Common wrapper for maintenance operations: verifies the state machine and swallows non-fatal + * exceptions (SPARK-46547) to avoid deadlock between the maintenance thread and the streaming + * aggregation operator. + */ + private def doMaintenanceOp(op: => Unit, opName: String): Unit = { stateMachine.verifyForMaintenance() try { - rocksDB.doMaintenance() + op } catch { - // SPARK-46547 - Swallow non-fatal exception in maintenance task to avoid deadlock between - // maintenance thread and streaming aggregation operator case NonFatal(ex) => - logWarning(s"Ignoring error while performing maintenance operations with exception=", - ex) + logWarning(s"Ignoring error while performing $opName with exception=", ex) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ad067b8edcc3a..775b8745442ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -86,6 +86,21 @@ object MaintenanceTaskType { case object FromLoadedProviders extends MaintenanceTaskType } +/** + * Tracks which maintenance operations still need to run before a provider can be closed. + * Used as a tag on queue entries in `unloadedProvidersToClose`. + */ +sealed trait MaintenanceOpRequest + +object MaintenanceOpRequest { + /** All maintenance operations still need to run (e.g. query-thread-initiated unload). */ + case object All extends MaintenanceOpRequest + /** Only snapshot still needs to run (cleanup already ran as the triggering op). */ + case object Snapshot extends MaintenanceOpRequest + /** Only cleanup still needs to run (snapshot already ran as the triggering op). */ + case object Cleanup extends MaintenanceOpRequest +} + /** * Base trait for a versioned key-value store which provides read operations. Each instance of a * `ReadStateStore` represents a specific version of state data, and such instances are created @@ -891,6 +906,12 @@ trait StateStoreProvider { /** Optional method for providers to allow for background maintenance (e.g. compactions) */ def doMaintenance(): Unit = { } + /** Run only the snapshot upload portion of maintenance. */ + def doSnapshotMaintenance(): Unit = { } + + /** Run only the cleanup portion of maintenance. */ + def doCleanupMaintenance(): Unit = { } + /** * Optional custom metrics that the implementation may want to report. * @note The StateStore objects created by this provider must report the same custom metrics @@ -1238,7 +1259,7 @@ object StateStore extends Logging { private val maintenanceThreadPoolLock = new Object private val unloadedProvidersToClose = - new ConcurrentLinkedQueue[(StateStoreProviderId, StateStoreProvider)] + new ConcurrentLinkedQueue[(StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)] // This set is to keep track of the partitions that are queued // for maintenance or currently have maintenance running on them @@ -1652,19 +1673,20 @@ object StateStore extends Logging { } // Providers that couldn't be processed now and need to be added back to the queue - val providersToRequeue = new ArrayBuffer[(StateStoreProviderId, StateStoreProvider)]() + val providersToRequeue = + new ArrayBuffer[(StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]() // unloadedProvidersToClose are StateStoreProviders that have been removed from // loadedProviders, and can now be processed for maintenance. This queue contains // providers for which we weren't able to process for maintenance on the previous iteration while (!unloadedProvidersToClose.isEmpty) { - val (providerId, provider) = unloadedProvidersToClose.poll() + val (providerId, provider, opRequest) = unloadedProvidersToClose.poll() if (processThisPartition(providerId)) { submitMaintenanceWorkForProvider( providerId, provider, storeConf, MaintenanceTaskType.FromUnloadedProvidersQueue) } else { - providersToRequeue += ((providerId, provider)) + providersToRequeue += ((providerId, provider, opRequest)) } } @@ -1713,7 +1735,7 @@ object StateStore extends Logging { if (!ableToProcessNow) { // Add to queue for later processing if we can't process now // This will be resubmitted for maintenance later by the background maintenance task - unloadedProvidersToClose.add((id, provider)) + unloadedProvidersToClose.add((id, provider, MaintenanceOpRequest.All)) } ableToProcessNow diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index dc697f5b99dc5..39e3f8c07af63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -1054,6 +1054,42 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } } + test("RocksDB: split maintenance methods upload snapshots and clean up separately") { + val remoteDir = Utils.createTempDir().toString + new File(remoteDir).delete() + val conf = dbConf.copy(enableChangelogCheckpointing = true, + minVersionsToRetain = 3, minDeltasForSnapshot = 1, minVersionsToDelete = 3) + withDB(remoteDir, conf = conf) { db => + // Commit 5 versions, uploading snapshots after each via doSnapshotMaintenance. + for (version <- 0 to 4) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + db.doSnapshotMaintenance() + } + assert(snapshotVersionsPresent(remoteDir) == (1 to 5)) + assert(changelogVersionsPresent(remoteDir) == (1 to 5)) + + // Commit 1 more version without maintenance. + // stale versions: (1, 2, 3), keep versions: (4, 5, 6) + db.load(5) + db.put("5", "5") + db.commit() + assert(snapshotVersionsPresent(remoteDir) == (1 to 5)) + assert(changelogVersionsPresent(remoteDir) == (1 to 6)) + + // doSnapshotMaintenance should upload version 6 and not clean up. + db.doSnapshotMaintenance() + assert(snapshotVersionsPresent(remoteDir) == (1 to 6)) + assert(changelogVersionsPresent(remoteDir) == (1 to 6)) + + // doCleanupMaintenance should delete stale versions (1, 2, 3) and not upload a snapshot. + db.doCleanupMaintenance() + assert(snapshotVersionsPresent(remoteDir) == Seq(4, 5, 6)) + assert(changelogVersionsPresent(remoteDir) == Seq(4, 5, 6)) + } + } + testWithStateStoreCheckpointIdsAndColumnFamilies( "RocksDB: minDeltasForSnapshot", TestWithChangelogCheckpointingEnabled) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 65009a2c1c6bf..14464fee4576d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -342,7 +342,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // Access the queue via reflection for verification val queueField = PrivateMethod[ConcurrentLinkedQueue[ - (StateStoreProviderId, StateStoreProvider)]]( + (StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]]( Symbol("unloadedProvidersToClose")) val queue = StateStore invokePrivate queueField() assert(queue.isEmpty, "Queue should start empty") @@ -398,7 +398,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] eventually(timeout(5.seconds)) { assert(queue.size() == 1, "Provider should be queued after timeout") } - val (queuedId, _) = queue.peek() + // TODO: Assert opRequest value once decoupled maintenance is enabled. + val (queuedId, _, _) = queue.peek() assert(queuedId == providerId, "Queued provider has wrong ID") // Now allow the first maintenance to complete @@ -496,7 +497,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // Get the partitionsForMaintenance field to check the queue is empty val partitionsField = PrivateMethod[ - ConcurrentLinkedQueue[StateStoreProviderId]](Symbol("unloadedProvidersToClose")) + ConcurrentLinkedQueue[(StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]]( + Symbol("unloadedProvidersToClose")) val queue = StateStore invokePrivate partitionsField() assert(queue.isEmpty, "Maintenance queue should be empty after processing queued tasks") } @@ -620,6 +622,31 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("HDFS: split maintenance methods upload snapshots and clean up old files separately") { + tryWithProviderResource(newStoreProvider(opId = Random.nextInt(), partition = 0, + minDeltasForSnapshot = 5)) { provider => + for (i <- 1 to 21) { + val store = provider.getStore(i - 1) + put(store, "a", 0, i) + store.commit() + // Snapshot and cleanup run as independent operations. + provider.doSnapshotMaintenance() + provider.doCleanupMaintenance() + } + + // Snapshots are uploaded by doSnapshotMaintenance (at versions 6, 12, 18 given + // minDeltasForSnapshot = 5) and doCleanupMaintenance removes old files, retaining only + // the last numVersionsToRetain (default 2) versions anchored on the latest snapshot (18). + val basePath = provider.stateStoreId.storeCheckpointLocation() + val remainingFiles = new File(basePath.toString) + .listFiles().filter(f => f.isFile && !f.getName.startsWith(".")) + .map(_.getName).filterNot(_.endsWith(".crc")).toSet + assert(remainingFiles === + Set("18.snapshot", "18.delta", "19.delta", "20.delta", "21.delta"), + s"Unexpected remaining files: $remainingFiles") + } + } + test("get, put, remove etc operations on non-default col family should fail") { tryWithProviderResource(newStoreProvider(opId = Random.nextInt(), partition = 0, minDeltasForSnapshot = 5)) { provider => From 3c294bdab33f4d7bffccb11919bbc9fae3e2f1dd Mon Sep 17 00:00:00 2001 From: liviazhu Date: Thu, 25 Jun 2026 22:00:52 +0000 Subject: [PATCH 2/5] [SS] Decouple state store maintenance scheduler into snapshot and cleanup tasks Core of decoupled state store maintenance, building on the split maintenance operations: - Submit separate snapshot and cleanup tasks per provider, tracked by two independent partition sets (snapshotPartitions/cleanupPartitions) so the two operation types never block each other. - Route the unloadedProvidersToClose queue by MaintenanceOpRequest: All submits one op as FromTaskThread which re-queues the other op; Snapshot/Cleanup submit that op as FromUnloadedProvidersQueue and close the provider afterward. - Rewrite the pool task body around the three source paths, replacing the await/timeout (awaitProcessThisPartition) and single maintenancePartitions set with tryClaimPartition/tryClaimAndSubmit. - Have the query thread queue providers for close instead of submitting directly. - Add an `unloaded` flag (setUnloaded) so maintenance skips providers being torn down; split removeFromLoadedProvidersAndClose into closeProvider + remove-by-key. Ported from databricks-eng/runtime#201011 (SC-225784). Co-authored-by: Isaac --- .../org/apache/spark/internal/LogKeys.java | 1 + .../streaming/state/StateStore.scala | 398 +++++---- .../state/StateStoreInstanceMetricSuite.scala | 28 +- .../streaming/state/StateStoreSuite.scala | 801 +++++++++++------- 4 files changed, 782 insertions(+), 446 deletions(-) diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index 37064bf776312..0e911a684634c 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -366,6 +366,7 @@ public enum LogKeys implements LogKey { LOG_TYPE, LOSSES, LOWER_BOUND, + MAINTENANCE_TASK_TYPE, MALFORMATTED_STRING, MAP_ID, MASTER_URL, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 775b8745442ff..85de1f85a33d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -101,6 +101,18 @@ object MaintenanceOpRequest { case object Cleanup extends MaintenanceOpRequest } +/** + * Specifies which maintenance operation a single pool task should perform. + * Unlike MaintenanceOpRequest (which tracks remaining ops before close), + * this is the concrete op assigned to one pool thread submission. + */ +sealed trait MaintenanceOpType + +object MaintenanceOpType { + case object Snapshot extends MaintenanceOpType + case object Cleanup extends MaintenanceOpType +} + /** * Base trait for a versioned key-value store which provides read operations. Each instance of a * `ReadStateStore` represents a specific version of state data, and such instances are created @@ -814,6 +826,12 @@ case class TimestampAsPostfixKeyStateEncoderSpec(keySchema: StructType) */ trait StateStoreProvider { + // Whether this provider has been unloaded from the executor. It is read on the maintenance + // thread and set when the provider is unloaded, so maintenance does not run on a provider that + // is already being torn down. Volatile because it can be set on a query execution thread while + // being read on a maintenance thread. + @volatile var unloaded: Boolean = false + /** * Initialize the provide with more contextual information from the SQL operator. * This method will be called first after creating an instance of the StateStoreProvider by @@ -856,6 +874,11 @@ trait StateStoreProvider { */ def close(): Unit + /** Marks this provider as unloaded so maintenance threads stop processing it. */ + def setUnloaded(): Unit = { + unloaded = true + } + /** * Return an instance of [[StateStore]] representing state data of the given version. * If `stateStoreCkptId` is provided, the instance also needs to match the ID. @@ -1256,16 +1279,18 @@ object StateStore extends Logging { @GuardedBy("loadedProviders") private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() - private val maintenanceThreadPoolLock = new Object - private val unloadedProvidersToClose = new ConcurrentLinkedQueue[(StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)] - // This set is to keep track of the partitions that are queued - // for maintenance or currently have maintenance running on them - // to prevent the same partition from being processed concurrently. - @GuardedBy("maintenanceThreadPoolLock") - private val maintenancePartitions = new mutable.HashSet[StateStoreProviderId] + // These sets track which providers currently have maintenance tasks in-flight, + // one per operation type, to prevent concurrent same-type operations on the same + // provider. Each set has its own lock. + private val snapshotPartitionsLock = new Object + @GuardedBy("snapshotPartitionsLock") + private val snapshotPartitions = new mutable.HashSet[StateStoreProviderId] + private val cleanupPartitionsLock = new Object + @GuardedBy("cleanupPartitionsLock") + private val cleanupPartitions = new mutable.HashSet[StateStoreProviderId] /** Reports to the coordinator that a StateStore has committed */ def reportCommitToCoordinator( @@ -1514,14 +1539,15 @@ object StateStore extends Logging { }.getOrElse(log"") providerStatus.providerIdsToUnload.foreach(id => { loadedProviders.remove(id).foreach( provider => { - // Trigger maintenance thread to immediately do maintenance on and close the provider. - // Doing maintenance first allows us to do maintenance for a constantly-moving state - // store. - logInfo(log"Submitted maintenance from task thread to close " + + // Queue the provider for maintenance and close. The maintenance scheduler will + // drain the queue and submit tasks. remove() returning non-null ensures only one + // queuer for this provider instance. + // TODO: trigger a scheduler cycle immediately so queued providers are processed + // without waiting for the next periodic tick. + logInfo(log"Queuing provider from task thread for maintenance and close " + log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}." + taskContextIdLogLine + log"Removed provider from loadedProviders") - submitMaintenanceWorkForProvider( - id, provider, storeConf, MaintenanceTaskType.FromTaskThread) + unloadedProvidersToClose.add((id, provider, MaintenanceOpRequest.All)) }) }) providerStatus.shouldForceSnapshotUpload @@ -1544,29 +1570,36 @@ object StateStore extends Logging { } /** - * Unload a state store provider. - * If alreadyRemovedFromLoadedProviders is None, provider will be - * removed from loadedProviders and closed. - * If alreadyRemovedFromLoadedProviders is Some, provider will be closed - * using passed in provider. + * Close a provider and release its resources. * WARNING: CAN ONLY BE CALLED FROM MAINTENANCE THREAD! */ - def removeFromLoadedProvidersAndClose( + def closeProvider( storeProviderId: StateStoreProviderId, - alreadyRemovedProvider: Option[StateStoreProvider] = None): Unit = { - val providerToClose = alreadyRemovedProvider.orElse { - loadedProviders.synchronized { - loadedProviders.remove(storeProviderId) - } - } - providerToClose.foreach { provider => + provider: StateStoreProvider): Unit = { + logInfo(log"Closing ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, storeProviderId)}") + try { provider.close() + } finally { + provider.setUnloaded() + } + logInfo(log"Closed ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, storeProviderId)}") + } + + /** + * Remove a provider from loadedProviders by key and close it. + * WARNING: CAN ONLY BE CALLED FROM MAINTENANCE THREAD! + */ + def removeFromLoadedProvidersAndClose(storeProviderId: StateStoreProviderId): Unit = { + loadedProviders.synchronized { + loadedProviders.remove(storeProviderId) + }.foreach { provider => + closeProvider(storeProviderId, provider) } } /** Unload all state store providers: unit test purpose */ private[sql] def unloadAll(): Unit = loadedProviders.synchronized { - loadedProviders.keySet.foreach { key => removeFromLoadedProvidersAndClose(key) } + loadedProviders.foreach { case (id, provider) => closeProvider(id, provider) } loadedProviders.clear() } @@ -1591,11 +1624,10 @@ object StateStore extends Logging { * */ private[streaming] def stopMaintenanceTaskWithoutLock(): Unit = { if (maintenanceThreadPool != null) { - maintenanceThreadPoolLock.synchronized { - maintenancePartitions.clear() - } maintenanceThreadPool.stop() maintenanceThreadPool = null + snapshotPartitionsLock.synchronized { snapshotPartitions.clear() } + cleanupPartitionsLock.synchronized { cleanupPartitions.clear() } } if (maintenanceTask != null) { maintenanceTask.stop() @@ -1605,10 +1637,15 @@ object StateStore extends Logging { /** Unload and stop all state store providers */ def stop(): Unit = loadedProviders.synchronized { - loadedProviders.keySet.foreach { key => removeFromLoadedProvidersAndClose(key) } + loadedProviders.foreach { case (id, provider) => closeProvider(id, provider) } loadedProviders.clear() _coordRef = null stopMaintenanceTask() + // Drain after stopping the pool to catch anything queued during shutdown. + while (!unloadedProvidersToClose.isEmpty) { + val (id, provider, _) = unloadedProvidersToClose.poll() + closeProvider(id, provider) + } logInfo("StateStore stopped") } @@ -1623,6 +1660,8 @@ object StateStore extends Logging { storeConf.maintenanceInterval, task = { doMaintenance(storeConf) } ) + // TODO: split into separate snapshot and cleanup pools so one operation type + // cannot starve the other when pool threads are saturated. maintenanceThreadPool = new MaintenanceThreadPool(numMaintenanceThreads, maintenanceShutdownTimeout, maintenanceForceShutdownTimeout) logInfo("State Store maintenance task started") @@ -1630,31 +1669,19 @@ object StateStore extends Logging { } } - // Wait until this partition can be processed - private def awaitProcessThisPartition( - id: StateStoreProviderId, - timeoutMs: Long): Boolean = maintenanceThreadPoolLock synchronized { - val startTime = System.currentTimeMillis() - val endTime = startTime + timeoutMs - - // If immediate processing fails, wait with timeout - var canProcessThisPartition = processThisPartition(id) - while (!canProcessThisPartition && System.currentTimeMillis() < endTime) { - maintenanceThreadPoolLock.wait(timeoutMs) - canProcessThisPartition = processThisPartition(id) - } - val elapsedTime = System.currentTimeMillis() - startTime - logInfo(log"Waited for ${MDC(LogKeys.TOTAL_TIME, elapsedTime)} ms to be able to process " + - log"maintenance for partition ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}") - canProcessThisPartition - } - private def doMaintenance(): Unit = doMaintenance(StateStoreConf.empty) - private def processThisPartition(id: StateStoreProviderId): Boolean = { - maintenanceThreadPoolLock.synchronized { - if (!maintenancePartitions.contains(id)) { - maintenancePartitions.add(id) + /** Claim a single partition set slot for the given op type. Returns true if claimed. */ + private[streaming] def tryClaimPartition( + id: StateStoreProviderId, + opType: MaintenanceOpType): Boolean = { + val (partitionSet, lock) = opType match { + case MaintenanceOpType.Snapshot => (snapshotPartitions, snapshotPartitionsLock) + case MaintenanceOpType.Cleanup => (cleanupPartitions, cleanupPartitionsLock) + } + lock.synchronized { + if (!partitionSet.contains(id)) { + partitionSet.add(id) true } else { false @@ -1676,16 +1703,41 @@ object StateStore extends Logging { val providersToRequeue = new ArrayBuffer[(StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]() - // unloadedProvidersToClose are StateStoreProviders that have been removed from - // loadedProviders, and can now be processed for maintenance. This queue contains - // providers for which we weren't able to process for maintenance on the previous iteration + // Phase 1: Drain the unloadedProvidersToClose queue. These are providers that have been + // removed from loadedProviders and need maintenance before close. opRequest determines + // which task to submit and which source type to use: + // All: pick one available op, submit as FromTaskThread (first of two ticks) + // Snapshot: submit snapshot as FromUnloadedProvidersQueue (closes after) + // Cleanup: submit cleanup as FromUnloadedProvidersQueue (closes after) while (!unloadedProvidersToClose.isEmpty) { val (providerId, provider, opRequest) = unloadedProvidersToClose.poll() - if (processThisPartition(providerId)) { - submitMaintenanceWorkForProvider( - providerId, provider, storeConf, MaintenanceTaskType.FromUnloadedProvidersQueue) - } else { + val submitted = opRequest match { + case MaintenanceOpRequest.All => + // All ops should run before the provider can be closed. We serialize them by + // submitting one op now as FromTaskThread; when it completes it enqueues the + // remaining op, which runs as FromUnloadedProvidersQueue and closes the provider. + // TODO: trigger the next scheduler cycle immediately after the first op completes + // so the remaining op is picked up without waiting for the next periodic tick. + // Pick whichever partition set is available with short-circuit evaluation. + tryClaimAndSubmit( + providerId, provider, storeConf, + MaintenanceOpType.Snapshot, MaintenanceTaskType.FromTaskThread) || + tryClaimAndSubmit( + providerId, provider, storeConf, + MaintenanceOpType.Cleanup, MaintenanceTaskType.FromTaskThread) + case MaintenanceOpRequest.Snapshot => + tryClaimAndSubmit( + providerId, provider, storeConf, + MaintenanceOpType.Snapshot, MaintenanceTaskType.FromUnloadedProvidersQueue) + case MaintenanceOpRequest.Cleanup => + tryClaimAndSubmit( + providerId, provider, storeConf, + MaintenanceOpType.Cleanup, MaintenanceTaskType.FromUnloadedProvidersQueue) + } + + // If the partition set is occupied, buffer for requeue and retry on the next cycle. + if (!submitted) { providersToRequeue += ((providerId, provider, opRequest)) } } @@ -1697,117 +1749,175 @@ object StateStore extends Logging { providersToRequeue.foreach(unloadedProvidersToClose.offer) + // Phase 2: Submit separate snapshot and cleanup tasks for loaded providers. loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => - if (processThisPartition(id)) { - submitMaintenanceWorkForProvider( - id, provider, storeConf, MaintenanceTaskType.FromLoadedProviders) - } else { - logInfo(log"Not processing partition ${MDC(LogKeys.PARTITION_ID, id)} " + - log"for maintenance because it is currently " + - log"being processed") - } + tryClaimAndSubmit( + id, provider, storeConf, MaintenanceOpType.Snapshot, MaintenanceTaskType.FromLoadedProviders) + tryClaimAndSubmit( + id, provider, storeConf, MaintenanceOpType.Cleanup, MaintenanceTaskType.FromLoadedProviders) } } + /** + * Attempts to claim a partition set slot and submit maintenance work for a provider. + * Returns true if the work was submitted, false if the partition set was occupied. + */ + private def tryClaimAndSubmit( + providerId: StateStoreProviderId, + provider: StateStoreProvider, + storeConf: StateStoreConf, + opType: MaintenanceOpType, + source: MaintenanceTaskType): Boolean = { + if (tryClaimPartition(providerId, opType)) { + submitMaintenanceWorkForProvider(providerId, provider, storeConf, source, opType) + true + } else { + logInfo(log"Not processing partition " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"with source ${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)} " + + log"for ${MDC(LogKeys.OP_TYPE, opType)} " + + log"because partition set is occupied") + false + } + } + + /** + * Determines the MaintenanceOpRequest for the "other" operation given the current opType. + * Used when a task needs to queue the provider for the remaining operation before close. + */ + private[streaming] def otherMaintenanceOpRequest( + opType: MaintenanceOpType): MaintenanceOpRequest = opType match { + case MaintenanceOpType.Snapshot => MaintenanceOpRequest.Cleanup + case MaintenanceOpType.Cleanup => MaintenanceOpRequest.Snapshot + } + /** * Submits maintenance work for a provider to the maintenance thread pool. * * @param id The StateStore provider ID to perform maintenance on * @param provider The StateStore provider instance + * @param opType Which maintenance operation (snapshot or cleanup) to perform */ private def submitMaintenanceWorkForProvider( id: StateStoreProviderId, provider: StateStoreProvider, storeConf: StateStoreConf, - source: MaintenanceTaskType = FromLoadedProviders): Unit = { + source: MaintenanceTaskType, + opType: MaintenanceOpType): Unit = { maintenanceThreadPool.execute(() => { val startTime = System.currentTimeMillis() - // Determine if we can process this partition based on the source - val canProcessThisPartition = source match { - case FromTaskThread => - // Provider from task thread needs to wait for lock - // We potentially need to wait for ongoing maintenance to finish processing - // this partition - val timeoutMs = storeConf.stateStoreMaintenanceProcessingTimeout * 1000 - val ableToProcessNow = awaitProcessThisPartition(id, timeoutMs) - if (!ableToProcessNow) { - // Add to queue for later processing if we can't process now - // This will be resubmitted for maintenance later by the background maintenance task - unloadedProvidersToClose.add((id, provider, MaintenanceOpRequest.All)) - } - ableToProcessNow - - case FromUnloadedProvidersQueue => - // Provider from queue can be processed immediately - // (we've already removed it from loadedProviders) - true - + // Check that this provider is still valid before doing work. + // TODO: the loadedProviders check is racy (membership could change between the check and + // doing the work). This is resolved by the read/write lock added in a later PR, which + // holds the read lock through the check and the work. + val canProcess = source match { case FromLoadedProviders => - // Provider from loadedProviders can be processed immediately - // as it's in maintenancePartitions - true + // Checks that the ID is still in loadedProviders and that the instance matches the + // one we were given. The scheduler submits from a stale copy of loadedProviders, so + // the provider may have been removed and replaced by a new instance under the same key. + loadedProviders.synchronized { loadedProviders.get(id).contains(provider) } && + !provider.unloaded + case _ => + // FromTaskThread / FromUnloadedProvidersQueue: provider already removed from + // loadedProviders, reference passed directly. + !provider.unloaded } - if (canProcessThisPartition) { - val awaitingPartitionDuration = System.currentTimeMillis() - startTime - try { - provider.doMaintenance() - // Handle unloading based on source - source match { - case FromTaskThread | FromUnloadedProvidersQueue => - // Provider already removed from loadedProviders, just close it - removeFromLoadedProvidersAndClose(id, Some(provider)) + try { + if (canProcess) { + // Do the maintenance work for this op type. + opType match { + case MaintenanceOpType.Snapshot => provider.doSnapshotMaintenance() + case MaintenanceOpType.Cleanup => provider.doCleanupMaintenance() + } + // Handle post-work actions based on source. + source match { case FromLoadedProviders => // Check if provider should be unloaded if (!verifyIfStoreInstanceActive(id)) { - removeFromLoadedProvidersAndClose(id) + // Only remove if the map still holds the same provider instance we were given. + // Between verifyIfStoreInstanceActive and this remove, a concurrent get() may + // have loaded a new provider under the same key; removing by key alone would + // incorrectly remove the new provider. + val removed = loadedProviders.synchronized { + if (loadedProviders.get(id).contains(provider)) { + loadedProviders.remove(id) + } else { + None + } + } + // Queue for the other operation before close. + if (removed.isDefined) { + unloadedProvidersToClose.add((id, provider, otherMaintenanceOpRequest(opType))) + } } + + case FromTaskThread => + // Provider already removed from loadedProviders by the query thread. + // Queue for the other operation before close. + unloadedProvidersToClose.add((id, provider, otherMaintenanceOpRequest(opType))) + + case FromUnloadedProvidersQueue => + // This is the final operation before close. + closeProvider(id, provider) } - logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}") - } catch { - case NonFatal(e) => - logWarning(log"Error doing maintenance on provider:" + - log" ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}. " + - log"Could not unload state store provider", e) - // When we get a non-fatal exception, we just unload the provider. - // - // By not bubbling the exception to the maintenance task thread or the query execution - // thread, it's possible for a maintenance thread pool task to continue failing on - // the same partition. Additionally, if there is some global issue that will cause - // all maintenance thread pool tasks to fail, then bubbling the exception and - // stopping the pool is faster than waiting for all tasks to see the same exception. - // - // However, we assume that repeated failures on the same partition and global issues - // are rare. The benefit to unloading just the partition with an exception is that - // transient issues on a given provider do not affect any other providers; so, in - // most cases, this should be a more performant solution. - source match { - case FromTaskThread | FromUnloadedProvidersQueue => - removeFromLoadedProvidersAndClose(id, Some(provider)) - - case FromLoadedProviders => - removeFromLoadedProvidersAndClose(id) - } - } finally { - val duration = System.currentTimeMillis() - startTime - val logMsg = - log"Finished maintenance task for " + - log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" + - log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}" + - log" and awaiting_partition_time=" + - log"${MDC(LogKeys.TIME_UNITS, awaitingPartitionDuration)}\n" - if (duration > 5000) { - logInfo(logMsg) - } else { - logDebug(logMsg) - } - maintenanceThreadPoolLock.synchronized { - maintenancePartitions.remove(id) - maintenanceThreadPoolLock.notifyAll() + } else { + logInfo(log"Skipping maintenance for " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}, " + + log"provider was removed from loadedProviders") + } + } catch { + case NonFatal(e) => + logWarning(log"Error doing maintenance on provider:" + + log" ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}. " + + log"Could not unload state store provider", e) + // When we get a non-fatal exception, we just unload the provider. + // + // By not bubbling the exception to the maintenance task thread or the query execution + // thread, it's possible for a maintenance thread pool task to continue failing on + // the same partition. Additionally, if there is some global issue that will cause + // all maintenance thread pool tasks to fail, then bubbling the exception and + // stopping the pool is faster than waiting for all tasks to see the same exception. + // + // However, we assume that repeated failures on the same partition and global issues + // are rare. The benefit to unloading just the partition with an exception is that + // transient issues on a given provider do not affect any other providers; so, in + // most cases, this should be a more performant solution. + source match { + case FromTaskThread | FromUnloadedProvidersQueue => + closeProvider(id, provider) + + case FromLoadedProviders => + // Only remove if the map still holds the same provider instance. A concurrent + // get() may have loaded a new provider under the same key. + loadedProviders.synchronized { + if (loadedProviders.get(id).contains(provider)) { + loadedProviders.remove(id) + } + } + // Always close this provider instance regardless of whether we removed it from + // the map. Maintenance failed, so we must clean up this provider's resources. + closeProvider(id, provider) } + } finally { + val duration = System.currentTimeMillis() - startTime + val logMsg = + log"Finished maintenance task for " + + log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" + + log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}\n" + if (duration > 5000) { + logInfo(logMsg) + } else { + logDebug(logMsg) + } + opType match { + case MaintenanceOpType.Snapshot => + snapshotPartitionsLock.synchronized { snapshotPartitions.remove(id) } + case MaintenanceOpType.Cleanup => + cleanupPartitionsLock.synchronized { cleanupPartitions.remove(id) } } } }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala index 58d951500c8c5..2015b7eaed1ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala @@ -30,24 +30,28 @@ import org.apache.spark.tags.ExtendedSQLTest // maintenance for partitions 0 and 1 (these are arbitrary choices). This is used to test // snapshot upload lag can be observed through StreamingQueryProgress metrics. class RocksDBSkipMaintenanceOnCertainPartitionsProvider extends RocksDBStateStoreProvider { - override def doMaintenance(): Unit = { - if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) { - return - } - super.doMaintenance() - } + private def shouldSkip: Boolean = + stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1 + + override def doSnapshotMaintenance(): Unit = + if (!shouldSkip) super.doSnapshotMaintenance() + + override def doCleanupMaintenance(): Unit = + if (!shouldSkip) super.doCleanupMaintenance() } // HDFSBackedSkipMaintenanceOnCertainPartitionsProvider is a test-only provider that skips running // maintenance for partitions 0 and 1 (these are arbitrary choices). This is used to test // snapshot upload lag can be observed through StreamingQueryProgress metrics. class HDFSBackedSkipMaintenanceOnCertainPartitionsProvider extends HDFSBackedStateStoreProvider { - override def doMaintenance(): Unit = { - if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) { - return - } - super.doMaintenance() - } + private def shouldSkip: Boolean = + stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1 + + override def doSnapshotMaintenance(): Unit = + if (!shouldSkip) super.doSnapshotMaintenance() + + override def doCleanupMaintenance(): Unit = + if (!shouldSkip) super.doCleanupMaintenance() } @ExtendedSQLTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 14464fee4576d..c0b13ace4eefb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -30,6 +30,7 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.apache.logging.log4j.Level import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods import org.scalatest.{BeforeAndAfter, PrivateMethodTester} @@ -55,18 +56,32 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** - * A test StateStoreProvider implementation that controls maintenance execution - * timing using a CountDownLatch to simulate concurrent maintenance scenarios. - * - * This provider is used to test the scenario where a task thread attempts to - * unload a provider via maintenance while it's already being processed by a - * maintenance thread. This tests the awaitProcessThisPartition functionality - * that ensures proper synchronization in StateStore's maintenance thread pool. + * A test StateStoreProvider that counts how many times each split maintenance operation + * (snapshot, cleanup) runs and records the thread that closed it. Used to verify that the + * decoupled maintenance scheduler submits snapshot and cleanup as independent operations. + */ +/** + * A test StateStoreProvider whose split maintenance operations (snapshot and cleanup) block + * on per-instance latches, so tests can observe and control the decoupled maintenance + * lifecycle. Records the thread that ran each op and the thread that closed the provider. */ -class SignalingStateStoreProvider extends StateStoreProvider with Logging { - import SignalingStateStoreProvider._ +class BlockingMaintenanceProvider extends StateStoreProvider with Logging { private var id: StateStoreId = null + // Per-instance state. No shared static fields, so stale scheduler cycles from a previous + // test use the old instance's latches (already counted down) and finish immediately. No + // cross-test interference. + @volatile var snapshotThreadName: String = "" + @volatile var cleanupThreadName: String = "" + @volatile var closeThreadName: String = "" + @volatile var snapshotShouldThrow: Boolean = false + @volatile var cleanupShouldThrow: Boolean = false + + val snapshotEnteredLatch = new CountDownLatch(1) + val cleanupEnteredLatch = new CountDownLatch(1) + val snapshotContinueSignal = new CountDownLatch(1) + val cleanupContinueSignal = new CountDownLatch(1) + override def init( stateStoreId: StateStoreId, keySchema: StructType, @@ -82,70 +97,39 @@ class SignalingStateStoreProvider extends StateStoreProvider with Logging { override def stateStoreId: StateStoreId = id - /** - * Records which thread called close() to verify that only maintenance threads close providers - */ override def close(): Unit = { closeThreadName = Thread.currentThread.getName } - /** - * This test implementation doesn't need to provide an actual store - */ + // Returns null because tests using this provider do not need a real store. They only + // exercise the maintenance scheduler and close paths. override def getStore( version: Long, uniqueId: Option[String], forceSnapshotOnCommit: Boolean = false, loadEmpty: Boolean = false): StateStore = null - /** - * Simulates a maintenance operation that blocks until a signal is received. - * This allows testing the scenario where a provider is already under maintenance - * when a task thread tries to trigger another maintenance operation on it. - */ - override def doMaintenance(): Unit = { - maintenanceStarted = true - logInfo(s"Maintenance started on thread: ${Thread.currentThread().getName}") - - // Block until the test signals to continue - continueSignal.await() - - logInfo(s"Maintenance continuing after signal on thread: ${Thread.currentThread().getName}") + /** Signals entry, then blocks until the test releases the continue latch. */ + override def doSnapshotMaintenance(): Unit = { + snapshotThreadName = Thread.currentThread.getName + logInfo(s"Snapshot maintenance entered on ${Thread.currentThread.getName}") + snapshotEnteredLatch.countDown() + snapshotContinueSignal.await() + logInfo(s"Snapshot maintenance continuing on ${Thread.currentThread.getName}") + if (snapshotShouldThrow) { + throw new RuntimeException("snapshot error") + } } -} - -/** - * Companion object that tracks state and provides synchronization primitives - * for testing concurrent maintenance scenarios - */ -object SignalingStateStoreProvider extends Logging { - // For tracking state across threads - var maintenanceStarted: Boolean = false - var taskSubmittedMaintenance: Boolean = false - var closeThreadName: String = "" - - // Added for queue testing - var providerWasQueued: Boolean = false - - // For coordination between threads - var continueSignal = new CountDownLatch(1) - val maintenanceStartedLatch = new CountDownLatch(1) - val taskAttemptCompletedLatch = new CountDownLatch(1) - - /** - * Resets all test state between test runs - */ - def reset(): Unit = { - maintenanceStarted = false - taskSubmittedMaintenance = false - closeThreadName = "" - // Reset the latch to ensure maintenance will block again - try { - continueSignal = new CountDownLatch(1) - } catch { - case e: Exception => - logError(s"Error resetting latch: ${e.getMessage}") + /** Same handshake as doSnapshotMaintenance but for cleanup. */ + override def doCleanupMaintenance(): Unit = { + cleanupThreadName = Thread.currentThread.getName + logInfo(s"Cleanup maintenance entered on ${Thread.currentThread.getName}") + cleanupEnteredLatch.countDown() + cleanupContinueSignal.await() + logInfo(s"Cleanup maintenance continuing on ${Thread.currentThread.getName}") + if (cleanupShouldThrow) { + throw new RuntimeException("cleanup error") } } } @@ -211,11 +195,20 @@ class MaintenanceErrorOnCertainPartitionsProvider extends HDFSBackedStateStorePr storeConfs, hadoopConf, useMultipleValuesPerKey) } - override def doMaintenance(): Unit = { + private def maybeThrow(): Unit = { if (id.partitionId == 0 || id.partitionId == 1) { throw new RuntimeException("Intentional maintenance failure") } - super.doMaintenance() + } + + override def doSnapshotMaintenance(): Unit = { + maybeThrow() + super.doSnapshotMaintenance() + } + + override def doCleanupMaintenance(): Unit = { + maybeThrow() + super.doCleanupMaintenance() } } @@ -270,17 +263,24 @@ private object FakeStateStoreProviderWithMaintenanceError { class MaintenanceCountingStateStoreProvider extends HDFSBackedStateStoreProvider { import MaintenanceCountingStateStoreProvider._ - override def doMaintenance(): Unit = { - maintenanceCallCount.incrementAndGet() - super.doMaintenance() + override def doSnapshotMaintenance(): Unit = { + snapshotMaintenanceCallCount.incrementAndGet() + super.doSnapshotMaintenance() + } + + override def doCleanupMaintenance(): Unit = { + cleanupMaintenanceCallCount.incrementAndGet() + super.doCleanupMaintenance() } } private object MaintenanceCountingStateStoreProvider { - val maintenanceCallCount = new java.util.concurrent.atomic.AtomicInteger(0) + val snapshotMaintenanceCallCount = new java.util.concurrent.atomic.AtomicInteger(0) + val cleanupMaintenanceCallCount = new java.util.concurrent.atomic.AtomicInteger(0) def reset(): Unit = { - maintenanceCallCount.set(0) + snapshotMaintenanceCallCount.set(0) + cleanupMaintenanceCallCount.set(0) } } @@ -302,258 +302,476 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] require(!StateStore.isMaintenanceRunning) } - test("SPARK-51596: submitMaintenanceWorkForProvider from task thread adds" + - " to queue when timeout occurs") { - // Reset tracking variables for a clean test - SignalingStateStoreProvider.reset() - - val sqlConf = getDefaultSQLConf( - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get - ) + private def getUnloadQueue(): ConcurrentLinkedQueue[ + (StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)] = { + val f = PrivateMethod[ConcurrentLinkedQueue[ + (StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]]( + Symbol("unloadedProvidersToClose")) + StateStore invokePrivate f() + } - // Critical: Set a very short timeout to ensure awaitProcessThisPartition fails quickly - sqlConf.setConf(SQLConf.STATE_STORE_MAINTENANCE_PROCESSING_TIMEOUT, 1L) // 1 second + private def getSnapshotPartitions(): mutable.HashSet[StateStoreProviderId] = { + val f = PrivateMethod[mutable.HashSet[StateStoreProviderId]](Symbol("snapshotPartitions")) + StateStore invokePrivate f() + } - // Maintenance interval large enough that we control timing manually - sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 30000L) - sqlConf.setConf(SQLConf.NUM_STATE_STORE_MAINTENANCE_THREADS, 4) + private def getCleanupPartitions(): mutable.HashSet[StateStoreProviderId] = { + val f = PrivateMethod[mutable.HashSet[StateStoreProviderId]](Symbol("cleanupPartitions")) + StateStore invokePrivate f() + } - // Use our test provider - sqlConf.setConf( - SQLConf.STATE_STORE_PROVIDER_CLASS, - classOf[SignalingStateStoreProvider].getName - ) + private def getLoadedProviders(): mutable.HashMap[StateStoreProviderId, StateStoreProvider] = { + val f = PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]( + Symbol("loadedProviders")) + StateStore invokePrivate f() + } - val conf = new SparkConf().setMaster("local").setAppName("test") + private def getBlockingProvider(id: StateStoreProviderId): BlockingMaintenanceProvider = { + val loaded = getLoadedProviders() + loaded.synchronized { loaded.get(id).get }.asInstanceOf[BlockingMaintenanceProvider] + } - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { _ => - val rootLocation = s"${Utils.createTempDir().getAbsolutePath}/spark-51596-timeout-queue" - val providerId = StateStoreProviderId(StateStoreId(rootLocation, 0, 0), UUID.randomUUID) + private def maintenanceStoreConf( + providerClass: Class[_], + interval: Long = 100L, + numThreads: Int = 4): StateStoreConf = { + val sqlConf = getDefaultSQLConf( + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get) + sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, interval) + sqlConf.setConf(SQLConf.NUM_STATE_STORE_MAINTENANCE_THREADS, numThreads) + sqlConf.setConf(SQLConf.STATE_STORE_PROVIDER_CLASS, providerClass.getName) + new StateStoreConf(sqlConf) + } - // Load the provider to start the maintenance system - StateStore.get( - providerId, - keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, None, None, useColumnFamilies = false, - new StateStoreConf(sqlConf), new Configuration() - ) + private def loadNullProvider( + dir: String, + storeConf: StateStoreConf, + partition: Int = 0): StateStoreProviderId = { + val storeId = StateStoreProviderId(StateStoreId(dir, 0, partition), UUID.randomUUID) + StateStore.get( + storeId, null, null, NoPrefixKeyStateEncoderSpec(null), + 0, None, None, useColumnFamilies = false, storeConf, new Configuration()) + storeId + } - // Access the queue via reflection for verification - val queueField = PrivateMethod[ConcurrentLinkedQueue[ - (StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]]( - Symbol("unloadedProvidersToClose")) - val queue = StateStore invokePrivate queueField() + test("SPARK-51596: task thread unload lifecycle from queue to close") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + // TODO: set to a longer interval once triggerNow is added so the test verifies close + // happens via triggerNow, not the periodic tick. + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider], interval = 5000L) + val id1 = loadNullProvider("lifecycle", storeConf) + val bp = getBlockingProvider(id1) + + val queue = getUnloadQueue() + assert(StateStore.isLoaded(id1)) assert(queue.isEmpty, "Queue should start empty") - // Manually trigger maintenance which will block - val maintenanceMethod = PrivateMethod[Unit](Symbol("doMaintenance")) - StateStore invokePrivate maintenanceMethod() - - // Wait for maintenance to start - eventually(timeout(5.seconds)) { - assert(SignalingStateStoreProvider.maintenanceStarted) - assert(StateStore.isLoaded(providerId)) + // Make stale and load another provider to trigger task thread queueing. Use a + // non-blocking provider for id2 since we don't need to observe its maintenance. + coordinatorRef.reportActiveInstance(id1, "otherhost", "otherexec", Seq.empty) + val storeConf2 = maintenanceStoreConf(classOf[FakeStateStoreProviderTracksCloseThread]) + val id2 = loadNullProvider("lifecycle", storeConf2, partition = 1) + + assert(!StateStore.isLoaded(id1), "Provider1 should be removed") + assert(StateStore.isLoaded(id2), "Provider2 should still be loaded") + + // Verify task thread queued with All. + assert(!queue.isEmpty) + val (qId, _, opReq) = queue.peek() + assert(qId == id1) + assert(opReq == MaintenanceOpRequest.All, s"Expected All, got $opReq") + + // Step 2: Scheduler submits first op as FromTaskThread. Snapshot enters latch. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS) && + bp.snapshotEnteredLatch.getCount == 0, "snapshot should have started") + + // Step 3: Release snapshot. Post-work queues remaining op (Cleanup). + bp.snapshotContinueSignal.countDown() + + // Step 4: Scheduler picks up Cleanup as FromUnloadedProvidersQueue. cleanupEnteredLatch + // being counted down proves Cleanup was queued and submitted. + assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS) && + bp.cleanupEnteredLatch.getCount == 0, "cleanup should have started") + + // Snapshot's post-work queued Cleanup and scheduler drained it. Cleanup is now running, + // queue is empty. + assert(queue.isEmpty, "queue should be drained while cleanup runs") + + // Step 5: Release cleanup. FromUnloadedProvidersQueue calls closeProvider. + bp.cleanupContinueSignal.countDown() + + // Verify provider was closed on the maintenance pool. + eventually(timeout(10.seconds)) { + assert(bp.closeThreadName.contains("state-store-maintenance-thread"), + "close should happen on maintenance thread, but was on: " + bp.closeThreadName) + assert(!StateStore.isLoaded(id1), "provider should be removed from loadedProviders") + assert(queue.isEmpty, "Queue should be drained") } + } + } + } - // Now get access to the provider to simulate a task thread - val loadedProvidersField = PrivateMethod[ - mutable.HashMap[StateStoreProviderId, StateStoreProvider]]( - Symbol("loadedProviders")) - val loadedProviders = StateStore invokePrivate loadedProvidersField() - val provider = loadedProviders.synchronized { loadedProviders.get(providerId).get } - val maintenancePartitionsField = PrivateMethod[ - mutable.HashSet[StateStoreProviderId]]( - Symbol("maintenancePartitions")) - val maintenancePartitions = StateStore invokePrivate maintenancePartitionsField() - - // Create a task thread that will attempt to submit maintenance - val taskThread = new Thread(() => { - try { - // Call submitMaintenanceWorkForProvider directly since that's what we're testing - val submitMaintenanceMethod = PrivateMethod[Unit]( - Symbol("submitMaintenanceWorkForProvider")) - StateStore invokePrivate submitMaintenanceMethod( - providerId, provider, new StateStoreConf(sqlConf), - MaintenanceTaskType.FromTaskThread) - - SignalingStateStoreProvider.taskSubmittedMaintenance = true - SignalingStateStoreProvider.taskAttemptCompletedLatch.countDown() - } catch { - case e: Exception => - logError(s"Error in task thread: ${e.getMessage}", e) - } - }) + test("tryClaimPartition returns true first call, false second, true for different opType") { + val id = StateStoreProviderId(StateStoreId("dir", 0, 0), UUID.randomUUID) + val id2 = StateStoreProviderId(StateStoreId("dir", 0, 1), UUID.randomUUID) - // Start the task thread - it should timeout and add provider to queue - taskThread.start() + try { + // First claim for snapshot succeeds + assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) + // Second claim for same id + opType fails + assert(!StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) + // Claim for same id but different opType succeeds + assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Cleanup)) + // That one is also occupied now + assert(!StateStore.tryClaimPartition(id, MaintenanceOpType.Cleanup)) + + // Different id can still claim both + assert(StateStore.tryClaimPartition(id2, MaintenanceOpType.Snapshot)) + assert(StateStore.tryClaimPartition(id2, MaintenanceOpType.Cleanup)) + } finally { + getSnapshotPartitions().clear() + getCleanupPartitions().clear() + } + } - // Wait for task attempt to complete - assert(SignalingStateStoreProvider - .taskAttemptCompletedLatch.await(10, TimeUnit.SECONDS), - "Task thread didn't complete") + test("otherMaintenanceOpRequest maps correctly") { + assert(StateStore.otherMaintenanceOpRequest(MaintenanceOpType.Snapshot) + === MaintenanceOpRequest.Cleanup) + assert(StateStore.otherMaintenanceOpRequest(MaintenanceOpType.Cleanup) + === MaintenanceOpRequest.Snapshot) + } - // Critical verification: After timeout, the provider should be in the queue - eventually(timeout(5.seconds)) { - assert(queue.size() == 1, "Provider should be queued after timeout") - } - // TODO: Assert opRequest value once decoupled maintenance is enabled. - val (queuedId, _, _) = queue.peek() - assert(queuedId == providerId, "Queued provider has wrong ID") + test("closeProvider sets unloaded even if close() throws") { + val storeId = StateStoreProviderId(StateStoreId("closeTest", 0, 0), UUID.randomUUID) + val callOrder = new mutable.ArrayBuffer[String]() + val provider = new FakeStateStoreProviderTracksCloseThread { + override def close(): Unit = { + callOrder += "close" + throw new RuntimeException("close failed") + } + override def setUnloaded(): Unit = { + callOrder += "setUnloaded" + super.setUnloaded() + } + } + provider.init( + storeId.storeId, null, null, NoPrefixKeyStateEncoderSpec(null), + useColumnFamilies = false, null, null) - // Now allow the first maintenance to complete - SignalingStateStoreProvider.continueSignal.countDown() + assert(!provider.unloaded) + intercept[RuntimeException] { + StateStore.closeProvider(storeId, provider) + } + assert(provider.unloaded, "setUnloaded should run even if close() throws") + assert(callOrder === Seq("close", "setUnloaded"), + "setUnloaded should run after close() even if it throws") + } - eventually(timeout(5.seconds)) { - assert(maintenancePartitions.isEmpty, - "Maintenance partitions should be removed from") + test("concurrent snapshot and cleanup on same provider both succeed") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val storeId = loadNullProvider("concurrentDir", storeConf) + val bp = getBlockingProvider(storeId) + + // Wait for both snapshot and cleanup to enter. + assert(bp.snapshotEnteredLatch.await(30, TimeUnit.SECONDS), "snapshot should have started") + assert(bp.cleanupEnteredLatch.await(30, TimeUnit.SECONDS), "cleanup should have started") + + // Both run on maintenance pool threads, on different threads. + val prefix = "state-store-maintenance-thread" + assert(bp.snapshotThreadName.startsWith(prefix)) + assert(bp.cleanupThreadName.startsWith(prefix)) + assert(bp.snapshotThreadName != bp.cleanupThreadName, + "snapshot and cleanup should run on different threads") + + // Partition sets should be claimed while both are running. + assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Snapshot), + "snapshot partition set should be occupied") + assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Cleanup), + "cleanup partition set should be occupied") + + // Release both to finish. + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + // Verify all ops completed by checking partition sets are released. Read the sets via + // reflection (tryClaimPartition has side effects that break eventually retries). + eventually(timeout(10.seconds)) { + assert(!getSnapshotPartitions().contains(storeId), + "snapshot partition set should be released") + assert(!getCleanupPartitions().contains(storeId), + "cleanup partition set should be released") } - // Manually trigger another maintenance to process the queue - StateStore invokePrivate maintenanceMethod() + } + } + } - // Verify the queue eventually gets processed - eventually(timeout(5.seconds)) { - assert(queue.isEmpty, "Queue should be emptied after maintenance") + test("partition set released when maintenance throws") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val storeId = loadNullProvider("errorDir", storeConf) + val bp = getBlockingProvider(storeId) + bp.snapshotShouldThrow = true + + // Wait for snapshot to enter. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + // Partition set is claimed while running. + assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Snapshot), + "snapshot set should be occupied") + + // Release snapshot. It will throw. Release cleanup to avoid blocking the pool. + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + // isLoaded returning false proves the exception was thrown (only the error handler + // unloads), and the finally block must release the partition set. + eventually(timeout(10.seconds)) { + assert(!StateStore.isLoaded(storeId), "provider should be unloaded after throw") + assert(bp.closeThreadName.nonEmpty, "provider should be closed after throw") + assert(!getSnapshotPartitions().contains(storeId), + "snapshot set should be released by finally block after throw") } } } } - test("SPARK-51596: queued maintenance tasks get processed when lock is available") { - // Reset tracking variables for a clean test - SignalingStateStoreProvider.reset() - - val sqlConf = getDefaultSQLConf( - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get - ) - // Use a maintenance interval large enough that we control timing explicitly - sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 30000L) - // Set our special provider class that lets us control maintenance timing - sqlConf.setConf( - SQLConf.STATE_STORE_PROVIDER_CLASS, - classOf[SignalingStateStoreProvider].getName - ) - - val conf = new SparkConf().setMaster("local").setAppName("test") + private def testRequeue(opRequest: MaintenanceOpRequest): Unit = { + val logAppender = new LogAppender("requeue-log", maxEvents = 100) + logAppender.setThreshold(Level.INFO) + withLogAppender(logAppender, level = Some(Level.INFO)) { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val storeId = loadNullProvider("requeueDir", storeConf) + val bp = getBlockingProvider(storeId) + + // Wait for both ops to enter, occupying both partition sets. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Add an entry to the queue. The scheduler tries to drain it but the partition set is + // occupied (held by the blocked task above), so it should be requeued. + val queue = getUnloadQueue() + queue.add((storeId, bp, opRequest)) + + eventually(timeout(10.seconds)) { + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("Had to requeue")), + s"scheduler should have logged requeue for $opRequest") + } - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { coordinatorRef => - val rootLocation = s"${Utils.createTempDir().getAbsolutePath}/spark-51596-queue" + // Queue should still have the entry. + assert(queue.size() == 1, s"$opRequest entry should have been requeued") + val (requeuedId, _, requeuedOp) = queue.peek() + assert(requeuedId == storeId) + assert(requeuedOp == opRequest) - // Create two providers that we'll use for the test - val provider1Id = - StateStoreProviderId(StateStoreId(rootLocation, 0, 0), UUID.randomUUID) - val provider2Id = - StateStoreProviderId(StateStoreId(rootLocation, 0, 1), UUID.randomUUID) + // Clean up: release latches. + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + queue.clear() + } + } + } + } - // Get the first provider to load it - StateStore.get( - provider1Id, - keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, None, None, useColumnFamilies = false, - new StateStoreConf(sqlConf), new Configuration() - ) + test("Snapshot entry requeues when snapshot partition set is occupied") { + testRequeue(MaintenanceOpRequest.Snapshot) + } - // Manually trigger maintenance for provider1, which will block in doMaintenance() - val maintenanceMethod = PrivateMethod[Unit](Symbol("doMaintenance")) - StateStore invokePrivate maintenanceMethod() + test("Cleanup entry requeues when cleanup partition set is occupied") { + testRequeue(MaintenanceOpRequest.Cleanup) + } - // Wait for maintenance to start before continuing - eventually(timeout(5.seconds)) { - assert(SignalingStateStoreProvider.maintenanceStarted) - assert(StateStore.isLoaded(provider1Id)) - } + test("All entry requeues when both partition sets are occupied") { + testRequeue(MaintenanceOpRequest.All) + } - // Now make the first provider "stale" by reporting it active on another executor - coordinatorRef.reportActiveInstance(provider1Id, "otherhost", "otherexec", Seq.empty) + test("When MaintenanceOpRequest is All, cleanup is submitted if snapshot partition set " + + "is occupied") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + // Long interval so we can set up before the first cycle fires (5s initial delay). + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider], interval = 5000L) + val id = loadNullProvider("shortCircuit", storeConf) + val bp = getBlockingProvider(id) - // Get provider2 which will cause a maintenance task for provider1 to be queued - // (since provider1 is already under maintenance and can't be processed immediately) - StateStore.get( - provider2Id, - keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, None, None, useColumnFamilies = false, - new StateStoreConf(sqlConf), new Configuration() - ) + try { + // Claim snapshot partition set. + assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) - // Mark that task submitted maintenance - SignalingStateStoreProvider.taskSubmittedMaintenance = true + // Remove from loadedProviders so the scheduler doesn't submit cleanup for this + // provider by iterating through it. Only the queue entry should submit cleanup. + getLoadedProviders().synchronized { getLoadedProviders().remove(id) } - // Unblock the first maintenance operation - SignalingStateStoreProvider.continueSignal.countDown() + // Add All entry. The scheduler's first cycle (at 5s) drains it, tries snapshot + // (occupied), falls through to cleanup. Cleanup blocks on bp's latch, keeping the set. + getUnloadQueue().add((id, bp, MaintenanceOpRequest.All)) - // Verify that provider1 is eventually unloaded by the maintenance thread - // after the first maintenance completes and the queued maintenance runs - eventually(timeout(5.seconds)) { - // Provider1 should be unloaded - assert(!StateStore.isLoaded(provider1Id)) - // Provider2 should still be loaded - assert(StateStore.isLoaded(provider2Id)) - // Close should have been called on a maintenance thread - assert(SignalingStateStoreProvider.closeThreadName.contains("maintenance")) + eventually(timeout(10.seconds)) { + assert(getCleanupPartitions().contains(id), "cleanup should be claimed") + assert(getUnloadQueue().isEmpty, "queue should be drained") + } + } finally { + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + getSnapshotPartitions().remove(id) } - - // Get the partitionsForMaintenance field to check the queue is empty - val partitionsField = PrivateMethod[ - ConcurrentLinkedQueue[(StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]]( - Symbol("unloadedProvidersToClose")) - val queue = StateStore invokePrivate partitionsField() - assert(queue.isEmpty, "Maintenance queue should be empty after processing queued tasks") } } } - test("SPARK-51596: unloading only occurs on maintenance thread but occurs promptly") { - // Reset closeThreadNames - FakeStateStoreProviderTracksCloseThread.closeThreadNames = Nil + test("canProcess is false and maintenance is skipped when provider has already been unloaded") { + val logAppender = new LogAppender("canProcess-log", maxEvents = 100) + logAppender.setThreshold(Level.INFO) + withLogAppender(logAppender, level = Some(Level.INFO)) { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = + maintenanceStoreConf(classOf[BlockingMaintenanceProvider], interval = 5000L) + val id = loadNullProvider("canProcess", storeConf) + val bp = getBlockingProvider(id) + + // Mark unloaded before the first maintenance cycle fires (5s initial delay). + // canProcess checks !provider.unloaded and will return false, skipping maintenance. + bp.setUnloaded() + + // Wait for the "Skipping maintenance" log proving canProcess was false. + eventually(timeout(10.seconds)) { + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("Skipping maintenance")), + "should log skipping maintenance for unloaded provider") + } + assert(bp.snapshotEnteredLatch.getCount == 1, "snapshot should not have entered") + assert(bp.cleanupEnteredLatch.getCount == 1, "cleanup should not have entered") + } + } + } + } - val sqlConf = getDefaultSQLConf( - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get - ) - // Make maintenance interval very large (30s) so that task thread runs before maintenance. - sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 30000L) - // Use the `FakeStateStoreProviderTracksCloseThread` to run the test - sqlConf.setConf( - SQLConf.STATE_STORE_PROVIDER_CLASS, - classOf[FakeStateStoreProviderTracksCloseThread].getName - ) + test("canProcess is false and maintenance is skipped when provider instance differs") { + val logAppender = new LogAppender("canProcess-stale-log", maxEvents = 100) + logAppender.setThreshold(Level.INFO) + withLogAppender(logAppender, level = Some(Level.INFO)) { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + // numThreads=1: snapshot gets the only pool thread. Cleanup's task waits in the pool's + // work queue until the thread is free. + val storeConf = + maintenanceStoreConf(classOf[BlockingMaintenanceProvider], numThreads = 1) + val id = loadNullProvider("canProcessStale", storeConf) + val bp = getBlockingProvider(id) + + // Wait for snapshot to enter (occupies the only pool thread). Cleanup's task waits in + // the pool's work queue. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Replace A with a different instance while cleanup is waiting. When cleanup starts, + // canProcess sees contains(A) is false. + val replacement = new FakeStateStoreProviderTracksCloseThread + replacement.init(id.storeId, null, null, NoPrefixKeyStateEncoderSpec(null), + useColumnFamilies = false, null, null) + val loaded = getLoadedProviders() + loaded.synchronized { loaded.put(id, replacement) } + + // Release snapshot. Thread frees, cleanup starts, canProcess fails (instance differs), + // maintenance skipped. + bp.snapshotContinueSignal.countDown() + + eventually(timeout(10.seconds)) { + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("Skipping maintenance")), + "should log skipping maintenance for stale instance") + } + assert(bp.cleanupEnteredLatch.getCount == 1, "cleanup should not have entered") + } + } + } + } + test("FromLoadedProviders unload: reloaded provider is not removed nor queued") { val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => withCoordinatorRef(sc) { coordinatorRef => - val rootLocation = s"${Utils.createTempDir().getAbsolutePath}/spark-51596" - val providerId = - StateStoreProviderId(StateStoreId(rootLocation, 0, 0), UUID.randomUUID) - val providerId2 = - StateStoreProviderId(StateStoreId(rootLocation, 0, 1), UUID.randomUUID) - - // Create provider to start the maintenance task + pool - StateStore.get( - providerId, - keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, None, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() - ) - - // Report instance active on another executor - coordinatorRef.reportActiveInstance(providerId, "otherhost", "otherexec", Seq.empty) + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val id = loadNullProvider("staleInstance", storeConf) + val bp = getBlockingProvider(id) + + // Mark as needing to be closed. + coordinatorRef.reportActiveInstance(id, "otherhost", "otherexec", Seq.empty) + + // Wait for snapshot to enter. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Stop only the scheduler (not the pools) so A's threads keep running but no new cycles + // fire after we replace. + val taskField = PrivateMethod[StateStore.MaintenanceTask](Symbol("maintenanceTask")) + (StateStore invokePrivate taskField()).stop() + + // Replace provider A with a different instance while A is blocked. When A finishes, + // loadedProviders.get(id).contains(A) is false (replacement is there), removal skipped. + val replacement = new FakeStateStoreProviderTracksCloseThread + replacement.init(id.storeId, null, null, + NoPrefixKeyStateEncoderSpec(null), useColumnFamilies = false, null, null) + val loaded = getLoadedProviders() + loaded.synchronized { loaded.put(id, replacement) } + + // Release A. + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + // Wait for A's partition sets to be released. + eventually(timeout(10.seconds)) { + assert(!getSnapshotPartitions().contains(id), "snapshot partition should be released") + } - // Load another provider to trigger task unload - StateStore.get( - providerId2, - keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, None, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() - ) + // A should NOT have removed the replacement from loadedProviders. + assert(StateStore.isLoaded(id), "replacement provider should still be loaded") + // A should NOT have queued anything (instance differs, skip). + assert(getUnloadQueue().isEmpty, "queue should be empty (stale instance skipped removal)") + } + } + } - // Wait for close to occur. Timeout is less than maintenance interval, - // so should only close by task triggering. - eventually(timeout(5.seconds)) { - assert(FakeStateStoreProviderTracksCloseThread.closeThreadNames.size == 1) - FakeStateStoreProviderTracksCloseThread.closeThreadNames.foreach { name => - assert(name.contains("state-store-maintenance-thread"))} + test("FromLoadedProviders unload: with concurrent ops, only one removes and queues") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val id = loadNullProvider("concurrentUnload", storeConf) + val bp = getBlockingProvider(id) + + // Make provider stale so both ops detect inactive in source handling. + coordinatorRef.reportActiveInstance(id, "otherhost", "otherexec", Seq.empty) + + // Wait for both ops to enter. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Stop the scheduler so no new cycles drain the queue before we can check its size. + val taskField = PrivateMethod[StateStore.MaintenanceTask](Symbol("maintenanceTask")) + (StateStore invokePrivate taskField()).stop() + + // Release both. Both finish, both see !verifyIfStoreInstanceActive. Only one should + // remove from loadedProviders and queue. + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + val queue = getUnloadQueue() + eventually(timeout(10.seconds)) { + assert(!StateStore.isLoaded(id), "provider should be removed") + assert(queue.size() == 1, "only one op should queue, the other should no-op") } } } @@ -1108,27 +1326,30 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(StateStore.isLoaded(storeProviderId1), "Store is not loaded") } - // Record the current maintenance call count before deactivation - val maintenanceCountBeforeDeactivate = - MaintenanceCountingStateStoreProvider.maintenanceCallCount.get() + // Record the current maintenance call counts before deactivation + val snapshotCountBefore = + MaintenanceCountingStateStoreProvider.snapshotMaintenanceCallCount.get() + val cleanupCountBefore = + MaintenanceCountingStateStoreProvider.cleanupMaintenanceCallCount.get() - // Deactivate the store instance - this should trigger maintenance before unload + // Deactivate the store instance - this should trigger maintenance before close. In the + // decoupled design, the provider is removed from loadedProviders before close completes + // (removal and close are separate events), so we wait for all conditions together. coordinatorRef.deactivateInstances(storeProviderId1.queryRunId) - // Wait for the store to be unloaded eventually(timeout(timeoutDuration)) { assert(!StateStore.isLoaded(storeProviderId1), "Store was not unloaded") + val snapshotCountAfter = + MaintenanceCountingStateStoreProvider.snapshotMaintenanceCallCount.get() + val cleanupCountAfter = + MaintenanceCountingStateStoreProvider.cleanupMaintenanceCallCount.get() + assert(snapshotCountAfter > snapshotCountBefore, + s"Snapshot maintenance should run before close. " + + s"Before: $snapshotCountBefore, After: $snapshotCountAfter") + assert(cleanupCountAfter > cleanupCountBefore, + s"Cleanup maintenance should run before close. " + + s"Before: $cleanupCountBefore, After: $cleanupCountAfter") } - - // Get the maintenance count after unload - val maintenanceCountAfterUnload = - MaintenanceCountingStateStoreProvider.maintenanceCallCount.get() - - // Ensure that maintenance was called at least one more time during unload - assert(maintenanceCountAfterUnload > maintenanceCountBeforeDeactivate, - s"Maintenance should be called before unload. " + - s"Before: $maintenanceCountBeforeDeactivate, " + - s"After: $maintenanceCountAfterUnload") } } } From 98b6566edbaff5b1c50ebf1b5855373edb1bd861 Mon Sep 17 00:00:00 2001 From: liviazhu Date: Thu, 25 Jun 2026 22:59:08 +0000 Subject: [PATCH 3/5] [SS] Add RW lock, dual maintenance pools, and scheduler trigger for state store maintenance Builds on the decoupled maintenance scheduler: - Add a per-provider fair ReentrantReadWriteLock (maintenanceLock). Maintenance ops hold the read lock; close holds the write lock, so close waits for in-flight maintenance and a maintenance op skips (tryLock with zero timeout) when a close is in progress. - Split the single maintenance thread pool into separate snapshot (high-priority) and cleanup (low-priority) pools, sized via snapshotToCleanupThreadRatio (getPoolSizes), so one operation type cannot starve the other. numStateStoreMaintenanceThreads is now a minimum of 2. - Add MaintenanceTask.triggerNow (at-most-one pending, processUnloadedOnly) and have the query thread trigger an immediate scheduler cycle when it queues a provider for unload, instead of waiting for the next periodic tick. Move the decoupled-maintenance tests out of StateStoreSuite into a new StateStoreDecoupledMaintenanceSuite and add tests for the RW lock, dual pools, and triggerNow. Ported from databricks-eng/runtime#201749 (SC-225785). Co-authored-by: Isaac --- .../apache/spark/sql/internal/SQLConf.scala | 29 +- .../streaming/state/StateStore.scala | 409 ++++-- .../streaming/state/StateStoreConf.scala | 9 +- .../StateStoreDecoupledMaintenanceSuite.scala | 1248 +++++++++++++++++ .../streaming/state/StateStoreSuite.scala | 594 -------- 5 files changed, 1560 insertions(+), 729 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreDecoupledMaintenanceSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ff2dd2dbd4833..a352842b65478 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2900,13 +2900,27 @@ object SQLConf { val NUM_STATE_STORE_MAINTENANCE_THREADS = buildConf("spark.sql.streaming.stateStore.numStateStoreMaintenanceThreads") .internal() - .doc("Number of threads in the thread pool that perform clean up and snapshotting tasks " + - "for stateful streaming queries. The default value is the number of cores * 0.25 " + - "so that this thread pool doesn't take too many resources " + - "away from the query and affect performance.") + .doc("Total number of threads split between the snapshot and cleanup " + + "maintenance pools for stateful streaming queries. Each pool needs at least " + + "1 thread, so the minimum is 2. The default value is the number of " + + "cores * 0.25 so that the pools don't take too many resources away from the " + + "query and affect performance. Use snapshotToCleanupThreadRatio to " + + "configure the split between snapshot and cleanup pools.") .intConf - .checkValue(_ > 0, "Must be greater than 0") - .createWithDefault(Math.max(Runtime.getRuntime.availableProcessors() / 4, 1)) + .checkValue(_ > 1, "Must be greater than 1") + .createWithDefault(Math.max(Runtime.getRuntime.availableProcessors() / 4, 2)) + + val STATE_STORE_MAINTENANCE_SNAPSHOT_THREAD_RATIO = + buildConf("spark.sql.streaming.stateStore.snapshotToCleanupThreadRatio") + .internal() + .doc("Ratio of total maintenance threads allocated to the snapshot " + + "pool. The remainder goes to the cleanup pool. The snapshot " + + "count is rounded to the nearest integer and clamped so each " + + "pool gets at least 1 thread and the total is never exceeded.") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .doubleConf + .checkValue(v => v > 0 && v < 1, "Must be between 0 and 1 (exclusive)") + .createWithDefault(0.5) val STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT = buildConf("spark.sql.streaming.stateStore.maintenanceShutdownTimeout") @@ -7788,6 +7802,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def numStateStoreMaintenanceThreads: Int = getConf(NUM_STATE_STORE_MAINTENANCE_THREADS) + def snapshotToCleanupThreadRatio: Double = + getConf(STATE_STORE_MAINTENANCE_SNAPSHOT_THREAD_RATIO) + def numStateStoreInstanceMetricsToReport: Int = getConf(STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 85de1f85a33d6..d3182da00d62d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.state import java.io.Closeable import java.util.UUID import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledFuture, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantReadWriteLock import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -832,6 +834,21 @@ trait StateStoreProvider { // being read on a maintenance thread. @volatile var unloaded: Boolean = false + /** + * Read-write lock for coordinating maintenance and close operations. + * - Read lock: held during snapshot/cleanup work (allows concurrent maintenance ops) + * - Write lock: held during close (waits for all maintenance to finish) + * This prevents close from racing with in-flight maintenance on the same provider. + * + * Lock ordering: maintenanceLock must be acquired before + * loadedProviders.synchronized to avoid ABBA deadlock. + * + * Passing fair=true to ensure fairness across reads and writes, + * so the write lock (close) is not starved by continuous read lock + * acquisitions (maintenance ops). + */ + val maintenanceLock: ReentrantReadWriteLock = new ReentrantReadWriteLock(true) + /** * Initialize the provide with more contextual information from the SQL operator. * This method will be called first after creating an instance of the StateStoreProvider by @@ -1321,30 +1338,68 @@ object StateStore extends Logging { * StateStoreProvider is also unloaded. Any exception that happens in the MaintenanceTask * is indeed exceptional and thus we let it propagate. */ - class MaintenanceTask(periodMs: Long, task: => Unit) { + class MaintenanceTask(periodMs: Long, task: Boolean => Unit) { private val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") + private def runTask(processUnloadedOnly: Boolean = false): Unit = { + try { + task(processUnloadedOnly) + } catch { + case NonFatal(e) => + logWarning(s"Error running maintenance task, " + + s"processUnloadedOnly=$processUnloadedOnly", e) + throw e + } + } + private val runnable = new Runnable { - override def run(): Unit = { + override def run(): Unit = runTask() + } + + private val future: ScheduledFuture[_] = executor.scheduleAtFixedRate( + runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) + + private val triggerPending = new AtomicBoolean(false) + + /** + * Submit a maintenance cycle to the scheduler executor. If the scheduler is + * idle, it runs immediately. If a cycle is already running, this queues behind + * it. If a triggered run is already queued, this is a no-op. The AtomicBoolean + * ensures at most one triggered run is pending at a time. The flag resets before + * execution so a new trigger can be queued while this one is running. + * + * @param processUnloadedOnly when true (default), only drains the unload + * queue without iterating loadedProviders. This avoids submitting + * unnecessary maintenance work for all providers. + */ + def triggerNow(processUnloadedOnly: Boolean = true): Unit = { + if (triggerPending.compareAndSet(false, true)) { try { - task + executor.execute(() => { + triggerPending.set(false) + runTask(processUnloadedOnly) + }) } catch { - case NonFatal(e) => - logWarning("Error running maintenance thread", e) - throw e + // Executor already shut down by stop(). Reset the flag for completeness. + case _: java.util.concurrent.RejectedExecutionException => + logWarning("triggerNow called after scheduler maintenance task stopped") + triggerPending.set(false) } } } - private val future: ScheduledFuture[_] = executor.scheduleAtFixedRate( - runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) - def stop(): Unit = { future.cancel(false) executor.shutdown() } + /** Stops the scheduler and waits for any in-flight cycle to finish. */ + def stopAndAwait(): Unit = { + stop() + executor.awaitTermination(10, TimeUnit.SECONDS) + } + def isRunning: Boolean = !future.isDone } @@ -1355,18 +1410,22 @@ object StateStore extends Logging { class MaintenanceThreadPool( numThreads: Int, shutdownTimeout: Long, - forceShutdownTimeout: Long) { - private val threadPool = ThreadUtils.newDaemonFixedThreadPool( - numThreads, "state-store-maintenance-thread") + forceShutdownTimeout: Long, + name: String) { + private val threadPool = ThreadUtils.newDaemonFixedThreadPool(numThreads, name) def execute(runnable: Runnable): Unit = { threadPool.execute(runnable) } - def stop(): Unit = { - logInfo("Shutting down MaintenanceThreadPool") + /** Initiate shutdown without waiting. Call awaitStop() to wait. */ + def shutdown(): Unit = { + logInfo(log"Shutting down MaintenanceThreadPool") threadPool.shutdown() // Disable new tasks from being submitted + } + /** Wait for threads to finish after shutdown() was called. */ + def awaitStop(): Unit = { // Wait a while for existing tasks to terminate if (!threadPool.awaitTermination(shutdownTimeout, TimeUnit.SECONDS)) { logWarning( @@ -1381,13 +1440,21 @@ object StateStore extends Logging { } } } + + def stop(): Unit = { + shutdown() + awaitStop() + } } @GuardedBy("loadedProviders") private var maintenanceTask: MaintenanceTask = null @GuardedBy("loadedProviders") - private var maintenanceThreadPool: MaintenanceThreadPool = null + private var highPriorityThreadPool: MaintenanceThreadPool = null + + @GuardedBy("loadedProviders") + private var lowPriorityThreadPool: MaintenanceThreadPool = null @GuardedBy("loadedProviders") private var _coordRef: StateStoreCoordinatorRef = null @@ -1542,14 +1609,22 @@ object StateStore extends Logging { // Queue the provider for maintenance and close. The maintenance scheduler will // drain the queue and submit tasks. remove() returning non-null ensures only one // queuer for this provider instance. - // TODO: trigger a scheduler cycle immediately so queued providers are processed - // without waiting for the next periodic tick. logInfo(log"Queuing provider from task thread for maintenance and close " + log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}." + taskContextIdLogLine + log"Removed provider from loadedProviders") unloadedProvidersToClose.add((id, provider, MaintenanceOpRequest.All)) }) }) + + // Submit a scheduler cycle so queued providers are processed without waiting for the + // next periodic tick, minimizing the time stale providers wait to be closed. Without + // this, we would wait up to 2 maintenance cycles for both operations to finish and the + // provider to be closed from the time it is queued. At most one triggered cycle can be + // pending at a time. + if (providerStatus.providerIdsToUnload.nonEmpty && maintenanceTask != null) { + maintenanceTask.triggerNow() + } + providerStatus.shouldForceSnapshotUpload } else { false @@ -1570,12 +1645,17 @@ object StateStore extends Logging { } /** - * Close a provider and release its resources. + * Close a provider and release its resources. No-op if already unloaded. * WARNING: CAN ONLY BE CALLED FROM MAINTENANCE THREAD! */ def closeProvider( storeProviderId: StateStoreProviderId, provider: StateStoreProvider): Unit = { + if (provider.unloaded) { + logInfo(log"Skipping close for ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, storeProviderId)}" + + log" because provider is already unloaded") + return + } logInfo(log"Closing ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, storeProviderId)}") try { provider.close() @@ -1623,24 +1703,37 @@ object StateStore extends Logging { * it can work-around a deadlock condition where a maintenance task is waiting for the lock * */ private[streaming] def stopMaintenanceTaskWithoutLock(): Unit = { - if (maintenanceThreadPool != null) { - maintenanceThreadPool.stop() - maintenanceThreadPool = null - snapshotPartitionsLock.synchronized { snapshotPartitions.clear() } - cleanupPartitionsLock.synchronized { cleanupPartitions.clear() } - } + // Stop the scheduler first so no new work is submitted to the pools. if (maintenanceTask != null) { - maintenanceTask.stop() + maintenanceTask.stopAndAwait() maintenanceTask = null } + // Shut down both pools concurrently, then await both, so we don't double the blocking time. + if (highPriorityThreadPool != null) highPriorityThreadPool.shutdown() + if (lowPriorityThreadPool != null) lowPriorityThreadPool.shutdown() + if (highPriorityThreadPool != null) { + highPriorityThreadPool.awaitStop() + highPriorityThreadPool = null + } + if (lowPriorityThreadPool != null) { + lowPriorityThreadPool.awaitStop() + lowPriorityThreadPool = null + } + snapshotPartitionsLock.synchronized { snapshotPartitions.clear() } + cleanupPartitionsLock.synchronized { cleanupPartitions.clear() } } /** Unload and stop all state store providers */ - def stop(): Unit = loadedProviders.synchronized { - loadedProviders.foreach { case (id, provider) => closeProvider(id, provider) } - loadedProviders.clear() - _coordRef = null - stopMaintenanceTask() + def stop(): Unit = { + // Stop scheduler and pools outside loadedProviders lock. Pool threads acquire + // maintenanceLock then loadedProviders.synchronized, so holding loadedProviders while + // awaiting termination would deadlock. + stopMaintenanceTaskWithoutLock() + loadedProviders.synchronized { + loadedProviders.foreach { case (id, provider) => closeProvider(id, provider) } + loadedProviders.clear() + _coordRef = null + } // Drain after stopping the pool to catch anything queued during shutdown. while (!unloadedProvidersToClose.isEmpty) { val (id, provider, _) = unloadedProvidersToClose.poll() @@ -1649,21 +1742,41 @@ object StateStore extends Logging { logInfo("StateStore stopped") } + /** + * Determines the number of threads for the snapshot and cleanup pools using the configured + * ratio. Snapshot gets the rounded value, clamped to [1, total - 1]. Cleanup gets the + * remainder. Each pool gets at least 1 thread and the total is never exceeded. + * @return (snapshotThreads, cleanupThreads) + */ + private[streaming] def getPoolSizes(storeConf: StateStoreConf): (Int, Int) = { + val total = storeConf.numStateStoreMaintenanceThreads + val ratio = storeConf.snapshotToCleanupThreadRatio + val snapshotBeforeClamp = math.round(total * ratio).toInt + // Clamp to [1, total - 1] so each pool gets at least 1 thread and total is never exceeded. + val snapshot = math.max(1, math.min(total - 1, snapshotBeforeClamp)) + val cleanup = total - snapshot + (snapshot, cleanup) + } + /** Start the periodic maintenance task if not already started and if Spark active */ private def startMaintenanceIfNeeded(storeConf: StateStoreConf): Unit = { - val numMaintenanceThreads = storeConf.numStateStoreMaintenanceThreads val maintenanceShutdownTimeout = storeConf.stateStoreMaintenanceShutdownTimeout val maintenanceForceShutdownTimeout = storeConf.stateStoreMaintenanceForceShutdownTimeout loadedProviders.synchronized { if (SparkEnv.get != null && !isMaintenanceRunning && !storeConf.unloadOnCommit) { maintenanceTask = new MaintenanceTask( storeConf.maintenanceInterval, - task = { doMaintenance(storeConf) } + task = { processUnloadedOnly => doMaintenance(storeConf, processUnloadedOnly) } ) - // TODO: split into separate snapshot and cleanup pools so one operation type - // cannot starve the other when pool threads are saturated. - maintenanceThreadPool = new MaintenanceThreadPool(numMaintenanceThreads, - maintenanceShutdownTimeout, maintenanceForceShutdownTimeout) + // Separate pools for snapshot and cleanup to prevent one operation type from starving + // the other when pool threads are saturated. + val (snapshotThreads, cleanupThreads) = getPoolSizes(storeConf) + highPriorityThreadPool = new MaintenanceThreadPool(snapshotThreads, + maintenanceShutdownTimeout, maintenanceForceShutdownTimeout, + "state-store-maintenance-high-priority") + lowPriorityThreadPool = new MaintenanceThreadPool(cleanupThreads, + maintenanceShutdownTimeout, maintenanceForceShutdownTimeout, + "state-store-maintenance-low-priority") logInfo("State Store maintenance task started") } } @@ -1693,7 +1806,9 @@ object StateStore extends Logging { * Execute background maintenance task in all the loaded store providers if they are still * the active instances according to the coordinator. */ - private def doMaintenance(storeConf: StateStoreConf): Unit = { + private def doMaintenance( + storeConf: StateStoreConf, + processUnloadedOnly: Boolean = false): Unit = { logDebug("Doing maintenance") if (SparkEnv.get == null) { throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores") @@ -1717,8 +1832,6 @@ object StateStore extends Logging { // All ops should run before the provider can be closed. We serialize them by // submitting one op now as FromTaskThread; when it completes it enqueues the // remaining op, which runs as FromUnloadedProvidersQueue and closes the provider. - // TODO: trigger the next scheduler cycle immediately after the first op completes - // so the remaining op is picked up without waiting for the next periodic tick. // Pick whichever partition set is available with short-circuit evaluation. tryClaimAndSubmit( providerId, provider, storeConf, @@ -1750,13 +1863,19 @@ object StateStore extends Logging { providersToRequeue.foreach(unloadedProvidersToClose.offer) // Phase 2: Submit separate snapshot and cleanup tasks for loaded providers. - loadedProviders.synchronized { - loadedProviders.toSeq - }.foreach { case (id, provider) => - tryClaimAndSubmit( - id, provider, storeConf, MaintenanceOpType.Snapshot, MaintenanceTaskType.FromLoadedProviders) - tryClaimAndSubmit( - id, provider, storeConf, MaintenanceOpType.Cleanup, MaintenanceTaskType.FromLoadedProviders) + // Skipped when processUnloadedOnly is true to avoid submitting unnecessary work for all + // providers. + if (!processUnloadedOnly) { + loadedProviders.synchronized { + loadedProviders.toSeq + }.foreach { case (id, provider) => + tryClaimAndSubmit( + id, provider, storeConf, + MaintenanceOpType.Snapshot, MaintenanceTaskType.FromLoadedProviders) + tryClaimAndSubmit( + id, provider, storeConf, + MaintenanceOpType.Cleanup, MaintenanceTaskType.FromLoadedProviders) + } } } @@ -1772,6 +1891,7 @@ object StateStore extends Logging { source: MaintenanceTaskType): Boolean = { if (tryClaimPartition(providerId, opType)) { submitMaintenanceWorkForProvider(providerId, provider, storeConf, source, opType) + logDebug(s"Submitted $providerId with source $source for $opType") true } else { logInfo(log"Not processing partition " + @@ -1806,101 +1926,134 @@ object StateStore extends Logging { storeConf: StateStoreConf, source: MaintenanceTaskType, opType: MaintenanceOpType): Unit = { - maintenanceThreadPool.execute(() => { + val pool = opType match { + case MaintenanceOpType.Snapshot => highPriorityThreadPool + case MaintenanceOpType.Cleanup => lowPriorityThreadPool + } + pool.execute(() => { + logDebug(s"Starting $opType maintenance for $id, source=$source") val startTime = System.currentTimeMillis() - // Check that this provider is still valid before doing work. - // TODO: the loadedProviders check is racy (membership could change between the check and - // doing the work). This is resolved by the read/write lock added in a later PR, which - // holds the read lock through the check and the work. - val canProcess = source match { - case FromLoadedProviders => - // Checks that the ID is still in loadedProviders and that the instance matches the - // one we were given. The scheduler submits from a stale copy of loadedProviders, so - // the provider may have been removed and replaced by a new instance under the same key. - loadedProviders.synchronized { loadedProviders.get(id).contains(provider) } && - !provider.unloaded - case _ => - // FromTaskThread / FromUnloadedProvidersQueue: provider already removed from - // loadedProviders, reference passed directly. - !provider.unloaded - } - try { - if (canProcess) { - // Do the maintenance work for this op type. - opType match { - case MaintenanceOpType.Snapshot => provider.doSnapshotMaintenance() - case MaintenanceOpType.Cleanup => provider.doCleanupMaintenance() - } + // We use a var instead of early return because `return` inside a closure (pool.execute) + // throws NonLocalReturnControl in Scala. + var canProcess = false + // If we can't acquire the lock, the write lock is held, which means another thread is + // closing this provider. The entire maintenance task is a no-op in that case, so we skip + // and free the pool thread rather than blocking. The zero timeout honors fair ordering so + // readers do not starve a queued writer. + val lockAcquired = provider.maintenanceLock.readLock().tryLock(0, TimeUnit.SECONDS) + try { + if (lockAcquired) { + canProcess = source match { + case FromLoadedProviders => + // Checks that the ID is still in loadedProviders and that the instance matches + // the one we were given. The scheduler submits from a stale copy of + // loadedProviders, so the provider may have been removed and replaced by a new + // instance under the same key. + loadedProviders.synchronized { loadedProviders.get(id).contains(provider) } && + !provider.unloaded + case _ => + // FromTaskThread / FromUnloadedProvidersQueue: provider already removed, + // reference passed directly, no containsKey needed. + !provider.unloaded + } + if (canProcess) { + // Do the maintenance work for this op type. + opType match { + case MaintenanceOpType.Snapshot => provider.doSnapshotMaintenance() + case MaintenanceOpType.Cleanup => provider.doCleanupMaintenance() + } - // Handle post-work actions based on source. - source match { - case FromLoadedProviders => - // Check if provider should be unloaded - if (!verifyIfStoreInstanceActive(id)) { - // Only remove if the map still holds the same provider instance we were given. - // Between verifyIfStoreInstanceActive and this remove, a concurrent get() may - // have loaded a new provider under the same key; removing by key alone would - // incorrectly remove the new provider. - val removed = loadedProviders.synchronized { - if (loadedProviders.get(id).contains(provider)) { - loadedProviders.remove(id) - } else { - None + // Dispatch based on source. FromLoadedProviders and FromTaskThread run inside the + // read lock so no close can interleave between work and enqueue. + // FromUnloadedProvidersQueue releases the read lock before acquiring the write lock. + source match { + case FromLoadedProviders => + // Check if provider should be unloaded + if (!verifyIfStoreInstanceActive(id)) { + // Only remove if the map still holds the same provider instance we were + // given. Between verifyIfStoreInstanceActive and this remove, a concurrent + // get() may have loaded a new provider under the same key. Removing by key + // alone would incorrectly remove the new provider. + val removed = loadedProviders.synchronized { + if (loadedProviders.get(id).contains(provider)) { + loadedProviders.remove(id) + } else { + None + } + } + if (removed.isDefined) { + unloadedProvidersToClose.add( + (id, provider, otherMaintenanceOpRequest(opType))) + logInfo(log"${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)}: " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)} verified inactive, " + + log"queued for close with " + + log"${MDC(LogKeys.OP_TYPE, otherMaintenanceOpRequest(opType))}") + } else { + logInfo(log"${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)}: " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)} verified inactive " + + log"but provider instance differs, skipping removal") + } } - } - // Queue for the other operation before close. - if (removed.isDefined) { + + case FromTaskThread => + // Provider already removed from loadedProviders by the query thread. Queue for + // the other operation before close. unloadedProvidersToClose.add((id, provider, otherMaintenanceOpRequest(opType))) - } + logInfo(log"${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)}: queued " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)} " + + log"for close with ${MDC(LogKeys.OP_TYPE, otherMaintenanceOpRequest(opType))}") + if (maintenanceTask != null) maintenanceTask.triggerNow() + + case FromUnloadedProvidersQueue => + // Release read lock, then acquire write lock to wait for any in-flight + // maintenance to finish. + provider.maintenanceLock.readLock().unlock() + provider.maintenanceLock.writeLock().lock() + try { + closeProvider(id, provider) + } finally { + // Downgrade: reacquire read lock while holding write lock, then release write + // lock. The outer finally unconditionally releases the read lock. + provider.maintenanceLock.readLock().lock() + provider.maintenanceLock.writeLock().unlock() + } } - - case FromTaskThread => - // Provider already removed from loadedProviders by the query thread. - // Queue for the other operation before close. - unloadedProvidersToClose.add((id, provider, otherMaintenanceOpRequest(opType))) - - case FromUnloadedProvidersQueue => - // This is the final operation before close. - closeProvider(id, provider) + } else { + logInfo(log"Skipping maintenance for " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}, " + + log"provider was removed from loadedProviders or already unloaded") + } + } else { + logDebug(s"Skipping $opType maintenance for $id, could not acquire read lock") } - } else { - logInfo(log"Skipping maintenance for " + - log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}, " + - log"provider was removed from loadedProviders") + } finally { + if (lockAcquired) provider.maintenanceLock.readLock().unlock() } } catch { case NonFatal(e) => logWarning(log"Error doing maintenance on provider:" + log" ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}. " + - log"Could not unload state store provider", e) - // When we get a non-fatal exception, we just unload the provider. - // - // By not bubbling the exception to the maintenance task thread or the query execution - // thread, it's possible for a maintenance thread pool task to continue failing on - // the same partition. Additionally, if there is some global issue that will cause - // all maintenance thread pool tasks to fail, then bubbling the exception and - // stopping the pool is faster than waiting for all tasks to see the same exception. - // - // However, we assume that repeated failures on the same partition and global issues - // are rare. The benefit to unloading just the partition with an exception is that - // transient issues on a given provider do not affect any other providers; so, in - // most cases, this should be a more performant solution. - source match { - case FromTaskThread | FromUnloadedProvidersQueue => - closeProvider(id, provider) - - case FromLoadedProviders => - // Only remove if the map still holds the same provider instance. A concurrent - // get() may have loaded a new provider under the same key. - loadedProviders.synchronized { - if (loadedProviders.get(id).contains(provider)) { - loadedProviders.remove(id) - } + log"Closing provider due to error", e) + if (source == FromLoadedProviders) { + // Only remove if the map still holds the same provider instance. A concurrent get() + // may have loaded a new provider under the same key. + loadedProviders.synchronized { + if (loadedProviders.get(id).contains(provider)) { + loadedProviders.remove(id) } - // Always close this provider instance regardless of whether we removed it from - // the map. Maintenance failed, so we must clean up this provider's resources. - closeProvider(id, provider) + } + } + // Acquire write lock before close to wait for any concurrent maintenance on the other + // pool thread to finish. + provider.maintenanceLock.writeLock().lock() + try { + // Always close this provider instance regardless of whether we removed it from the + // map. Maintenance failed, so we must clean up this provider's resources. We cannot + // rely on the queue to close the provider because maintenance may error again. + closeProvider(id, provider) + } finally { + provider.maintenanceLock.writeLock().unlock() } } finally { val duration = System.currentTimeMillis() - startTime diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index e19ac06732fa1..b63ab7596b004 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -31,10 +31,17 @@ class StateStoreConf( def this() = this(new SQLConf) /** - * Size of MaintenanceThreadPool to perform maintenance tasks for StateStore + * Total number of maintenance threads. Split between the snapshot and cleanup thread + * pools. Each pool needs at least 1 thread, so the minimum is 2. */ val numStateStoreMaintenanceThreads: Int = sqlConf.numStateStoreMaintenanceThreads + /** + * Ratio of threads for the snapshot pool. The remainder goes to cleanup. Each pool gets at + * least 1 thread and the total is never exceeded. + */ + val snapshotToCleanupThreadRatio: Double = sqlConf.snapshotToCleanupThreadRatio + /** * Timeout for state store maintenance operations to complete on shutdown */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreDecoupledMaintenanceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreDecoupledMaintenanceSuite.scala new file mode 100644 index 0000000000000..ace25da58945f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreDecoupledMaintenanceSuite.scala @@ -0,0 +1,1248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.mutable + +import org.apache.hadoop.conf.Configuration +import org.apache.logging.log4j.Level +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.STATE_STORE_PROVIDER_CLASS +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.tags.ExtendedSQLTest + +/** + * A fake StateStoreProvider that gives tests deterministic control over + * snapshot and cleanup maintenance timing using a latch-based handshake. + * + * Each operation follows a two-phase pattern: + * 1. The op counts down its enteredLatch, telling the test "I'm running." + * 2. The op blocks on its continueSignal until the test counts it down. + * + * The scheduler runs normally and submits tasks to the pools, but those + * tasks block inside the provider's maintenance methods. This holds the + * pool threads mid-execution, keeping downstream logic (source handling, + * queue routing, close) from running until the test releases the latch. + */ +class BlockingMaintenanceProvider extends StateStoreProvider + with Logging { + private var id: StateStoreId = null + + // Per-instance state. No shared static fields, so stale scheduler + // cycles from a previous test use the old instance's latches (already + // counted down) and finish immediately. No cross-test interference. + @volatile var snapshotThreadName: String = "" + @volatile var cleanupThreadName: String = "" + @volatile var closeThreadName: String = "" + @volatile var snapshotShouldThrow: Boolean = false + @volatile var cleanupShouldThrow: Boolean = false + @volatile var closeShouldBlock: Boolean = false + + val snapshotEnteredLatch = new CountDownLatch(1) + val cleanupEnteredLatch = new CountDownLatch(1) + val snapshotContinueSignal = new CountDownLatch(1) + val cleanupContinueSignal = new CountDownLatch(1) + val closeEnteredLatch = new CountDownLatch(1) + val closeContinueSignal = new CountDownLatch(1) + + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useColumnFamilies: Boolean, + storeConfs: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false, + stateSchemaProvider: Option[StateSchemaProvider] = None + ): Unit = { + id = stateStoreId + } + + override def stateStoreId: StateStoreId = id + + override def close(): Unit = { + closeThreadName = Thread.currentThread.getName + if (closeShouldBlock) { + closeEnteredLatch.countDown() + closeContinueSignal.await() + } + } + + /** Returns null because tests using this provider do not need a real + * store. They only exercise the maintenance scheduler and close paths. */ + override def getStore( + version: Long, + uniqueId: Option[String], + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = null + + /** Signals entry, then blocks until the test releases the continue latch. */ + override def doSnapshotMaintenance(): Unit = { + snapshotThreadName = Thread.currentThread.getName + logInfo(s"Snapshot maintenance entered on" + + s" ${Thread.currentThread.getName}") + snapshotEnteredLatch.countDown() + snapshotContinueSignal.await() + logInfo(s"Snapshot maintenance continuing on" + + s" ${Thread.currentThread.getName}") + if (snapshotShouldThrow) { + throw new RuntimeException("snapshot error") + } + } + + /** Same handshake as doSnapshotMaintenance but for cleanup. */ + override def doCleanupMaintenance(): Unit = { + cleanupThreadName = Thread.currentThread.getName + logInfo(s"Cleanup maintenance entered on" + + s" ${Thread.currentThread.getName}") + cleanupEnteredLatch.countDown() + cleanupContinueSignal.await() + logInfo(s"Cleanup maintenance continuing on" + + s" ${Thread.currentThread.getName}") + if (cleanupShouldThrow) { + throw new RuntimeException("cleanup error") + } + } +} + +class FakeStateStoreProviderTracksCloseThread extends StateStoreProvider { + import FakeStateStoreProviderTracksCloseThread._ + private var id: StateStoreId = null + + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useColumnFamilies: Boolean, + storeConfs: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false, + stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = { + id = stateStoreId + } + + override def stateStoreId: StateStoreId = id + + override def close(): Unit = { + closeThreadNames = Thread.currentThread.getName :: closeThreadNames + } + + override def getStore( + version: Long, + uniqueId: Option[String], + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = null +} + +private object FakeStateStoreProviderTracksCloseThread { + var closeThreadNames: List[String] = Nil +} + +@ExtendedSQLTest +abstract class StateStoreDecoupledMaintenanceSuiteBase[ + ProviderClass <: StateStoreProvider] + extends SparkFunSuite + with BeforeAndAfter + with PrivateMethodTester { + + import StateStoreCoordinatorSuite._ + + before { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + + after { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + + private def getDefaultSQLConf( + minDeltasForSnapshot: Int, + numOfVersToRetainInMemory: Int): SQLConf = { + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, + minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, + numOfVersToRetainInMemory) + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + sqlConf + } + + private def getUnloadQueue() = { + val f = PrivateMethod[ConcurrentLinkedQueue[(StateStoreProviderId, + StateStoreProvider, MaintenanceOpRequest)]]( + Symbol("unloadedProvidersToClose")) + StateStore invokePrivate f() + } + + private def getSnapshotPartitions() = { + val f = PrivateMethod[mutable.HashSet[StateStoreProviderId]]( + Symbol("snapshotPartitions")) + StateStore invokePrivate f() + } + + private def getCleanupPartitions() = { + val f = PrivateMethod[mutable.HashSet[StateStoreProviderId]]( + Symbol("cleanupPartitions")) + StateStore invokePrivate f() + } + + private def getLoadedProviders() = { + val f = PrivateMethod[ + mutable.HashMap[StateStoreProviderId, StateStoreProvider]]( + Symbol("loadedProviders")) + StateStore invokePrivate f() + } + + private def getMaintenanceTask() = { + val f = PrivateMethod[StateStore.MaintenanceTask]( + Symbol("maintenanceTask")) + StateStore invokePrivate f() + } + + private def getBlockingProvider( + id: StateStoreProviderId): BlockingMaintenanceProvider = { + val loaded = getLoadedProviders() + loaded.synchronized { loaded.get(id).get } + .asInstanceOf[BlockingMaintenanceProvider] + } + + private def maintenanceStoreConf( + providerClass: Class[_], + interval: Long = 100L, + numThreads: Int = 4): StateStoreConf = { + val sqlConf = getDefaultSQLConf( + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get) + sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, interval) + sqlConf.setConf(SQLConf.NUM_STATE_STORE_MAINTENANCE_THREADS, numThreads) + sqlConf.setConf(STATE_STORE_PROVIDER_CLASS, providerClass.getName) + new StateStoreConf(sqlConf) + } + + private def loadNullProvider( + dir: String, + storeConf: StateStoreConf, + partition: Int = 0): StateStoreProviderId = { + val storeId = StateStoreProviderId( + StateStoreId(dir, 0, partition), UUID.randomUUID) + StateStore.get(storeId, null, null, NoPrefixKeyStateEncoderSpec(null), 0, + stateStoreCkptId = None, stateSchemaBroadcast = None, + useColumnFamilies = false, storeConf, new Configuration()) + storeId + } + + test("SPARK-51596: task thread unload lifecycle " + + "from queue to close") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + // Long interval so close can only happen via triggerNow, not the + // periodic scheduler tick. Without triggerNow the test fails at the + // snapshot latch because the first cycle never fires within the 10s + // timeout. + val storeConf = maintenanceStoreConf( + classOf[BlockingMaintenanceProvider], interval = 30000L) + val id1 = loadNullProvider("lifecycle", storeConf) + val bp = getBlockingProvider(id1) + + val queue = getUnloadQueue() + assert(StateStore.isLoaded(id1)) + assert(queue.isEmpty, "Queue should start empty") + + // Make stale and load another provider to trigger task thread queueing. + // Use a non-blocking provider for id2 since we don't need to + // observe its maintenance. + coordinatorRef.reportActiveInstance(id1, "otherhost", "otherexec", Seq.empty) + val storeConf2 = maintenanceStoreConf( + classOf[FakeStateStoreProviderTracksCloseThread]) + val id2 = loadNullProvider("lifecycle", storeConf2, partition = 1) + + assert(!StateStore.isLoaded(id1), "Provider1 should be removed") + assert(StateStore.isLoaded(id2), "Provider2 should still be loaded") + + // The task thread queued id1 with All. We can't peek the queue + // here because triggerNow fires immediately after queueing, draining + // it before we can inspect. snapshotEnteredLatch proves the entry + // was consumed and snapshot was submitted. + + // Step 2: Scheduler (via triggerNow) submits first op as + // FromTaskThread. Snapshot enters latch. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS) && + bp.snapshotEnteredLatch.getCount == 0, "snapshot should have started") + + // Step 3: Release snapshot. Post-work queues remaining op (Cleanup) + // via otherMaintenanceOpRequest. + bp.snapshotContinueSignal.countDown() + + // Ideally we would peek the queue here to verify the entry is Cleanup + // (via otherMaintenanceOpRequest). But triggerNow drains it before we + // can peek. Instead, cleanupEnteredLatch being counted down proves + // Cleanup was queued and submitted. If the entry were Snapshot, + // doSnapshotMaintenance would have been called instead. + + // Step 4: Scheduler (via triggerNow) picks up Cleanup as + // FromUnloadedProvidersQueue. Cleanup enters. + assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS) && + bp.cleanupEnteredLatch.getCount == 0, "cleanup should have started") + + // Verify intermediate queue state: snapshot's post-work queued Cleanup + // and scheduler drained it. Cleanup is now running, queue is empty. + assert(queue.isEmpty, "queue should be drained while cleanup runs") + + // Step 5: Release cleanup. FromUnloadedProvidersQueue calls closeProvider. + bp.cleanupContinueSignal.countDown() + + // Verify provider was closed on cleanup pool. + eventually(timeout(10.seconds)) { + assert(bp.closeThreadName.contains( + "state-store-maintenance-low-priority"), + "close should happen on cleanup pool thread, but was on: " + + bp.closeThreadName) + assert(!StateStore.isLoaded(id1), + "provider should be removed from loadedProviders") + assert(queue.isEmpty, "Queue should be drained") + } + } + } + } + + test("tryClaimPartition returns true first call, false second, " + + "true for different opType") { + val id = StateStoreProviderId( + StateStoreId("dir", 0, 0), UUID.randomUUID) + val id2 = StateStoreProviderId( + StateStoreId("dir", 0, 1), UUID.randomUUID) + + try { + // First claim for snapshot succeeds + assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) + // Second claim for same id + opType fails + assert(!StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) + // Claim for same id but different opType succeeds + assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Cleanup)) + // That one is also occupied now + assert(!StateStore.tryClaimPartition(id, MaintenanceOpType.Cleanup)) + + // Different id can still claim both + assert(StateStore.tryClaimPartition(id2, MaintenanceOpType.Snapshot)) + assert(StateStore.tryClaimPartition(id2, MaintenanceOpType.Cleanup)) + } finally { + getSnapshotPartitions().clear() + getCleanupPartitions().clear() + } + } + + test("otherMaintenanceOpRequest maps correctly") { + assert(StateStore.otherMaintenanceOpRequest(MaintenanceOpType.Snapshot) + === MaintenanceOpRequest.Cleanup) + assert(StateStore.otherMaintenanceOpRequest(MaintenanceOpType.Cleanup) + === MaintenanceOpRequest.Snapshot) + } + + test("closeProvider sets unloaded even if close() throws") { + val storeId = StateStoreProviderId(StateStoreId("closeTest", 0, 0), UUID.randomUUID) + val callOrder = new mutable.ArrayBuffer[String]() + val provider = new FakeStateStoreProviderTracksCloseThread { + override def close(): Unit = { + callOrder += "close" + throw new RuntimeException("close failed") + } + override def setUnloaded(): Unit = { + callOrder += "setUnloaded" + super.setUnloaded() + } + } + provider.init( + storeId.storeId, null, null, NoPrefixKeyStateEncoderSpec(null), + useColumnFamilies = false, null, null) + + assert(!provider.unloaded) + intercept[RuntimeException] { + StateStore.closeProvider(storeId, provider) + } + assert(provider.unloaded, + "setUnloaded should run even if close() throws") + assert(callOrder === Seq("close", "setUnloaded"), + "setUnloaded should run after close() even if it throws") + } + + test("concurrent snapshot and cleanup on same provider " + + "both succeed") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val storeId = loadNullProvider("concurrentDir", storeConf) + val bp = getBlockingProvider(storeId) + + // Wait for both snapshot and cleanup to enter + assert(bp.snapshotEnteredLatch + .await(30, TimeUnit.SECONDS), "snapshot should have started") + assert(bp.cleanupEnteredLatch + .await(30, TimeUnit.SECONDS), "cleanup should have started") + + // Scheduler is no longer needed. Stop and wait so no new + // cycles interfere with assertions below. + getMaintenanceTask().stopAndAwait() + + // Both are running on their respective pool threads. + assert(bp.snapshotThreadName + .startsWith("state-store-maintenance-high-priority")) + assert(bp.cleanupThreadName + .startsWith("state-store-maintenance-low-priority")) + + // Partition sets should be claimed while both are running. + assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Snapshot), + "snapshot partition set should be occupied") + assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Cleanup), + "cleanup partition set should be occupied") + + // Read lock should be held while maintenance ops are running. + assert(bp.maintenanceLock.getReadLockCount == 2, + "both pool threads should hold the read lock") + + // Release both to finish + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + // Verify all ops completed by checking partition sets and read + // lock are released. Use reflection to read the sets without + // claiming (tryClaimPartition has side effects that break + // eventually retries). + eventually(timeout(10.seconds)) { + assert(!getSnapshotPartitions().contains(storeId), + "snapshot partition set should be released") + assert(!getCleanupPartitions().contains(storeId), + "cleanup partition set should be released") + assert(bp.maintenanceLock.getReadLockCount == 0, + "read lock should be released after maintenance completes") + } + } + } + } + + test("partition sets and locks released when maintenance throws, " + + "write lock blocks until read lock is freed") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val storeId = loadNullProvider("errorDir", storeConf) + val bp = getBlockingProvider(storeId) + bp.snapshotShouldThrow = true + bp.closeShouldBlock = true + + // Wait for all ops to enter. + assert(bp.snapshotEnteredLatch + .await(10, TimeUnit.SECONDS), "snapshot should start") + assert(bp.cleanupEnteredLatch + .await(10, TimeUnit.SECONDS), "cleanup should start") + + // Scheduler is no longer needed. Stop and wait so no new + // cycles interfere with assertions below. + getMaintenanceTask().stopAndAwait() + + // Partition set is claimed while running. + assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Snapshot), + "snapshot set should be occupied") + + assert(bp.maintenanceLock.getReadLockCount == 2, + "both pool threads should hold the read lock") + + // Release snapshot. It will throw. The error handler tries to + // acquire the write lock, but cleanup still holds a read lock. + bp.snapshotContinueSignal.countDown() + + // Wait for the error handler to be blocked on the write lock. + eventually(timeout(5.seconds)) { + assert(bp.maintenanceLock.getQueueLength > 0, + "error handler should be waiting for write lock") + } + // Close has not been entered because the write lock is blocked. + assert(bp.closeEnteredLatch.getCount == 1, + "close should not be entered while write lock is blocked") + + // Release cleanup. Read lock freed. Write lock unblocks. Close called. + bp.cleanupContinueSignal.countDown() + + assert(bp.closeEnteredLatch + .await(10, TimeUnit.SECONDS), "close should be called") + assert(bp.maintenanceLock.isWriteLocked, + "write lock should be held during close") + + // Release close to let error handler finish. + bp.closeContinueSignal.countDown() + + // Wait for the error and finally block to complete. + eventually(timeout(10.seconds)) { + assert(!StateStore.isLoaded(storeId), + "provider should be unloaded after throw") + assert(bp.closeThreadName.nonEmpty, + "provider should be closed after throw") + assert(!getSnapshotPartitions().contains(storeId), + "snapshot set should be released by finally block after throw") + assert(bp.maintenanceLock.getReadLockCount == 0, + "read lock should be released after error handling") + assert(!bp.maintenanceLock.isWriteLocked, + "write lock should be released after error handling") + } + } + } + } + + private def testRequeue(opRequest: MaintenanceOpRequest): Unit = { + val logAppender = new LogAppender("requeue-log", maxEvents = 100) + logAppender.setThreshold(Level.INFO) + withLogAppender(logAppender, level = Some(Level.INFO)) { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val storeId = loadNullProvider("requeueDir", storeConf) + val bp = getBlockingProvider(storeId) + + // Wait for both ops to enter, occupying both partition sets. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Add an entry to the queue. The scheduler will try to drain + // it but the partition set is occupied (held by the blocked + // task above), so it should be requeued. + val queue = getUnloadQueue() + queue.add((storeId, bp, opRequest)) + + // Wait for the scheduler to attempt draining and verify + // the requeue log. Queue checks are inside eventually to + // handle the momentary gap between the poll removing the + // entry and offer putting the entry back. + // Queue should still have the entry. + eventually(timeout(10.seconds)) { + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("Had to requeue")), + s"scheduler should have logged requeue for $opRequest") + val peeked = queue.peek() + assert(peeked != null, + s"$opRequest entry should have been requeued") + val (requeuedId, _, requeuedOp) = peeked + assert(requeuedId == storeId) + assert(requeuedOp == opRequest) + } + + // Clean up: release latches + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + queue.clear() + } + } + } + } + + test("Snapshot entry requeues when snapshot partition set is occupied") { + testRequeue(MaintenanceOpRequest.Snapshot) + } + + test("Cleanup entry requeues when cleanup partition set is occupied") { + testRequeue(MaintenanceOpRequest.Cleanup) + } + + test("All entry requeues when both partition sets are occupied") { + testRequeue(MaintenanceOpRequest.All) + } + + test("When MaintenanceOpRequest is All, cleanup is submitted " + + "if snapshot partition set is occupied") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + // Long interval so we can set up before the first + // cycle fires (5s initial delay). + val storeConf = maintenanceStoreConf( + classOf[BlockingMaintenanceProvider], interval = 5000L) + val id = loadNullProvider("shortCircuit", storeConf) + val bp = getBlockingProvider(id) + + try { + // Claim snapshot partition set. + assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) + + // Remove from loadedProviders so scheduler doesn't submit + // cleanup for this provider by iterating through it. + // Only the queue entry should submit cleanup. + getLoadedProviders().synchronized { getLoadedProviders().remove(id) } + + // Add All entry. Scheduler's first cycle (at 5s) drains it, + // tries snapshot (occupied), falls through to cleanup. + // Cleanup blocks on bp's latch, keeping the partition set. + getUnloadQueue().add((id, bp, MaintenanceOpRequest.All)) + + eventually(timeout(10.seconds)) { + assert(getCleanupPartitions().contains(id), + "cleanup should be claimed") + assert(getUnloadQueue().isEmpty, "queue should be drained") + } + } finally { + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + getSnapshotPartitions().remove(id) + } + } + } + } + + test("canProcess is false and maintenance is skipped " + + "when provider has already been unloaded") { + val logAppender = new LogAppender("canProcess-log", maxEvents = 100) + logAppender.setThreshold(Level.INFO) + withLogAppender(logAppender, level = Some(Level.INFO)) { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = maintenanceStoreConf( + classOf[BlockingMaintenanceProvider], interval = 5000L) + val id = loadNullProvider("canProcess", storeConf) + val bp = getBlockingProvider(id) + + // Mark unloaded before the first maintenance cycle fires (5s + // initial delay). canProcess checks !provider.unloaded and will + // return false, skipping maintenance entirely. + bp.setUnloaded() + + // Wait for the "Skipping maintenance" log proving canProcess + // was false and maintenance was skipped. + eventually(timeout(10.seconds)) { + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("Skipping maintenance")), + "should log skipping maintenance for unloaded provider") + } + assert(bp.snapshotEnteredLatch.getCount == 1, + "snapshot should not have entered") + assert(bp.cleanupEnteredLatch.getCount == 1, + "cleanup should not have entered") + } + } + } + } + + test("canProcess is false and maintenance is skipped " + + "when provider instance differs") { + val logAppender = new LogAppender("canProcess-stale-log", maxEvents = 100) + logAppender.setThreshold(Level.INFO) + withLogAppender(logAppender, level = Some(Level.INFO)) { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + // Long interval so we can block the cleanup pool before the + // first cycle fires. numThreads=2 gives each pool 1 thread. + val storeConf = maintenanceStoreConf( + classOf[BlockingMaintenanceProvider], + interval = 5000L, numThreads = 2) + val id = loadNullProvider("canProcessStale", storeConf) + val bp = getBlockingProvider(id) + + // Block the cleanup pool's thread with a dummy task before the + // first cycle fires (5s away). + val cleanupPoolField = PrivateMethod[StateStore.MaintenanceThreadPool]( + Symbol("lowPriorityThreadPool")) + val cleanupPool = StateStore invokePrivate cleanupPoolField() + val blockLatch = new CountDownLatch(1) + cleanupPool.execute(() => blockLatch.await()) + + // Wait for the first cycle. Snapshot runs freely. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Stop the scheduler so the next cycle doesn't run the + // replacement (which lacks thread locals). + getMaintenanceTask().stopAndAwait() + + bp.snapshotContinueSignal.countDown() + + // Replace A with a different instance while cleanup is waiting. + val replacement = new FakeStateStoreProviderTracksCloseThread + replacement.init(id.storeId, null, null, + NoPrefixKeyStateEncoderSpec(null), + useColumnFamilies = false, null, null) + val loaded = getLoadedProviders() + loaded.synchronized { loaded.put(id, replacement) } + + // Release the dummy. Cleanup starts, canProcess sees + // contains(A) is false (replacement is there), skips. + blockLatch.countDown() + + eventually(timeout(10.seconds)) { + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("Skipping maintenance")), + "should log skipping maintenance for stale instance") + } + assert(bp.cleanupEnteredLatch.getCount == 1, + "cleanup should not have entered") + } + } + } + } + + test("FromLoadedProviders unload: reloaded provider is not removed nor queued") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val id = loadNullProvider("staleInstance", storeConf) + val bp = getBlockingProvider(id) + + // Mark as needing to be closed. + coordinatorRef.reportActiveInstance(id, "otherhost", "otherexec", Seq.empty) + + // Wait for snapshot to enter. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Stop only the scheduler (not the pools) so A's threads keep + // running but no new cycles fire after we replace. + getMaintenanceTask().stopAndAwait() + + // Replace provider A with a different instance while A is blocked. + // The scheduler is stopped so no maintenance runs on the + // replacement. When A finishes, loadedProviders.get(id).contains(A) + // is false (replacement is there), removal is skipped. + val replacement = new FakeStateStoreProviderTracksCloseThread + replacement.init(id.storeId, null, null, + NoPrefixKeyStateEncoderSpec(null), useColumnFamilies = false, null, null) + val loaded = getLoadedProviders() + loaded.synchronized { loaded.put(id, replacement) } + + // Release A. + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + // Wait for A's partition sets to be released. + eventually(timeout(10.seconds)) { + assert(!getSnapshotPartitions().contains(id), + "snapshot partition should be released") + } + + // A should NOT have removed the replacement from loadedProviders. + assert(StateStore.isLoaded(id), + "replacement provider should still be loaded") + // A should NOT have queued anything (instance differs, skip). + assert(getUnloadQueue().isEmpty, + "queue should be empty (stale instance skipped removal)") + } + } + } + + test("FromLoadedProviders unload: with concurrent ops, " + + "only one removes and queues") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val id = loadNullProvider("concurrentUnload", storeConf) + val bp = getBlockingProvider(id) + + // Make provider stale so both ops detect inactive in source handling. + coordinatorRef.reportActiveInstance(id, "otherhost", "otherexec", Seq.empty) + + // Wait for both ops to enter. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Stop the scheduler so no new cycles drain the queue before + // we can check its size. + getMaintenanceTask().stopAndAwait() + + // Release both. Both finish, both see !verifyIfStoreInstanceActive. + // Only one should remove from loadedProviders and queue. + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + val queue = getUnloadQueue() + eventually(timeout(10.seconds)) { + assert(!StateStore.isLoaded(id), "provider should be removed") + assert(queue.size() == 1, + "only one op should queue, the other should no-op") + } + } + } + } + + test("stale provider from loadedProviders is closed properly " + + "through the full queue routing lifecycle") { + // Load a provider, make it stale via the coordinator, then verify: + // 1. Snapshot and cleanup run on separate pools with partition sets claimed + // 2. Both pool threads hold the read lock (readLockCount == 2) + // 3. Snapshot detects inactive, queues Cleanup via otherMaintenanceOpRequest + // 4. Cleanup runs as FromUnloadedProvidersQueue and closes the provider + // 5. During close: write lock held, no read locks + // 6. After close: all locks released, provider removed, queue drained + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) + val id = loadNullProvider("decoupled", storeConf) + val bp = getBlockingProvider(id) + bp.closeShouldBlock = true + + // Make provider stale so source handling queues it + coordinatorRef.reportActiveInstance(id, "otherhost", "otherexec", Seq.empty) + + // Wait for both tasks to block + assert(bp.snapshotEnteredLatch + .await(5, TimeUnit.SECONDS), "Snapshot task did not start") + assert(bp.cleanupEnteredLatch + .await(5, TimeUnit.SECONDS), "Cleanup task did not start") + + // Snapshot and cleanup should run on separate pools + assert(bp.snapshotThreadName + .contains("state-store-maintenance-high-priority"), + s"Snapshot should run on snapshot pool, was: " + + s"${bp.snapshotThreadName}") + assert(bp.cleanupThreadName + .contains("state-store-maintenance-low-priority"), + s"Cleanup should run on cleanup pool, was: " + + s"${bp.cleanupThreadName}") + + // Both partition sets should be claimed + val snap = getSnapshotPartitions() + val clean = getCleanupPartitions() + assert(snap.contains(id), "Snapshot partition should be claimed") + assert(clean.contains(id), "Cleanup partition should be claimed") + + // Both pool threads hold the read lock. + assert(bp.maintenanceLock.getReadLockCount == 2, + "both pool threads should hold the read lock") + + val queue = getUnloadQueue() + + // Both blocked, queue should be empty, no close + assert(queue.isEmpty, "Queue should be empty while tasks are blocked") + assert(bp.closeThreadName.isEmpty, + "No close while tasks are blocked") + + // Release snapshot. It will detect inactive, remove from + // loadedProviders, and queue the other op. Cleanup stays blocked, + // holding its partition set, so the scheduler cannot process the + // queue entry. + bp.snapshotContinueSignal.countDown() + + // Wait for snapshot source handling to complete. Check inside + // eventually to handle the momentary gap between the poll + // removing the entry and offer putting the entry back. + // Verify the queue entry. Snapshot completed first, so it queued + // otherMaintenanceOpRequest(Snapshot) = Cleanup. + eventually(timeout(5.seconds)) { + val peeked = queue.peek() + assert(peeked != null, "Queue should have one entry") + val (queuedId, _, opRequest) = peeked + assert(queuedId == id) + assert(opRequest == MaintenanceOpRequest.Cleanup, + s"Expected Cleanup, got $opRequest") + } + + // No close should have happened yet + assert(bp.closeThreadName.isEmpty, + "No close before final op runs") + + // Release cleanup. FromUnloadedProvidersQueue will release the + // read lock, acquire the write lock, and call close(). + bp.cleanupContinueSignal.countDown() + + // close() blocks on the latch. Write lock held, no read locks. + assert(bp.closeEnteredLatch + .await(10, TimeUnit.SECONDS), "close should be called") + assert(bp.maintenanceLock.isWriteLocked, + "write lock should be held during close") + assert(bp.maintenanceLock.getReadLockCount == 0, + "no read locks should be held during close") + + // Release close to let the downgrade and cleanup finish. + bp.closeContinueSignal.countDown() + + // Everything released. + eventually(timeout(10.seconds)) { + assert(bp.closeThreadName.nonEmpty, + "Provider should be closed") + assert(!StateStore.isLoaded(id), + "Provider should be removed from loadedProviders") + assert(queue.isEmpty, "Queue should be drained") + assert(snap.isEmpty && clean.isEmpty, + "Both partition sets should be released") + assert(bp.maintenanceLock.getReadLockCount == 0, + "read lock should be released after close") + assert(!bp.maintenanceLock.isWriteLocked, + "write lock should be released after close") + } + } + } + } + + test("scheduler maintenance triggerNow: at-most-one pending, no-op after stop") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + val storeConf = maintenanceStoreConf( + classOf[BlockingMaintenanceProvider], interval = 60000L) + val id = loadNullProvider("triggerNow", storeConf) + val bp = getBlockingProvider(id) + + // Reflection to extract maintenanceTask, its triggerPending flag, + // the underlying ScheduledThreadPoolExecutor, and loadedProviders. + val task = getMaintenanceTask() + val loaded = getLoadedProviders() + val pendingField = task.getClass.getDeclaredField("triggerPending") + val executorField = task.getClass.getDeclaredField("executor") + pendingField.setAccessible(true) + executorField.setAccessible(true) + val triggerPending = pendingField.get(task).asInstanceOf[AtomicBoolean] + val executor = executorField.get(task) + .asInstanceOf[java.util.concurrent.ScheduledThreadPoolExecutor] + + assert(!triggerPending.get(), "triggerPending should start as false") + + // Hold loadedProviders lock so doMaintenance blocks on Phase 2 + // (loadedProviders.synchronized). This keeps the scheduler executor + // busy, allowing us to test the at-most-one guard. + val lockHeld = new CountDownLatch(1) + val lockRelease = new CountDownLatch(1) + val lockThread = new Thread(() => { + loaded.synchronized { lockHeld.countDown(); lockRelease.await() } + }) + lockThread.start() + lockHeld.await() + + // triggerNow submits to scheduler executor. + // We wait until the triggered task starts executing, which will block on + // the loadedProviders lock. + // processUnloadedOnly=false so the triggered cycle iterates + // loadedProviders and blocks on the lock. + task.triggerNow(processUnloadedOnly = false) + eventually(timeout(5.seconds)) { + assert(!triggerPending.get(), "triggerPending reset when task started") + // Queue has 1 entry: the periodic future. The triggered task is + // currently executing (blocked on the loadedProviders lock). + assert(executor.getQueue.size() == 1, + s"expected 1 (periodic future), got ${executor.getQueue.size()}") + } + + // First call queues one pending run + task.triggerNow() + assert(triggerPending.get(), "one pending run queued") + // Queue has exactly 2: the pending triggered run + the periodic future + assert(executor.getQueue.size() == 2, + s"expected 2 queued tasks, got ${executor.getQueue.size()}") + + // Subsequent calls are no-ops (at-most-one pending). + // Queue size should not increase. + val queueSizeBefore = executor.getQueue.size() + task.triggerNow() + task.triggerNow() + assert(triggerPending.get(), "still one pending") + assert(executor.getQueue.size() == queueSizeBefore, + "queue size should not increase from no-op triggerNow calls") + + // Release the lock so the blocked cycle can proceed. After it + // finishes, the queued pending cycle runs automatically. + lockRelease.countDown() + + // Wait for the first cycle's snapshot to enter + assert(bp.snapshotEnteredLatch + .await(10, TimeUnit.SECONDS), "triggerNow should fire a cycle") + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + // Verify the queued pending task also ran: after both cycles + // complete, the executor queue should have only the periodic future + // left, and triggerPending should be false. + eventually(timeout(10.seconds)) { + assert(!triggerPending.get(), "pending task should have run") + assert(executor.getQueue.size() == 1, + "only the periodic future should remain in the queue") + } + + // After stop, triggerNow catches RejectedExecutionException and logs. + val logAppender = new LogAppender("triggerNow-warn", maxEvents = 100) + logAppender.setThreshold(Level.WARN) + withLogAppender(logAppender, level = Some(Level.WARN)) { + // Use WithoutLock to avoid deadlock from stopMaintenanceTask + // holding loadedProviders lock while awaiting pool termination, but + // pool threads needing to acquire the same lock. + StateStore.stopMaintenanceTaskWithoutLock() + task.triggerNow() + assert(!triggerPending.get(), "reset after rejection") + assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage + .contains("triggerNow called after scheduler maintenance task stopped")), + "should log warning on rejected execution") + } + } + } + } + + test("scheduler maintenance triggerNow: only unloaded queue is processed") { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + // Long interval so periodic cycles don't interfere. + val storeConf = maintenanceStoreConf( + classOf[BlockingMaintenanceProvider], interval = 60000L) + val id1 = loadNullProvider("triggerOnly", storeConf) + val bp1 = getBlockingProvider(id1) + + // Make id1 stale. Loading id2 calls reportActiveInstance which + // detects id1 as stale, queues it, and calls triggerNow. + coordinatorRef.reportActiveInstance( + id1, "otherhost", "otherexec", Seq.empty) + val id2 = loadNullProvider("triggerOnly", storeConf, partition = 1) + val bp2 = getBlockingProvider(id2) + + // triggerNow fires with processUnloadedOnly=true. + // id1 should enter (proves triggerNow processed the queue). + assert(bp1.snapshotEnteredLatch.await(10, TimeUnit.SECONDS), + "id1 should be processed from queue by triggerNow") + + // Give the triggered cycle time to finish. If phase 2 ran, + // id2 would have been submitted to the pool by now. + Thread.sleep(2000) + + // id2 should NOT have entered. The scheduler should NOT have + // iterated loadedProviders and submitted maintenance tasks. + assert(bp2.snapshotEnteredLatch.getCount == 1, + "id2 snapshot should not have been submitted by triggerNow") + assert(bp2.cleanupEnteredLatch.getCount == 1, + "id2 cleanup should not have been submitted by triggerNow") + + bp1.snapshotContinueSignal.countDown() + bp1.cleanupContinueSignal.countDown() + bp2.snapshotContinueSignal.countDown() + bp2.cleanupContinueSignal.countDown() + } + } + } + + private def poolSizeConf( + total: Int, + ratio: Double = 0.5): StateStoreConf = { + val sqlConf = getDefaultSQLConf( + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get) + sqlConf.setConf(SQLConf.NUM_STATE_STORE_MAINTENANCE_THREADS, total) + sqlConf.setConf( + SQLConf.STATE_STORE_MAINTENANCE_SNAPSHOT_THREAD_RATIO, ratio) + new StateStoreConf(sqlConf) + } + + test("getPoolSizes: ratio based split") { + // Default ratio 0.5: even split. + assert(StateStore.getPoolSizes(poolSizeConf(2)) === (1, 1)) + assert(StateStore.getPoolSizes(poolSizeConf(3)) === (2, 1)) + assert(StateStore.getPoolSizes(poolSizeConf(4)) === (2, 2)) + assert(StateStore.getPoolSizes(poolSizeConf(5)) === (3, 2)) + assert(StateStore.getPoolSizes(poolSizeConf(8)) === (4, 4)) + assert(StateStore.getPoolSizes(poolSizeConf(100)) === (50, 50)) + + // Custom ratios. + assert(StateStore.getPoolSizes(poolSizeConf(8, ratio = 0.75)) === (6, 2)) + assert(StateStore.getPoolSizes(poolSizeConf(8, ratio = 0.25)) === (2, 6)) + assert(StateStore.getPoolSizes(poolSizeConf(10, ratio = 0.8)) === (8, 2)) + assert(StateStore.getPoolSizes(poolSizeConf(10, ratio = 0.1)) === (1, 9)) + + // Fractional rounding (math.round rounds 0.5 up). + assert(StateStore.getPoolSizes(poolSizeConf(10, ratio = 1.0/3)) === (3, 7)) + assert(StateStore.getPoolSizes(poolSizeConf(11, ratio = 0.5)) === (6, 5)) + assert(StateStore.getPoolSizes(poolSizeConf(7, ratio = 0.3)) === (2, 5)) + + // Each pool gets at least 1 thread. Total is never exceeded. + assert(StateStore.getPoolSizes(poolSizeConf(2, ratio = 0.99)) === (1, 1)) + assert(StateStore.getPoolSizes(poolSizeConf(2, ratio = 0.01)) === (1, 1)) + assert(StateStore.getPoolSizes(poolSizeConf(4, ratio = 0.99)) === (3, 1)) + assert(StateStore.getPoolSizes(poolSizeConf(4, ratio = 0.01)) === (1, 3)) + assert(StateStore.getPoolSizes(poolSizeConf(10, ratio = 0.99)) === (9, 1)) + assert(StateStore.getPoolSizes(poolSizeConf(10, ratio = 0.01)) === (1, 9)) + + Seq(2, 3, 6, 7, 12, 15, 20, 50).foreach { total => + Seq(0.05, 0.15, 0.33, 0.5, 0.67, 0.85, 0.95).foreach { ratio => + val (s, c) = StateStore.getPoolSizes(poolSizeConf(total, ratio)) + assert(s + c == total, s"total=$total ratio=$ratio: $s + $c != $total") + assert(s >= 1, s"total=$total ratio=$ratio: snapshot=$s < 1") + assert(c >= 1, s"total=$total ratio=$ratio: cleanup=$c < 1") + } + } + } + + /** + * Proves that a full pool does not starve the other. Provider A fills + * the specified pool. Provider B on a different partition can still + * run the other op type. + * @param blockSnapshot if true, fill snapshot pool; if false, fill cleanup pool + */ + private def testPoolIsolation(blockSnapshot: Boolean): Unit = { + val conf = new SparkConf().setMaster("local").setAppName("test") + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { _ => + // numThreads=2: each pool gets exactly 1 thread. + val storeConf = maintenanceStoreConf( + classOf[BlockingMaintenanceProvider], numThreads = 2) + val id1 = loadNullProvider("poolIsolation", storeConf) + val bp1 = getBlockingProvider(id1) + + // Wait for both of A's ops to enter. + assert(bp1.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + assert(bp1.cleanupEnteredLatch.await(10, TimeUnit.SECONDS)) + + // Release the op we DON'T want to fill, keeping the other + // pool's only thread occupied. + if (blockSnapshot) bp1.cleanupContinueSignal.countDown() + else bp1.snapshotContinueSignal.countDown() + + // Load B on a different partition. + val id2 = loadNullProvider("poolIsolation", storeConf, partition = 1) + val bp2 = getBlockingProvider(id2) + + if (blockSnapshot) { + assert(bp2.cleanupEnteredLatch.await(10, TimeUnit.SECONDS), + "B's cleanup should run despite snapshot pool being full") + assert(bp2.snapshotEnteredLatch.getCount == 1, + "B's snapshot should not have entered") + } else { + assert(bp2.snapshotEnteredLatch.await(10, TimeUnit.SECONDS), + "B's snapshot should run despite cleanup pool being full") + assert(bp2.cleanupEnteredLatch.getCount == 1, + "B's cleanup should not have entered") + } + + // Release everything. + bp1.snapshotContinueSignal.countDown() + bp1.cleanupContinueSignal.countDown() + bp2.snapshotContinueSignal.countDown() + bp2.cleanupContinueSignal.countDown() + } + } + } + + test("full snapshot pool does not prevent cleanup from running") { + testPoolIsolation(blockSnapshot = true) + } + + test("full cleanup pool does not prevent snapshot from running") { + testPoolIsolation(blockSnapshot = false) + } + + test("maintenance op skips when write lock is held during close") { + val conf = new SparkConf().setMaster("local").setAppName("test") + val logAppender = new LogAppender("tryLock-skip", maxEvents = 100) + logAppender.setThreshold(Level.DEBUG) + val loggerName = StateStore.getClass.getName.stripSuffix("$") + withLogAppender(logAppender, + loggerNames = Seq(loggerName), level = Some(Level.DEBUG)) { + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + val storeConf = maintenanceStoreConf( + classOf[BlockingMaintenanceProvider]) + val id = loadNullProvider("writeLockBlock", storeConf) + val bp = getBlockingProvider(id) + bp.closeShouldBlock = true + + // Make stale so the close path is triggered. + coordinatorRef.reportActiveInstance( + id, "otherhost", "otherexec", Seq.empty) + + // Wait for both ops to enter, release both. One detects + // inactive, queues remaining op. That op runs as + // FromUnloadedProvidersQueue, acquires write lock, calls close, + // blocks on closeShouldBlock. + assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) + assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS)) + bp.snapshotContinueSignal.countDown() + bp.cleanupContinueSignal.countDown() + + // Wait for close to hold the write lock. + assert(bp.closeEnteredLatch.await(10, TimeUnit.SECONDS)) + assert(bp.maintenanceLock.isWriteLocked, + "write lock should be held during close") + + // Add provider back to the unload queue. Use All because we + // don't know which op is doing close (and holding that partition + // set). All tries both and submits whichever is free. The + // scheduler's next cycle drains the queue and submits to pool. + getUnloadQueue().add((id, bp, MaintenanceOpRequest.All)) + + // The pool thread calls tryLock(0, SECONDS) which returns + // false because the write lock is held. The op is skipped + // without blocking. The debug log proves the pool thread ran. + eventually(timeout(10.seconds)) { + assert(getUnloadQueue().isEmpty, + "scheduler should have consumed the queue entry") + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains( + "could not acquire read lock")), + "pool thread should have logged tryLock failure") + assert(bp.maintenanceLock.getQueueLength == 0, + "no thread should be blocked on the lock") + assert(bp.maintenanceLock.getReadLockCount == 0, + "read lock should not be acquired when write lock is held") + } + assert(bp.maintenanceLock.isWriteLocked, + "write lock should still be held during close") + + // Release close. + bp.closeContinueSignal.countDown() + + eventually(timeout(10.seconds)) { + assert(!bp.maintenanceLock.isWriteLocked, + "write lock should be released after close") + } + } + } + } + } +} + +class StateStoreDecoupledMaintenanceSuite + extends StateStoreDecoupledMaintenanceSuiteBase[HDFSBackedStateStoreProvider] + with SharedSparkSession { + override def beforeEach(): Unit = {} + override def afterEach(): Unit = {} +} + +class StateStoreDecoupledMaintenanceSuiteWithRowChecksum + extends StateStoreDecoupledMaintenanceSuite + with EnableStateStoreRowChecksum + +class RocksDBDecoupledMaintenanceSuite + extends StateStoreDecoupledMaintenanceSuiteBase[RocksDBStateStoreProvider] + with AlsoTestWithEncodingTypes + with AlsoTestWithRocksDBFeatures + with SharedSparkSession { + override def afterEach(): Unit = {} +} + +class RocksDBDecoupledMaintenanceSuiteWithRowChecksum + extends RocksDBDecoupledMaintenanceSuite + with EnableStateStoreRowChecksum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index c0b13ace4eefb..90b1098dab838 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -21,7 +21,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, IOException, import java.net.URI import java.util import java.util.UUID -import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable @@ -30,7 +29,6 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.logging.log4j.Level import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods import org.scalatest.{BeforeAndAfter, PrivateMethodTester} @@ -39,7 +37,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.LocalSparkContext._ -import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly @@ -55,121 +52,6 @@ import org.apache.spark.tags.ExtendedSQLTest import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -/** - * A test StateStoreProvider that counts how many times each split maintenance operation - * (snapshot, cleanup) runs and records the thread that closed it. Used to verify that the - * decoupled maintenance scheduler submits snapshot and cleanup as independent operations. - */ -/** - * A test StateStoreProvider whose split maintenance operations (snapshot and cleanup) block - * on per-instance latches, so tests can observe and control the decoupled maintenance - * lifecycle. Records the thread that ran each op and the thread that closed the provider. - */ -class BlockingMaintenanceProvider extends StateStoreProvider with Logging { - private var id: StateStoreId = null - - // Per-instance state. No shared static fields, so stale scheduler cycles from a previous - // test use the old instance's latches (already counted down) and finish immediately. No - // cross-test interference. - @volatile var snapshotThreadName: String = "" - @volatile var cleanupThreadName: String = "" - @volatile var closeThreadName: String = "" - @volatile var snapshotShouldThrow: Boolean = false - @volatile var cleanupShouldThrow: Boolean = false - - val snapshotEnteredLatch = new CountDownLatch(1) - val cleanupEnteredLatch = new CountDownLatch(1) - val snapshotContinueSignal = new CountDownLatch(1) - val cleanupContinueSignal = new CountDownLatch(1) - - override def init( - stateStoreId: StateStoreId, - keySchema: StructType, - valueSchema: StructType, - keyStateEncoderSpec: KeyStateEncoderSpec, - useColumnFamilies: Boolean, - storeConfs: StateStoreConf, - hadoopConf: Configuration, - useMultipleValuesPerKey: Boolean = false, - stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = { - id = stateStoreId - } - - override def stateStoreId: StateStoreId = id - - override def close(): Unit = { - closeThreadName = Thread.currentThread.getName - } - - // Returns null because tests using this provider do not need a real store. They only - // exercise the maintenance scheduler and close paths. - override def getStore( - version: Long, - uniqueId: Option[String], - forceSnapshotOnCommit: Boolean = false, - loadEmpty: Boolean = false): StateStore = null - - /** Signals entry, then blocks until the test releases the continue latch. */ - override def doSnapshotMaintenance(): Unit = { - snapshotThreadName = Thread.currentThread.getName - logInfo(s"Snapshot maintenance entered on ${Thread.currentThread.getName}") - snapshotEnteredLatch.countDown() - snapshotContinueSignal.await() - logInfo(s"Snapshot maintenance continuing on ${Thread.currentThread.getName}") - if (snapshotShouldThrow) { - throw new RuntimeException("snapshot error") - } - } - - /** Same handshake as doSnapshotMaintenance but for cleanup. */ - override def doCleanupMaintenance(): Unit = { - cleanupThreadName = Thread.currentThread.getName - logInfo(s"Cleanup maintenance entered on ${Thread.currentThread.getName}") - cleanupEnteredLatch.countDown() - cleanupContinueSignal.await() - logInfo(s"Cleanup maintenance continuing on ${Thread.currentThread.getName}") - if (cleanupShouldThrow) { - throw new RuntimeException("cleanup error") - } - } -} - -class FakeStateStoreProviderTracksCloseThread extends StateStoreProvider { - import FakeStateStoreProviderTracksCloseThread._ - private var id: StateStoreId = null - - override def init( - stateStoreId: StateStoreId, - keySchema: StructType, - valueSchema: StructType, - keyStateEncoderSpec: KeyStateEncoderSpec, - useColumnFamilies: Boolean, - storeConfs: StateStoreConf, - hadoopConf: Configuration, - useMultipleValuesPerKey: Boolean = false, - stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = { - id = stateStoreId - } - - override def stateStoreId: StateStoreId = id - - override def close(): Unit = { - closeThreadNames = Thread.currentThread.getName :: closeThreadNames - } - - override def getStore( - version: Long, - uniqueId: Option[String], - forceSnapshotOnCommit: Boolean = false, - loadEmpty: Boolean = false): StateStore = null - - override def doMaintenance(): Unit = {} -} - -private object FakeStateStoreProviderTracksCloseThread { - var closeThreadNames: List[String] = Nil -} - // MaintenanceErrorOnCertainPartitionsProvider is a test-only provider that throws an // exception during maintenance for partitions 0 and 1 (these are arbitrary choices). It is // used to test that an exception in a single provider's maintenance does not affect other @@ -302,482 +184,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] require(!StateStore.isMaintenanceRunning) } - private def getUnloadQueue(): ConcurrentLinkedQueue[ - (StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)] = { - val f = PrivateMethod[ConcurrentLinkedQueue[ - (StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]]( - Symbol("unloadedProvidersToClose")) - StateStore invokePrivate f() - } - - private def getSnapshotPartitions(): mutable.HashSet[StateStoreProviderId] = { - val f = PrivateMethod[mutable.HashSet[StateStoreProviderId]](Symbol("snapshotPartitions")) - StateStore invokePrivate f() - } - - private def getCleanupPartitions(): mutable.HashSet[StateStoreProviderId] = { - val f = PrivateMethod[mutable.HashSet[StateStoreProviderId]](Symbol("cleanupPartitions")) - StateStore invokePrivate f() - } - - private def getLoadedProviders(): mutable.HashMap[StateStoreProviderId, StateStoreProvider] = { - val f = PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]( - Symbol("loadedProviders")) - StateStore invokePrivate f() - } - - private def getBlockingProvider(id: StateStoreProviderId): BlockingMaintenanceProvider = { - val loaded = getLoadedProviders() - loaded.synchronized { loaded.get(id).get }.asInstanceOf[BlockingMaintenanceProvider] - } - - private def maintenanceStoreConf( - providerClass: Class[_], - interval: Long = 100L, - numThreads: Int = 4): StateStoreConf = { - val sqlConf = getDefaultSQLConf( - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get) - sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, interval) - sqlConf.setConf(SQLConf.NUM_STATE_STORE_MAINTENANCE_THREADS, numThreads) - sqlConf.setConf(SQLConf.STATE_STORE_PROVIDER_CLASS, providerClass.getName) - new StateStoreConf(sqlConf) - } - - private def loadNullProvider( - dir: String, - storeConf: StateStoreConf, - partition: Int = 0): StateStoreProviderId = { - val storeId = StateStoreProviderId(StateStoreId(dir, 0, partition), UUID.randomUUID) - StateStore.get( - storeId, null, null, NoPrefixKeyStateEncoderSpec(null), - 0, None, None, useColumnFamilies = false, storeConf, new Configuration()) - storeId - } - - test("SPARK-51596: task thread unload lifecycle from queue to close") { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { coordinatorRef => - // TODO: set to a longer interval once triggerNow is added so the test verifies close - // happens via triggerNow, not the periodic tick. - val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider], interval = 5000L) - val id1 = loadNullProvider("lifecycle", storeConf) - val bp = getBlockingProvider(id1) - - val queue = getUnloadQueue() - assert(StateStore.isLoaded(id1)) - assert(queue.isEmpty, "Queue should start empty") - - // Make stale and load another provider to trigger task thread queueing. Use a - // non-blocking provider for id2 since we don't need to observe its maintenance. - coordinatorRef.reportActiveInstance(id1, "otherhost", "otherexec", Seq.empty) - val storeConf2 = maintenanceStoreConf(classOf[FakeStateStoreProviderTracksCloseThread]) - val id2 = loadNullProvider("lifecycle", storeConf2, partition = 1) - - assert(!StateStore.isLoaded(id1), "Provider1 should be removed") - assert(StateStore.isLoaded(id2), "Provider2 should still be loaded") - - // Verify task thread queued with All. - assert(!queue.isEmpty) - val (qId, _, opReq) = queue.peek() - assert(qId == id1) - assert(opReq == MaintenanceOpRequest.All, s"Expected All, got $opReq") - - // Step 2: Scheduler submits first op as FromTaskThread. Snapshot enters latch. - assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS) && - bp.snapshotEnteredLatch.getCount == 0, "snapshot should have started") - - // Step 3: Release snapshot. Post-work queues remaining op (Cleanup). - bp.snapshotContinueSignal.countDown() - - // Step 4: Scheduler picks up Cleanup as FromUnloadedProvidersQueue. cleanupEnteredLatch - // being counted down proves Cleanup was queued and submitted. - assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS) && - bp.cleanupEnteredLatch.getCount == 0, "cleanup should have started") - - // Snapshot's post-work queued Cleanup and scheduler drained it. Cleanup is now running, - // queue is empty. - assert(queue.isEmpty, "queue should be drained while cleanup runs") - - // Step 5: Release cleanup. FromUnloadedProvidersQueue calls closeProvider. - bp.cleanupContinueSignal.countDown() - - // Verify provider was closed on the maintenance pool. - eventually(timeout(10.seconds)) { - assert(bp.closeThreadName.contains("state-store-maintenance-thread"), - "close should happen on maintenance thread, but was on: " + bp.closeThreadName) - assert(!StateStore.isLoaded(id1), "provider should be removed from loadedProviders") - assert(queue.isEmpty, "Queue should be drained") - } - } - } - } - - test("tryClaimPartition returns true first call, false second, true for different opType") { - val id = StateStoreProviderId(StateStoreId("dir", 0, 0), UUID.randomUUID) - val id2 = StateStoreProviderId(StateStoreId("dir", 0, 1), UUID.randomUUID) - - try { - // First claim for snapshot succeeds - assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) - // Second claim for same id + opType fails - assert(!StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) - // Claim for same id but different opType succeeds - assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Cleanup)) - // That one is also occupied now - assert(!StateStore.tryClaimPartition(id, MaintenanceOpType.Cleanup)) - - // Different id can still claim both - assert(StateStore.tryClaimPartition(id2, MaintenanceOpType.Snapshot)) - assert(StateStore.tryClaimPartition(id2, MaintenanceOpType.Cleanup)) - } finally { - getSnapshotPartitions().clear() - getCleanupPartitions().clear() - } - } - - test("otherMaintenanceOpRequest maps correctly") { - assert(StateStore.otherMaintenanceOpRequest(MaintenanceOpType.Snapshot) - === MaintenanceOpRequest.Cleanup) - assert(StateStore.otherMaintenanceOpRequest(MaintenanceOpType.Cleanup) - === MaintenanceOpRequest.Snapshot) - } - - test("closeProvider sets unloaded even if close() throws") { - val storeId = StateStoreProviderId(StateStoreId("closeTest", 0, 0), UUID.randomUUID) - val callOrder = new mutable.ArrayBuffer[String]() - val provider = new FakeStateStoreProviderTracksCloseThread { - override def close(): Unit = { - callOrder += "close" - throw new RuntimeException("close failed") - } - override def setUnloaded(): Unit = { - callOrder += "setUnloaded" - super.setUnloaded() - } - } - provider.init( - storeId.storeId, null, null, NoPrefixKeyStateEncoderSpec(null), - useColumnFamilies = false, null, null) - - assert(!provider.unloaded) - intercept[RuntimeException] { - StateStore.closeProvider(storeId, provider) - } - assert(provider.unloaded, "setUnloaded should run even if close() throws") - assert(callOrder === Seq("close", "setUnloaded"), - "setUnloaded should run after close() even if it throws") - } - - test("concurrent snapshot and cleanup on same provider both succeed") { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { _ => - val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) - val storeId = loadNullProvider("concurrentDir", storeConf) - val bp = getBlockingProvider(storeId) - - // Wait for both snapshot and cleanup to enter. - assert(bp.snapshotEnteredLatch.await(30, TimeUnit.SECONDS), "snapshot should have started") - assert(bp.cleanupEnteredLatch.await(30, TimeUnit.SECONDS), "cleanup should have started") - - // Both run on maintenance pool threads, on different threads. - val prefix = "state-store-maintenance-thread" - assert(bp.snapshotThreadName.startsWith(prefix)) - assert(bp.cleanupThreadName.startsWith(prefix)) - assert(bp.snapshotThreadName != bp.cleanupThreadName, - "snapshot and cleanup should run on different threads") - - // Partition sets should be claimed while both are running. - assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Snapshot), - "snapshot partition set should be occupied") - assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Cleanup), - "cleanup partition set should be occupied") - - // Release both to finish. - bp.snapshotContinueSignal.countDown() - bp.cleanupContinueSignal.countDown() - - // Verify all ops completed by checking partition sets are released. Read the sets via - // reflection (tryClaimPartition has side effects that break eventually retries). - eventually(timeout(10.seconds)) { - assert(!getSnapshotPartitions().contains(storeId), - "snapshot partition set should be released") - assert(!getCleanupPartitions().contains(storeId), - "cleanup partition set should be released") - } - } - } - } - - test("partition set released when maintenance throws") { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { _ => - val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) - val storeId = loadNullProvider("errorDir", storeConf) - val bp = getBlockingProvider(storeId) - bp.snapshotShouldThrow = true - - // Wait for snapshot to enter. - assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) - // Partition set is claimed while running. - assert(!StateStore.tryClaimPartition(storeId, MaintenanceOpType.Snapshot), - "snapshot set should be occupied") - - // Release snapshot. It will throw. Release cleanup to avoid blocking the pool. - bp.snapshotContinueSignal.countDown() - bp.cleanupContinueSignal.countDown() - - // isLoaded returning false proves the exception was thrown (only the error handler - // unloads), and the finally block must release the partition set. - eventually(timeout(10.seconds)) { - assert(!StateStore.isLoaded(storeId), "provider should be unloaded after throw") - assert(bp.closeThreadName.nonEmpty, "provider should be closed after throw") - assert(!getSnapshotPartitions().contains(storeId), - "snapshot set should be released by finally block after throw") - } - } - } - } - - private def testRequeue(opRequest: MaintenanceOpRequest): Unit = { - val logAppender = new LogAppender("requeue-log", maxEvents = 100) - logAppender.setThreshold(Level.INFO) - withLogAppender(logAppender, level = Some(Level.INFO)) { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { _ => - val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) - val storeId = loadNullProvider("requeueDir", storeConf) - val bp = getBlockingProvider(storeId) - - // Wait for both ops to enter, occupying both partition sets. - assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) - assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS)) - - // Add an entry to the queue. The scheduler tries to drain it but the partition set is - // occupied (held by the blocked task above), so it should be requeued. - val queue = getUnloadQueue() - queue.add((storeId, bp, opRequest)) - - eventually(timeout(10.seconds)) { - assert(logAppender.loggingEvents.exists( - _.getMessage.getFormattedMessage.contains("Had to requeue")), - s"scheduler should have logged requeue for $opRequest") - } - - // Queue should still have the entry. - assert(queue.size() == 1, s"$opRequest entry should have been requeued") - val (requeuedId, _, requeuedOp) = queue.peek() - assert(requeuedId == storeId) - assert(requeuedOp == opRequest) - - // Clean up: release latches. - bp.snapshotContinueSignal.countDown() - bp.cleanupContinueSignal.countDown() - queue.clear() - } - } - } - } - - test("Snapshot entry requeues when snapshot partition set is occupied") { - testRequeue(MaintenanceOpRequest.Snapshot) - } - - test("Cleanup entry requeues when cleanup partition set is occupied") { - testRequeue(MaintenanceOpRequest.Cleanup) - } - - test("All entry requeues when both partition sets are occupied") { - testRequeue(MaintenanceOpRequest.All) - } - - test("When MaintenanceOpRequest is All, cleanup is submitted if snapshot partition set " + - "is occupied") { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { _ => - // Long interval so we can set up before the first cycle fires (5s initial delay). - val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider], interval = 5000L) - val id = loadNullProvider("shortCircuit", storeConf) - val bp = getBlockingProvider(id) - - try { - // Claim snapshot partition set. - assert(StateStore.tryClaimPartition(id, MaintenanceOpType.Snapshot)) - - // Remove from loadedProviders so the scheduler doesn't submit cleanup for this - // provider by iterating through it. Only the queue entry should submit cleanup. - getLoadedProviders().synchronized { getLoadedProviders().remove(id) } - - // Add All entry. The scheduler's first cycle (at 5s) drains it, tries snapshot - // (occupied), falls through to cleanup. Cleanup blocks on bp's latch, keeping the set. - getUnloadQueue().add((id, bp, MaintenanceOpRequest.All)) - - eventually(timeout(10.seconds)) { - assert(getCleanupPartitions().contains(id), "cleanup should be claimed") - assert(getUnloadQueue().isEmpty, "queue should be drained") - } - } finally { - bp.snapshotContinueSignal.countDown() - bp.cleanupContinueSignal.countDown() - getSnapshotPartitions().remove(id) - } - } - } - } - - test("canProcess is false and maintenance is skipped when provider has already been unloaded") { - val logAppender = new LogAppender("canProcess-log", maxEvents = 100) - logAppender.setThreshold(Level.INFO) - withLogAppender(logAppender, level = Some(Level.INFO)) { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { _ => - val storeConf = - maintenanceStoreConf(classOf[BlockingMaintenanceProvider], interval = 5000L) - val id = loadNullProvider("canProcess", storeConf) - val bp = getBlockingProvider(id) - - // Mark unloaded before the first maintenance cycle fires (5s initial delay). - // canProcess checks !provider.unloaded and will return false, skipping maintenance. - bp.setUnloaded() - - // Wait for the "Skipping maintenance" log proving canProcess was false. - eventually(timeout(10.seconds)) { - assert(logAppender.loggingEvents.exists( - _.getMessage.getFormattedMessage.contains("Skipping maintenance")), - "should log skipping maintenance for unloaded provider") - } - assert(bp.snapshotEnteredLatch.getCount == 1, "snapshot should not have entered") - assert(bp.cleanupEnteredLatch.getCount == 1, "cleanup should not have entered") - } - } - } - } - - test("canProcess is false and maintenance is skipped when provider instance differs") { - val logAppender = new LogAppender("canProcess-stale-log", maxEvents = 100) - logAppender.setThreshold(Level.INFO) - withLogAppender(logAppender, level = Some(Level.INFO)) { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { _ => - // numThreads=1: snapshot gets the only pool thread. Cleanup's task waits in the pool's - // work queue until the thread is free. - val storeConf = - maintenanceStoreConf(classOf[BlockingMaintenanceProvider], numThreads = 1) - val id = loadNullProvider("canProcessStale", storeConf) - val bp = getBlockingProvider(id) - - // Wait for snapshot to enter (occupies the only pool thread). Cleanup's task waits in - // the pool's work queue. - assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) - - // Replace A with a different instance while cleanup is waiting. When cleanup starts, - // canProcess sees contains(A) is false. - val replacement = new FakeStateStoreProviderTracksCloseThread - replacement.init(id.storeId, null, null, NoPrefixKeyStateEncoderSpec(null), - useColumnFamilies = false, null, null) - val loaded = getLoadedProviders() - loaded.synchronized { loaded.put(id, replacement) } - - // Release snapshot. Thread frees, cleanup starts, canProcess fails (instance differs), - // maintenance skipped. - bp.snapshotContinueSignal.countDown() - - eventually(timeout(10.seconds)) { - assert(logAppender.loggingEvents.exists( - _.getMessage.getFormattedMessage.contains("Skipping maintenance")), - "should log skipping maintenance for stale instance") - } - assert(bp.cleanupEnteredLatch.getCount == 1, "cleanup should not have entered") - } - } - } - } - - test("FromLoadedProviders unload: reloaded provider is not removed nor queued") { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { coordinatorRef => - val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) - val id = loadNullProvider("staleInstance", storeConf) - val bp = getBlockingProvider(id) - - // Mark as needing to be closed. - coordinatorRef.reportActiveInstance(id, "otherhost", "otherexec", Seq.empty) - - // Wait for snapshot to enter. - assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) - - // Stop only the scheduler (not the pools) so A's threads keep running but no new cycles - // fire after we replace. - val taskField = PrivateMethod[StateStore.MaintenanceTask](Symbol("maintenanceTask")) - (StateStore invokePrivate taskField()).stop() - - // Replace provider A with a different instance while A is blocked. When A finishes, - // loadedProviders.get(id).contains(A) is false (replacement is there), removal skipped. - val replacement = new FakeStateStoreProviderTracksCloseThread - replacement.init(id.storeId, null, null, - NoPrefixKeyStateEncoderSpec(null), useColumnFamilies = false, null, null) - val loaded = getLoadedProviders() - loaded.synchronized { loaded.put(id, replacement) } - - // Release A. - bp.snapshotContinueSignal.countDown() - bp.cleanupContinueSignal.countDown() - - // Wait for A's partition sets to be released. - eventually(timeout(10.seconds)) { - assert(!getSnapshotPartitions().contains(id), "snapshot partition should be released") - } - - // A should NOT have removed the replacement from loadedProviders. - assert(StateStore.isLoaded(id), "replacement provider should still be loaded") - // A should NOT have queued anything (instance differs, skip). - assert(getUnloadQueue().isEmpty, "queue should be empty (stale instance skipped removal)") - } - } - } - - test("FromLoadedProviders unload: with concurrent ops, only one removes and queues") { - val conf = new SparkConf().setMaster("local").setAppName("test") - withSpark(SparkContext.getOrCreate(conf)) { sc => - withCoordinatorRef(sc) { coordinatorRef => - val storeConf = maintenanceStoreConf(classOf[BlockingMaintenanceProvider]) - val id = loadNullProvider("concurrentUnload", storeConf) - val bp = getBlockingProvider(id) - - // Make provider stale so both ops detect inactive in source handling. - coordinatorRef.reportActiveInstance(id, "otherhost", "otherexec", Seq.empty) - - // Wait for both ops to enter. - assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS)) - assert(bp.cleanupEnteredLatch.await(10, TimeUnit.SECONDS)) - - // Stop the scheduler so no new cycles drain the queue before we can check its size. - val taskField = PrivateMethod[StateStore.MaintenanceTask](Symbol("maintenanceTask")) - (StateStore invokePrivate taskField()).stop() - - // Release both. Both finish, both see !verifyIfStoreInstanceActive. Only one should - // remove from loadedProviders and queue. - bp.snapshotContinueSignal.countDown() - bp.cleanupContinueSignal.countDown() - - val queue = getUnloadQueue() - eventually(timeout(10.seconds)) { - assert(!StateStore.isLoaded(id), "provider should be removed") - assert(queue.size() == 1, "only one op should queue, the other should no-op") - } - } - } - } - - test("retaining only two latest versions when MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") { tryWithProviderResource( newStoreProvider(minDeltasForSnapshot = 10, numOfVersToRetainInMemory = 2)) { provider => From 383debf5138f95c90246ded4e15cc760218b6294 Mon Sep 17 00:00:00 2001 From: liviazhu Date: Thu, 25 Jun 2026 23:09:24 +0000 Subject: [PATCH 4/5] [SS] Remove FromTaskThread maintenance source and add nextOp to route post-op behavior The FromTaskThread maintenance source is no longer needed now that the query thread queues providers instead of submitting directly. Replace it with an explicit nextOp on the submitted task: - Remove MaintenanceTaskType.FromTaskThread. - Add nextOp: Option[MaintenanceOpRequest] to submitMaintenanceWorkForProvider/tryClaimAndSubmit (only meaningful for FromUnloadedProvidersQueue). For an All request, the first op is submitted as FromUnloadedProvidersQueue with nextOp = Some(otherOp); on completion the pool thread enqueues the remaining op with nextOp = None, which closes the provider after it runs. Ported from databricks-eng/runtime#209771 (SC-226849). Co-authored-by: Isaac --- .../streaming/state/StateStore.scala | 100 ++++++++++-------- .../StateStoreDecoupledMaintenanceSuite.scala | 2 +- 2 files changed, 57 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d3182da00d62d..437ac25ac42fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -84,7 +84,6 @@ sealed trait MaintenanceTaskType object MaintenanceTaskType { case object FromUnloadedProvidersQueue extends MaintenanceTaskType - case object FromTaskThread extends MaintenanceTaskType case object FromLoadedProviders extends MaintenanceTaskType } @@ -1820,25 +1819,26 @@ object StateStore extends Logging { // Phase 1: Drain the unloadedProvidersToClose queue. These are providers that have been // removed from loadedProviders and need maintenance before close. opRequest determines - // which task to submit and which source type to use: - // All: pick one available op, submit as FromTaskThread (first of two ticks) - // Snapshot: submit snapshot as FromUnloadedProvidersQueue (closes after) - // Cleanup: submit cleanup as FromUnloadedProvidersQueue (closes after) + // which task to submit: + // All: pick one available op, nextOp enqueues the other after completion + // Snapshot: submit snapshot, nextOp = None so provider is closed after + // Cleanup: submit cleanup, nextOp = None so provider is closed after while (!unloadedProvidersToClose.isEmpty) { val (providerId, provider, opRequest) = unloadedProvidersToClose.poll() val submitted = opRequest match { case MaintenanceOpRequest.All => - // All ops should run before the provider can be closed. We serialize them by - // submitting one op now as FromTaskThread; when it completes it enqueues the - // remaining op, which runs as FromUnloadedProvidersQueue and closes the provider. + // All ops should have run recently before the provider can be closed. We serialize + // them by submitting one op now with nextOp pointing to the other. When the first + // completes, the pool thread enqueues the remaining op with nextOp = None, which + // closes the provider after finishing. // Pick whichever partition set is available with short-circuit evaluation. - tryClaimAndSubmit( - providerId, provider, storeConf, - MaintenanceOpType.Snapshot, MaintenanceTaskType.FromTaskThread) || - tryClaimAndSubmit( - providerId, provider, storeConf, - MaintenanceOpType.Cleanup, MaintenanceTaskType.FromTaskThread) + tryClaimAndSubmit(providerId, provider, storeConf, + MaintenanceOpType.Snapshot, MaintenanceTaskType.FromUnloadedProvidersQueue, + nextOp = Some(otherMaintenanceOpRequest(MaintenanceOpType.Snapshot))) || + tryClaimAndSubmit(providerId, provider, storeConf, + MaintenanceOpType.Cleanup, MaintenanceTaskType.FromUnloadedProvidersQueue, + nextOp = Some(otherMaintenanceOpRequest(MaintenanceOpType.Cleanup))) case MaintenanceOpRequest.Snapshot => tryClaimAndSubmit( providerId, provider, storeConf, @@ -1888,10 +1888,14 @@ object StateStore extends Logging { provider: StateStoreProvider, storeConf: StateStoreConf, opType: MaintenanceOpType, - source: MaintenanceTaskType): Boolean = { + source: MaintenanceTaskType, + // Only used when source is FromUnloadedProvidersQueue. + nextOp: Option[MaintenanceOpRequest] = None): Boolean = { if (tryClaimPartition(providerId, opType)) { - submitMaintenanceWorkForProvider(providerId, provider, storeConf, source, opType) - logDebug(s"Submitted $providerId with source $source for $opType") + submitMaintenanceWorkForProvider( + providerId, provider, storeConf, source, opType, nextOp) + logDebug(s"Submitted $providerId with source $source" + + s" for $opType, nextOp=$nextOp") true } else { logInfo(log"Not processing partition " + @@ -1918,14 +1922,20 @@ object StateStore extends Logging { * * @param id The StateStore provider ID to perform maintenance on * @param provider The StateStore provider instance + * @param storeConf The StateStore configuration + * @param source Where this request originated from * @param opType Which maintenance operation (snapshot or cleanup) to perform + * @param nextOp If set, the remaining op to enqueue after this one completes. Only used when + * source is FromUnloadedProvidersQueue. */ private def submitMaintenanceWorkForProvider( id: StateStoreProviderId, provider: StateStoreProvider, storeConf: StateStoreConf, source: MaintenanceTaskType, - opType: MaintenanceOpType): Unit = { + opType: MaintenanceOpType, + // Only used when source is FromUnloadedProvidersQueue. + nextOp: Option[MaintenanceOpRequest] = None): Unit = { val pool = opType match { case MaintenanceOpType.Snapshot => highPriorityThreadPool case MaintenanceOpType.Cleanup => lowPriorityThreadPool @@ -1953,7 +1963,7 @@ object StateStore extends Logging { loadedProviders.synchronized { loadedProviders.get(id).contains(provider) } && !provider.unloaded case _ => - // FromTaskThread / FromUnloadedProvidersQueue: provider already removed, + // FromUnloadedProvidersQueue: provider already removed, // reference passed directly, no containsKey needed. !provider.unloaded } @@ -1964,9 +1974,10 @@ object StateStore extends Logging { case MaintenanceOpType.Cleanup => provider.doCleanupMaintenance() } - // Dispatch based on source. FromLoadedProviders and FromTaskThread run inside the - // read lock so no close can interleave between work and enqueue. - // FromUnloadedProvidersQueue releases the read lock before acquiring the write lock. + // Dispatch based on source. FromLoadedProviders runs inside the read lock so no + // close can interleave between work and enqueue. FromUnloadedProvidersQueue uses + // nextOp to decide whether to enqueue the remaining op or release the read lock and + // acquire the write lock to close the provider. source match { case FromLoadedProviders => // Check if provider should be unloaded @@ -1996,28 +2007,29 @@ object StateStore extends Logging { } } - case FromTaskThread => - // Provider already removed from loadedProviders by the query thread. Queue for - // the other operation before close. - unloadedProvidersToClose.add((id, provider, otherMaintenanceOpRequest(opType))) - logInfo(log"${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)}: queued " + - log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)} " + - log"for close with ${MDC(LogKeys.OP_TYPE, otherMaintenanceOpRequest(opType))}") - if (maintenanceTask != null) maintenanceTask.triggerNow() - - case FromUnloadedProvidersQueue => - // Release read lock, then acquire write lock to wait for any in-flight - // maintenance to finish. - provider.maintenanceLock.readLock().unlock() - provider.maintenanceLock.writeLock().lock() - try { - closeProvider(id, provider) - } finally { - // Downgrade: reacquire read lock while holding write lock, then release write - // lock. The outer finally unconditionally releases the read lock. - provider.maintenanceLock.readLock().lock() - provider.maintenanceLock.writeLock().unlock() - } + case FromUnloadedProvidersQueue => nextOp match { + case Some(remainingOp) => + // Enqueue the remaining op. It will run with nextOp = None and close the + // provider after. + unloadedProvidersToClose.add((id, provider, remainingOp)) + logInfo(log"${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)}: queued " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)} for close with " + + log"${MDC(LogKeys.OP_TYPE, remainingOp)}") + if (maintenanceTask != null) maintenanceTask.triggerNow() + case None => + // Release read lock, then acquire write lock to wait for any in-flight + // maintenance to finish. + provider.maintenanceLock.readLock().unlock() + provider.maintenanceLock.writeLock().lock() + try { + closeProvider(id, provider) + } finally { + // Downgrade: reacquire read lock while holding write lock, then release + // write lock. The outer finally unconditionally releases the read lock. + provider.maintenanceLock.readLock().lock() + provider.maintenanceLock.writeLock().unlock() + } + } } } else { logInfo(log"Skipping maintenance for " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreDecoupledMaintenanceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreDecoupledMaintenanceSuite.scala index ace25da58945f..81c0aee601d2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreDecoupledMaintenanceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreDecoupledMaintenanceSuite.scala @@ -297,7 +297,7 @@ abstract class StateStoreDecoupledMaintenanceSuiteBase[ // was consumed and snapshot was submitted. // Step 2: Scheduler (via triggerNow) submits first op as - // FromTaskThread. Snapshot enters latch. + // FromUnloadedProvidersQueue. Snapshot enters latch. assert(bp.snapshotEnteredLatch.await(10, TimeUnit.SECONDS) && bp.snapshotEnteredLatch.getCount == 0, "snapshot should have started") From 7ee8f1f76dc58a5bf018219a5bb87c1529f8785b Mon Sep 17 00:00:00 2001 From: liviazhu Date: Thu, 25 Jun 2026 23:59:52 +0000 Subject: [PATCH 5/5] [SS] Align state store maintenance comments and docstrings with the source design Match the comment wording/wrapping and the StateStoreConf docstrings to the original design, and factor the duplicated otherMaintenanceOpRequest(opType) call in the FromLoadedProviders unload path into a local `remaining` val. No behavior change. Co-authored-by: Isaac --- .../streaming/state/StateStore.scala | 160 ++++++++++-------- .../streaming/state/StateStoreConf.scala | 11 +- 2 files changed, 90 insertions(+), 81 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 437ac25ac42fe..3a8bb0d86ff03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1299,8 +1299,8 @@ object StateStore extends Logging { new ConcurrentLinkedQueue[(StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)] // These sets track which providers currently have maintenance tasks in-flight, - // one per operation type, to prevent concurrent same-type operations on the same - // provider. Each set has its own lock. + // one per operation type, to prevent concurrent same-type operations on the same provider. + // Each set has its own lock. private val snapshotPartitionsLock = new Object @GuardedBy("snapshotPartitionsLock") private val snapshotPartitions = new mutable.HashSet[StateStoreProviderId] @@ -1605,9 +1605,8 @@ object StateStore extends Logging { }.getOrElse(log"") providerStatus.providerIdsToUnload.foreach(id => { loadedProviders.remove(id).foreach( provider => { - // Queue the provider for maintenance and close. The maintenance scheduler will - // drain the queue and submit tasks. remove() returning non-null ensures only one - // queuer for this provider instance. + // Queue provider for maintenance + close. The scheduler will drain the queue + // and submit tasks. remove() returning non-null ensures only one queuer. logInfo(log"Queuing provider from task thread for maintenance and close " + log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}." + taskContextIdLogLine + log"Removed provider from loadedProviders") @@ -1615,11 +1614,12 @@ object StateStore extends Logging { }) }) - // Submit a scheduler cycle so queued providers are processed without waiting for the - // next periodic tick, minimizing the time stale providers wait to be closed. Without - // this, we would wait up to 2 maintenance cycles for both operations to finish and the - // provider to be closed from the time it is queued. At most one triggered cycle can be - // pending at a time. + // Submit a scheduler cycle so queued providers are processed without + // waiting for the next periodic tick, minimizing the time stale + // providers wait to be closed. Without this, we would wait up to + // 2 maintenance cycles for both operations to finish and the + // provider to be closed from the time it is queued. At most one + // triggered cycle can be pending at a time. if (providerStatus.providerIdsToUnload.nonEmpty && maintenanceTask != null) { maintenanceTask.triggerNow() } @@ -1707,7 +1707,8 @@ object StateStore extends Logging { maintenanceTask.stopAndAwait() maintenanceTask = null } - // Shut down both pools concurrently, then await both, so we don't double the blocking time. + // Shut down both pools concurrently, then await both, so we don't + // double the blocking time. if (highPriorityThreadPool != null) highPriorityThreadPool.shutdown() if (lowPriorityThreadPool != null) lowPriorityThreadPool.shutdown() if (highPriorityThreadPool != null) { @@ -1724,9 +1725,9 @@ object StateStore extends Logging { /** Unload and stop all state store providers */ def stop(): Unit = { - // Stop scheduler and pools outside loadedProviders lock. Pool threads acquire - // maintenanceLock then loadedProviders.synchronized, so holding loadedProviders while - // awaiting termination would deadlock. + // Stop scheduler and pools outside loadedProviders lock. Pool threads + // acquire maintenanceLock then loadedProviders.synchronized, so holding + // loadedProviders while awaiting termination would deadlock. stopMaintenanceTaskWithoutLock() loadedProviders.synchronized { loadedProviders.foreach { case (id, provider) => closeProvider(id, provider) } @@ -1742,16 +1743,18 @@ object StateStore extends Logging { } /** - * Determines the number of threads for the snapshot and cleanup pools using the configured - * ratio. Snapshot gets the rounded value, clamped to [1, total - 1]. Cleanup gets the - * remainder. Each pool gets at least 1 thread and the total is never exceeded. + * Determines the number of threads for the snapshot and cleanup pools + * using the configured ratio. Snapshot gets the rounded value, clamped + * to [1, total - 1]. Cleanup gets the remainder. Each pool gets at + * least 1 thread and the total is never exceeded. * @return (snapshotThreads, cleanupThreads) */ private[streaming] def getPoolSizes(storeConf: StateStoreConf): (Int, Int) = { val total = storeConf.numStateStoreMaintenanceThreads val ratio = storeConf.snapshotToCleanupThreadRatio val snapshotBeforeClamp = math.round(total * ratio).toInt - // Clamp to [1, total - 1] so each pool gets at least 1 thread and total is never exceeded. + // Clamp to [1, total - 1] so each pool gets at least 1 thread + // and total is never exceeded. val snapshot = math.max(1, math.min(total - 1, snapshotBeforeClamp)) val cleanup = total - snapshot (snapshot, cleanup) @@ -1767,8 +1770,8 @@ object StateStore extends Logging { storeConf.maintenanceInterval, task = { processUnloadedOnly => doMaintenance(storeConf, processUnloadedOnly) } ) - // Separate pools for snapshot and cleanup to prevent one operation type from starving - // the other when pool threads are saturated. + // Separate pools for snapshot and cleanup to prevent one operation type + // from starving the other when pool threads are saturated. val (snapshotThreads, cleanupThreads) = getPoolSizes(storeConf) highPriorityThreadPool = new MaintenanceThreadPool(snapshotThreads, maintenanceShutdownTimeout, maintenanceForceShutdownTimeout, @@ -1783,7 +1786,7 @@ object StateStore extends Logging { private def doMaintenance(): Unit = doMaintenance(StateStoreConf.empty) - /** Claim a single partition set slot for the given op type. Returns true if claimed. */ + /** Claim a single partition set slot. Returns true if claimed. */ private[streaming] def tryClaimPartition( id: StateStoreProviderId, opType: MaintenanceOpType): Boolean = { @@ -1817,22 +1820,24 @@ object StateStore extends Logging { val providersToRequeue = new ArrayBuffer[(StateStoreProviderId, StateStoreProvider, MaintenanceOpRequest)]() - // Phase 1: Drain the unloadedProvidersToClose queue. These are providers that have been - // removed from loadedProviders and need maintenance before close. opRequest determines - // which task to submit: - // All: pick one available op, nextOp enqueues the other after completion + // Phase 1: Drain unloadedProvidersToClose queue. + // These are providers removed from loadedProviders that need maintenance before close. + // opRequest determines which task to submit: + // All: pick one available op, nextOp enqueues the other after completion // Snapshot: submit snapshot, nextOp = None so provider is closed after - // Cleanup: submit cleanup, nextOp = None so provider is closed after + // Cleanup: submit cleanup, nextOp = None so provider is closed after while (!unloadedProvidersToClose.isEmpty) { val (providerId, provider, opRequest) = unloadedProvidersToClose.poll() val submitted = opRequest match { case MaintenanceOpRequest.All => - // All ops should have run recently before the provider can be closed. We serialize - // them by submitting one op now with nextOp pointing to the other. When the first - // completes, the pool thread enqueues the remaining op with nextOp = None, which - // closes the provider after finishing. - // Pick whichever partition set is available with short-circuit evaluation. + // All ops should have run recently before the provider can be + // closed. We serialize them by submitting one op now with + // nextOp pointing to the other. When the first completes, + // the pool thread enqueues the remaining op with nextOp = None, + // which closes the provider after finishing. + // Pick whichever partition set is available with short circuit + // evaluation. tryClaimAndSubmit(providerId, provider, storeConf, MaintenanceOpType.Snapshot, MaintenanceTaskType.FromUnloadedProvidersQueue, nextOp = Some(otherMaintenanceOpRequest(MaintenanceOpType.Snapshot))) || @@ -1849,7 +1854,8 @@ object StateStore extends Logging { MaintenanceOpType.Cleanup, MaintenanceTaskType.FromUnloadedProvidersQueue) } - // If the partition set is occupied, buffer for requeue and retry on the next cycle. + // If the partition set is occupied, buffer for requeue. These will be + // added back to the queue after draining and retried on the next cycle. if (!submitted) { providersToRequeue += ((providerId, provider, opRequest)) } @@ -1863,8 +1869,8 @@ object StateStore extends Logging { providersToRequeue.foreach(unloadedProvidersToClose.offer) // Phase 2: Submit separate snapshot and cleanup tasks for loaded providers. - // Skipped when processUnloadedOnly is true to avoid submitting unnecessary work for all - // providers. + // Skipped when processUnloadedOnly is true to avoid submitting + // unnecessary work for all providers. if (!processUnloadedOnly) { loadedProviders.synchronized { loadedProviders.toSeq @@ -1924,9 +1930,9 @@ object StateStore extends Logging { * @param provider The StateStore provider instance * @param storeConf The StateStore configuration * @param source Where this request originated from - * @param opType Which maintenance operation (snapshot or cleanup) to perform - * @param nextOp If set, the remaining op to enqueue after this one completes. Only used when - * source is FromUnloadedProvidersQueue. + * @param opType Which maintenance operation to perform + * @param nextOp If set, the remaining op to enqueue after this one + * completes. Only used when source is FromUnloadedProvidersQueue. */ private def submitMaintenanceWorkForProvider( id: StateStoreProviderId, @@ -1944,48 +1950,51 @@ object StateStore extends Logging { logDebug(s"Starting $opType maintenance for $id, source=$source") val startTime = System.currentTimeMillis() try { - // We use a var instead of early return because `return` inside a closure (pool.execute) - // throws NonLocalReturnControl in Scala. + // We use a var instead of early return because `return` inside + // a closure (pool.execute) throws NonLocalReturnControl in Scala. var canProcess = false - // If we can't acquire the lock, the write lock is held, which means another thread is - // closing this provider. The entire maintenance task is a no-op in that case, so we skip - // and free the pool thread rather than blocking. The zero timeout honors fair ordering so - // readers do not starve a queued writer. + // If we can't acquire the lock, the write lock is held, which + // means another thread is closing this provider. The entire + // maintenance task is a no-op in that case, so we skip and free + // the pool thread rather than blocking. The zero timeout honors + // fair ordering so readers do not starve a queued writer. val lockAcquired = provider.maintenanceLock.readLock().tryLock(0, TimeUnit.SECONDS) try { if (lockAcquired) { canProcess = source match { case FromLoadedProviders => - // Checks that the ID is still in loadedProviders and that the instance matches - // the one we were given. The scheduler submits from a stale copy of - // loadedProviders, so the provider may have been removed and replaced by a new - // instance under the same key. + // Checks that the ID is still in loadedProviders and that the instance + // matches the one we were given. The scheduler submits from a stale copy + // of loadedProviders, so the provider may have been removed and replaced + // by a new instance under the same key. loadedProviders.synchronized { loadedProviders.get(id).contains(provider) } && !provider.unloaded case _ => - // FromUnloadedProvidersQueue: provider already removed, - // reference passed directly, no containsKey needed. + // FromUnloadedProvidersQueue: provider already + // removed, reference passed directly, no containsKey needed !provider.unloaded } if (canProcess) { - // Do the maintenance work for this op type. + // Do maintenance work opType match { case MaintenanceOpType.Snapshot => provider.doSnapshotMaintenance() case MaintenanceOpType.Cleanup => provider.doCleanupMaintenance() } - // Dispatch based on source. FromLoadedProviders runs inside the read lock so no - // close can interleave between work and enqueue. FromUnloadedProvidersQueue uses - // nextOp to decide whether to enqueue the remaining op or release the read lock and - // acquire the write lock to close the provider. + // Dispatch based on source. FromLoadedProviders runs inside the + // read lock so no close can interleave between work and enqueue. + // FromUnloadedProvidersQueue uses nextOp to decide whether to + // enqueue the remaining op or release the read lock and acquire + // the write lock to close the provider. source match { case FromLoadedProviders => // Check if provider should be unloaded if (!verifyIfStoreInstanceActive(id)) { - // Only remove if the map still holds the same provider instance we were - // given. Between verifyIfStoreInstanceActive and this remove, a concurrent - // get() may have loaded a new provider under the same key. Removing by key - // alone would incorrectly remove the new provider. + // Only remove if the map still holds the same provider instance + // we were given. Between verifyIfStoreInstanceActive and this + // remove, a concurrent get() may have loaded a new provider + // under the same key. Removing by key alone would incorrectly + // remove the new provider. val removed = loadedProviders.synchronized { if (loadedProviders.get(id).contains(provider)) { loadedProviders.remove(id) @@ -1994,12 +2003,11 @@ object StateStore extends Logging { } } if (removed.isDefined) { - unloadedProvidersToClose.add( - (id, provider, otherMaintenanceOpRequest(opType))) + val remaining = otherMaintenanceOpRequest(opType) + unloadedProvidersToClose.add((id, provider, remaining)) logInfo(log"${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)}: " + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)} verified inactive, " + - log"queued for close with " + - log"${MDC(LogKeys.OP_TYPE, otherMaintenanceOpRequest(opType))}") + log"queued for close with ${MDC(LogKeys.OP_TYPE, remaining)}") } else { logInfo(log"${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)}: " + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)} verified inactive " + @@ -2009,23 +2017,24 @@ object StateStore extends Logging { case FromUnloadedProvidersQueue => nextOp match { case Some(remainingOp) => - // Enqueue the remaining op. It will run with nextOp = None and close the - // provider after. + // Enqueue the remaining op. It will run with + // nextOp = None and close the provider after. unloadedProvidersToClose.add((id, provider, remainingOp)) logInfo(log"${MDC(LogKeys.MAINTENANCE_TASK_TYPE, source)}: queued " + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)} for close with " + log"${MDC(LogKeys.OP_TYPE, remainingOp)}") if (maintenanceTask != null) maintenanceTask.triggerNow() case None => - // Release read lock, then acquire write lock to wait for any in-flight - // maintenance to finish. + // Release read lock, then acquire write lock to wait + // for any in-flight maintenance to finish. provider.maintenanceLock.readLock().unlock() provider.maintenanceLock.writeLock().lock() try { closeProvider(id, provider) } finally { - // Downgrade: reacquire read lock while holding write lock, then release - // write lock. The outer finally unconditionally releases the read lock. + // Downgrade: reacquire read lock while holding write + // lock, then release write lock. The outer finally + // unconditionally releases the read lock. provider.maintenanceLock.readLock().lock() provider.maintenanceLock.writeLock().unlock() } @@ -2048,21 +2057,22 @@ object StateStore extends Logging { log" ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}. " + log"Closing provider due to error", e) if (source == FromLoadedProviders) { - // Only remove if the map still holds the same provider instance. A concurrent get() - // may have loaded a new provider under the same key. + // Only remove if the map still holds the same provider instance. + // A concurrent get() may have loaded a new provider under the same key. loadedProviders.synchronized { if (loadedProviders.get(id).contains(provider)) { loadedProviders.remove(id) } } } - // Acquire write lock before close to wait for any concurrent maintenance on the other - // pool thread to finish. + // Acquire write lock before close to wait for any concurrent + // maintenance on the other pool thread to finish. provider.maintenanceLock.writeLock().lock() try { - // Always close this provider instance regardless of whether we removed it from the - // map. Maintenance failed, so we must clean up this provider's resources. We cannot - // rely on the queue to close the provider because maintenance may error again. + // Always close this provider instance regardless of whether we + // removed from the map. Maintenance failed, so we must clean up + // this provider's resources. We cannot rely on the queue to close + // the provider because maintenance may error again. closeProvider(id, provider) } finally { provider.maintenanceLock.writeLock().unlock() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index b63ab7596b004..5da62611fda98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -31,15 +31,14 @@ class StateStoreConf( def this() = this(new SQLConf) /** - * Total number of maintenance threads. Split between the snapshot and cleanup thread - * pools. Each pool needs at least 1 thread, so the minimum is 2. + * Total number of maintenance threads. Split evenly between the snapshot + * and cleanup thread pools. Each pool needs at least 1 thread, so the + * minimum is 2. */ val numStateStoreMaintenanceThreads: Int = sqlConf.numStateStoreMaintenanceThreads - /** - * Ratio of threads for the snapshot pool. The remainder goes to cleanup. Each pool gets at - * least 1 thread and the total is never exceeded. - */ + /** Ratio of threads for snapshot pool. Remainder goes to cleanup. + * Each pool gets at least 1 thread and the total is never exceeded. */ val snapshotToCleanupThreadRatio: Double = sqlConf.snapshotToCleanupThreadRatio /**