diff --git a/src/providers/mod.rs b/src/providers/mod.rs index f1315cd..686b488 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1082,6 +1082,15 @@ pub trait Provider: Any + Send + Sync { "0.0.0" } + /// Whether this provider has an atomic, provider-native batched worker fetch implementation. + /// + /// The runtime only requests more than one worker item at a time when this returns `true`. + /// This keeps external providers on legacy single-item fetch semantics until they explicitly + /// audit batching, session claiming, tag filtering, and attempt-count behavior. + fn supports_batched_work_item_fetch(&self) -> bool { + false + } + // ===== Core Atomic Orchestration Methods (REQUIRED) ===== // These three methods form the heart of reliable orchestration execution. // @@ -1739,6 +1748,50 @@ pub trait Provider: Any + Send + Sync { tag_filter: &TagFilter, ) -> Result, ProviderError>; + /// Fetch and peek-lock up to `max_items` worker queue items in one provider call. + /// + /// `max_new_sessions` is a transaction-scoped cap for how many previously-unowned + /// sessions may be claimed by this fetch. Items for already-owned sessions and + /// non-session items do not consume this quota. + /// + /// The default implementation preserves compatibility. It does not provide true + /// atomic batching and intentionally falls back to a single fetch when session + /// routing is enabled, so unaudited providers cannot claim multiple sessions in + /// separate transactions. + async fn fetch_work_items( + &self, + lock_timeout: Duration, + poll_timeout: Duration, + session: Option<&SessionFetchConfig>, + tag_filter: &TagFilter, + max_items: usize, + _max_new_sessions: usize, + ) -> Result, ProviderError> { + if max_items == 0 || matches!(tag_filter, TagFilter::None) { + return Ok(Vec::new()); + } + + if max_items > 1 && session.is_some() { + return Ok(self + .fetch_work_item(lock_timeout, poll_timeout, session, tag_filter) + .await? + .into_iter() + .collect()); + } + + let mut items = Vec::with_capacity(max_items.min(16)); + for _ in 0..max_items { + match self + .fetch_work_item(lock_timeout, poll_timeout, session, tag_filter) + .await? + { + Some(item) => items.push(item), + None => break, + } + } + Ok(items) + } + /// Acknowledge successful processing of a work item. /// /// # What This Does diff --git a/src/providers/sqlite.rs b/src/providers/sqlite.rs index d8fe523..a89dc6e 100644 --- a/src/providers/sqlite.rs +++ b/src/providers/sqlite.rs @@ -4,6 +4,7 @@ use sqlx::sqlite::{SqlitePool, SqlitePoolOptions}; use sqlx::{Row, Sqlite, Transaction}; +use std::collections::HashSet; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tracing::debug; @@ -718,6 +719,10 @@ impl Provider for SqliteProvider { env!("CARGO_PKG_VERSION") } + fn supports_batched_work_item_fetch(&self) -> bool { + true + } + async fn fetch_orchestration_item( &self, lock_timeout: Duration, @@ -1973,6 +1978,188 @@ impl Provider for SqliteProvider { Ok(Some((work_item, lock_token, attempt_count))) } + async fn fetch_work_items( + &self, + lock_timeout: Duration, + _poll_timeout: Duration, + session: Option<&SessionFetchConfig>, + tag_filter: &TagFilter, + max_items: usize, + max_new_sessions: usize, + ) -> Result, ProviderError> { + if max_items == 0 || matches!(tag_filter, TagFilter::None) { + return Ok(Vec::new()); + } + + let max_items = max_items.min(8); + let scan_limit = max_items.saturating_mul(4).max(max_items); + let mut tx = self + .pool + .begin() + .await + .map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?; + + let now_ms = Self::now_millis(); + let locked_until = Self::timestamp_after(lock_timeout); + let batch_token = Self::generate_lock_token(); + + let tag_start_param = if session.is_some() { 3 } else { 2 }; + let tag_clause = Self::build_tag_clause(tag_filter, tag_start_param); + let tag_values = Self::collect_tag_values(tag_filter); + let limit_param = tag_start_param + tag_values.len(); + + let rows = if let Some(config) = session { + let sql = format!( + r#" + SELECT q.id, q.work_item, q.attempt_count, q.session_id, s.worker_id AS active_worker_id + FROM worker_queue q + LEFT JOIN sessions s ON s.session_id = q.session_id AND s.locked_until > ?1 + WHERE q.visible_at <= ?1 + AND (q.lock_token IS NULL OR q.locked_until <= ?1) + AND ( + q.session_id IS NULL + OR s.worker_id = ?2 + OR s.session_id IS NULL + ) + AND ({tag_clause}) + ORDER BY q.id + LIMIT ?{limit_param} + "#, + ); + let mut query = sqlx::query(&sql).bind(now_ms).bind(&config.owner_id); + for val in &tag_values { + query = query.bind(val.as_str()); + } + query + .bind(scan_limit as i64) + .fetch_all(&mut *tx) + .await + .map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))? + } else { + let sql = format!( + r#" + SELECT q.id, q.work_item, q.attempt_count, q.session_id, NULL AS active_worker_id + FROM worker_queue q + WHERE q.visible_at <= ?1 + AND (q.lock_token IS NULL OR q.locked_until <= ?1) + AND q.session_id IS NULL + AND ({tag_clause}) + ORDER BY q.id + LIMIT ?{limit_param} + "#, + ); + let mut query = sqlx::query(&sql).bind(now_ms); + for val in &tag_values { + query = query.bind(val.as_str()); + } + query + .bind(scan_limit as i64) + .fetch_all(&mut *tx) + .await + .map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))? + }; + + let mut out = Vec::with_capacity(max_items); + let mut new_sessions_claimed = 0usize; + let mut claimed_session_ids = HashSet::new(); + + for row in rows { + if out.len() >= max_items { + break; + } + + let id: i64 = row + .try_get("id") + .map_err(|e| ProviderError::permanent("fetch_work_items", format!("Failed to get id: {e}")))?; + let work_item_str: String = row + .try_get("work_item") + .map_err(|e| ProviderError::permanent("fetch_work_items", format!("Failed to get work_item: {e}")))?; + let current_attempt_count: i64 = row.try_get("attempt_count").map_err(|e| { + ProviderError::permanent("fetch_work_items", format!("Failed to get attempt_count: {e}")) + })?; + let session_id: Option = row + .try_get("session_id") + .map_err(|e| ProviderError::permanent("fetch_work_items", format!("Failed to get session_id: {e}")))?; + let active_worker_id: Option = row.try_get("active_worker_id").unwrap_or(None); + + let work_item: WorkItem = serde_json::from_str(&work_item_str) + .map_err(|e| ProviderError::permanent("fetch_work_items", format!("Deserialization error: {e}")))?; + + if let (Some(sid), Some(config)) = (&session_id, session) { + let already_owned = + active_worker_id.as_deref() == Some(config.owner_id.as_str()) || claimed_session_ids.contains(sid); + if !already_owned { + if new_sessions_claimed >= max_new_sessions { + continue; + } + new_sessions_claimed += 1; + } + + let session_locked_until = now_ms + config.lock_timeout.as_millis() as i64; + let upsert_result = sqlx::query( + r#" + INSERT INTO sessions (session_id, worker_id, locked_until, last_activity_at) + VALUES (?1, ?2, ?3, ?4) + ON CONFLICT (session_id) DO UPDATE + SET worker_id = ?2, + locked_until = ?3, + last_activity_at = ?4 + WHERE sessions.locked_until <= ?4 OR sessions.worker_id = ?2 + "#, + ) + .bind(sid) + .bind(&config.owner_id) + .bind(session_locked_until) + .bind(now_ms) + .execute(&mut *tx) + .await + .map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?; + + if upsert_result.rows_affected() == 0 { + if !already_owned { + new_sessions_claimed = new_sessions_claimed.saturating_sub(1); + } + continue; + } + claimed_session_ids.insert(sid.clone()); + } + + let lock_token = format!("{batch_token}_{id}"); + let update_result = sqlx::query( + r#" + UPDATE worker_queue + SET lock_token = ?1, locked_until = ?2, attempt_count = attempt_count + 1 + WHERE id = ?3 + AND (lock_token IS NULL OR locked_until <= ?4) + "#, + ) + .bind(&lock_token) + .bind(locked_until) + .bind(id) + .bind(now_ms) + .execute(&mut *tx) + .await + .map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?; + + if update_result.rows_affected() == 0 { + if let Some(sid) = &session_id { + if claimed_session_ids.remove(sid) { + new_sessions_claimed = new_sessions_claimed.saturating_sub(1); + } + } + continue; + } + + out.push((work_item, lock_token, (current_attempt_count + 1) as u32)); + } + + tx.commit() + .await + .map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?; + + Ok(out) + } + async fn ack_work_item(&self, token: &str, completion: Option) -> Result<(), ProviderError> { let mut tx = self .pool diff --git a/src/runtime/dispatchers/worker.rs b/src/runtime/dispatchers/worker.rs index 157729d..7e6a7c7 100644 --- a/src/runtime/dispatchers/worker.rs +++ b/src/runtime/dispatchers/worker.rs @@ -32,6 +32,7 @@ use crate::providers::WorkItem; use std::sync::Arc; use std::sync::atomic::Ordering; use std::time::Duration; +use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::{error, warn}; @@ -170,6 +171,8 @@ impl Runtime { tokio::spawn(async move { let mut worker_handles = Vec::with_capacity(concurrency); let mut session_owner_ids: Vec = Vec::new(); + let activity_permits = Arc::new(Semaphore::new(self.options.worker_max_inflight.max(1))); + let activity_handles = Arc::new(Mutex::new(Vec::new())); // Tracks distinct active sessions across all worker slots in this // runtime. When distinct_count() reaches max_sessions_per_runtime, @@ -190,6 +193,8 @@ impl Runtime { let activities = Arc::clone(&activities); let shutdown = Arc::clone(&shutdown); let session_tracker_clone = Arc::clone(&session_tracker); + let activity_permits_clone = Arc::clone(&activity_permits); + let activity_handles_clone = Arc::clone(&activity_handles); let suffix = stable_node_id.as_deref().unwrap_or(&self.runtime_id); let worker_id = format!("work-{worker_idx}-{suffix}"); @@ -211,13 +216,15 @@ impl Runtime { let min_interval = rt.options.dispatcher_min_poll_interval; let start_time = std::time::Instant::now(); - let work_found = match process_next_work_item( + let work_found = match process_next_work_batch( &rt, &activities, &shutdown, &worker_id, &session_owner, &session_tracker_clone, + &activity_permits_clone, + &activity_handles_clone, ) .await { @@ -265,26 +272,37 @@ impl Runtime { for handle in worker_handles { let _ = handle.await; } + + let mut handles = activity_handles.lock().await; + let pending: Vec<_> = handles.drain(..).collect(); + drop(handles); + for handle in pending { + let _ = handle.await; + } }) } } -/// Process the next available work item from the queue. +/// Process the next available batch of worker items from the queue. /// /// Returns: -/// - `Ok(true)` if work was found and processed +/// - `Ok(true)` if work was found and scheduled /// - `Ok(false)` if no work was available /// - `Err(e)` if fetch failed (caller handles backoff) -async fn process_next_work_item( +async fn process_next_work_batch( rt: &Arc, activities: &Arc, shutdown: &Arc, worker_id: &str, session_worker_id: &str, session_tracker: &Arc, + activity_permits: &Arc, + activity_handles: &Arc>>>, ) -> Result { // Check session capacity: if at limit, only fetch non-session items - let at_session_capacity = session_tracker.distinct_count() >= rt.options.max_sessions_per_runtime; + let distinct_sessions = session_tracker.distinct_count(); + let remaining_session_capacity = rt.options.max_sessions_per_runtime.saturating_sub(distinct_sessions); + let at_session_capacity = remaining_session_capacity == 0; let session_config = if at_session_capacity { None @@ -295,20 +313,92 @@ async fn process_next_work_item( }) }; - let (item, token, attempt_count) = match rt + let session_batch_allowed = session_config.is_none() || rt.options.worker_node_id.is_some(); + let provider_batch_size = if rt.history_store.supports_batched_work_item_fetch() && session_batch_allowed { + rt.options.worker_fetch_batch_size.max(1) + } else { + 1 + }; + let fetch_limit = provider_batch_size.min(activity_permits.available_permits()).max(1); + let mut permits = acquire_activity_permits(activity_permits, fetch_limit); + if permits.is_empty() { + return Ok(false); + } + + activity_handles.lock().await.retain(|handle| !handle.is_finished()); + + let items = rt .history_store - .fetch_work_item( + .fetch_work_items( rt.options.worker_lock_timeout, rt.options.dispatcher_long_poll_timeout, session_config.as_ref(), &rt.options.worker_tag_filter, + permits.len(), + remaining_session_capacity, ) - .await? - { - Some(result) => result, - None => return Ok(false), - }; + .await?; + + if items.is_empty() { + return Ok(false); + } + + if session_config.is_some() && rt.options.worker_node_id.is_none() { + let permit = permits + .pop() + .expect("fetch_work_items returned an item without an acquired permit"); + process_fetched_work_item( + Arc::clone(rt), + Arc::clone(activities), + Arc::clone(shutdown), + worker_id.to_string(), + Arc::clone(session_tracker), + items.into_iter().next().expect("items was checked non-empty above"), + permit, + ) + .await; + return Ok(true); + } + + let found = items.len(); + for item in items { + let permit = permits + .pop() + .expect("fetch_work_items returned more items than acquired permits"); + let rt = Arc::clone(rt); + let activities = Arc::clone(activities); + let shutdown = Arc::clone(shutdown); + let session_tracker = Arc::clone(session_tracker); + let worker_id = worker_id.to_string(); + tokio::spawn(async move { + process_fetched_work_item(rt, activities, shutdown, worker_id, session_tracker, item, permit).await; + }); + } + + Ok(found > 0) +} + +fn acquire_activity_permits(activity_permits: &Arc, max_items: usize) -> Vec { + let mut permits = Vec::with_capacity(max_items); + for _ in 0..max_items { + match Arc::clone(activity_permits).try_acquire_owned() { + Ok(permit) => permits.push(permit), + Err(_) => break, + } + } + permits +} +async fn process_fetched_work_item( + rt: Arc, + activities: Arc, + shutdown: Arc, + worker_id: String, + session_tracker: Arc, + item: (WorkItem, String, u32), + _permit: OwnedSemaphorePermit, +) { + let (item, token, attempt_count) = item; let item_serialized = serde_json::to_string(&item).unwrap_or_default(); match item { @@ -325,7 +415,7 @@ async fn process_next_work_item( // The guard releases the slot on drop (when activity processing completes). // Multiple activities on the same session share one slot. let _session_guard = if let Some(ref sid) = session_id { - let guard = SessionGuard::new(session_tracker, sid); + let guard = SessionGuard::new(&session_tracker, sid); // Re-check capacity after acquiring. The pre-fetch check is a hint // to avoid unnecessary fetches, but two workers can race past it. // If we're now over capacity (another worker won the race for a @@ -342,7 +432,7 @@ async fn process_next_work_item( .history_store .abandon_work_item(&token, Some(Duration::from_millis(100)), true) .await; - return Ok(true); + return; } Some(guard) } else { @@ -358,7 +448,7 @@ async fn process_next_work_item( lock_token: token, attempt_count, item_serialized, - worker_id: worker_id.to_string(), + worker_id, session_id, tag, }; @@ -366,10 +456,10 @@ async fn process_next_work_item( // Cancellation is detected during lock renewal (lock stealing). if ctx.attempt_count > rt.options.max_attempts { // Handle poison messages - handle_poison_message(rt, &ctx).await; + handle_poison_message(&rt, &ctx).await; } else { // Execute activity with cancellation support - execute_activity(rt, activities, shutdown, ctx).await; + execute_activity(&rt, &activities, &shutdown, ctx).await; } } other => { @@ -377,8 +467,6 @@ async fn process_next_work_item( panic!("unexpected WorkItem in Worker dispatcher"); } } - - Ok(true) } /// Enforce minimum polling interval to prevent hot loops. diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index fef0b09..36430f2 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -182,6 +182,23 @@ pub struct RuntimeOptions { /// Default: 2 pub worker_concurrency: usize, + /// Maximum worker queue items requested in a single provider fetch. + /// + /// This is a fetch amortization knob, not a concurrency knob. The dispatcher + /// only requests as many items as it has available `worker_max_inflight` + /// permits, so increasing this value alone does not increase activity + /// concurrency. + /// Default: 1 + pub worker_fetch_batch_size: usize, + + /// Maximum concurrently executing activities in this runtime. + /// + /// Defaults to `worker_concurrency` to preserve legacy behavior. Increase + /// this separately from `worker_fetch_batch_size` when the desired change is + /// more activity parallelism rather than fewer fetch round trips. + /// Default: 2 + pub worker_max_inflight: usize, + /// Lock timeout for orchestrator queue items. /// When an orchestration message is dequeued, it's locked for this duration. /// Orchestration turns are typically fast (milliseconds), so a shorter timeout is appropriate. @@ -351,6 +368,8 @@ impl Default for RuntimeOptions { dispatcher_long_poll_timeout: Duration::from_secs(30), // 30 seconds orchestration_concurrency: 2, worker_concurrency: 2, + worker_fetch_batch_size: 1, + worker_max_inflight: 2, orchestrator_lock_timeout: Duration::from_secs(5), orchestrator_lock_renewal_buffer: Duration::from_secs(2), worker_lock_timeout: Duration::from_secs(30),