Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,15 @@ pub trait Provider: Any + Send + Sync {
"0.0.0"
}

/// Whether this provider has an atomic, provider-native batched worker fetch implementation.
///
/// The runtime only requests more than one worker item at a time when this returns `true`.
/// This keeps external providers on legacy single-item fetch semantics until they explicitly
/// audit batching, session claiming, tag filtering, and attempt-count behavior.
fn supports_batched_work_item_fetch(&self) -> bool {
false
}

// ===== Core Atomic Orchestration Methods (REQUIRED) =====
// These three methods form the heart of reliable orchestration execution.
//
Expand Down Expand Up @@ -1739,6 +1748,50 @@ pub trait Provider: Any + Send + Sync {
tag_filter: &TagFilter,
) -> Result<Option<(WorkItem, String, u32)>, ProviderError>;

/// Fetch and peek-lock up to `max_items` worker queue items in one provider call.
///
/// `max_new_sessions` is a transaction-scoped cap for how many previously-unowned
/// sessions may be claimed by this fetch. Items for already-owned sessions and
/// non-session items do not consume this quota.
///
/// The default implementation preserves compatibility. It does not provide true
/// atomic batching and intentionally falls back to a single fetch when session
/// routing is enabled, so unaudited providers cannot claim multiple sessions in
/// separate transactions.
async fn fetch_work_items(
&self,
lock_timeout: Duration,
poll_timeout: Duration,
session: Option<&SessionFetchConfig>,
tag_filter: &TagFilter,
max_items: usize,
_max_new_sessions: usize,
) -> Result<Vec<(WorkItem, String, u32)>, ProviderError> {
if max_items == 0 || matches!(tag_filter, TagFilter::None) {
return Ok(Vec::new());
}

if max_items > 1 && session.is_some() {
return Ok(self
.fetch_work_item(lock_timeout, poll_timeout, session, tag_filter)
.await?
.into_iter()
.collect());
}

let mut items = Vec::with_capacity(max_items.min(16));
for _ in 0..max_items {
match self
.fetch_work_item(lock_timeout, poll_timeout, session, tag_filter)
.await?
{
Some(item) => items.push(item),
None => break,
}
}
Ok(items)
}

/// Acknowledge successful processing of a work item.
///
/// # What This Does
Expand Down
187 changes: 187 additions & 0 deletions src/providers/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
use sqlx::{Row, Sqlite, Transaction};
use std::collections::HashSet;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::debug;

Expand Down Expand Up @@ -718,6 +719,10 @@ impl Provider for SqliteProvider {
env!("CARGO_PKG_VERSION")
}

fn supports_batched_work_item_fetch(&self) -> bool {
true
}

async fn fetch_orchestration_item(
&self,
lock_timeout: Duration,
Expand Down Expand Up @@ -1973,6 +1978,188 @@ impl Provider for SqliteProvider {
Ok(Some((work_item, lock_token, attempt_count)))
}

async fn fetch_work_items(
&self,
lock_timeout: Duration,
_poll_timeout: Duration,
session: Option<&SessionFetchConfig>,
tag_filter: &TagFilter,
max_items: usize,
max_new_sessions: usize,
) -> Result<Vec<(WorkItem, String, u32)>, ProviderError> {
if max_items == 0 || matches!(tag_filter, TagFilter::None) {
return Ok(Vec::new());
}

let max_items = max_items.min(8);
let scan_limit = max_items.saturating_mul(4).max(max_items);
let mut tx = self
.pool
.begin()
.await
.map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?;

let now_ms = Self::now_millis();
let locked_until = Self::timestamp_after(lock_timeout);
let batch_token = Self::generate_lock_token();

let tag_start_param = if session.is_some() { 3 } else { 2 };
let tag_clause = Self::build_tag_clause(tag_filter, tag_start_param);
let tag_values = Self::collect_tag_values(tag_filter);
let limit_param = tag_start_param + tag_values.len();

let rows = if let Some(config) = session {
let sql = format!(
r#"
SELECT q.id, q.work_item, q.attempt_count, q.session_id, s.worker_id AS active_worker_id
FROM worker_queue q
LEFT JOIN sessions s ON s.session_id = q.session_id AND s.locked_until > ?1
WHERE q.visible_at <= ?1
AND (q.lock_token IS NULL OR q.locked_until <= ?1)
AND (
q.session_id IS NULL
OR s.worker_id = ?2
OR s.session_id IS NULL
)
AND ({tag_clause})
ORDER BY q.id
LIMIT ?{limit_param}
"#,
);
let mut query = sqlx::query(&sql).bind(now_ms).bind(&config.owner_id);
for val in &tag_values {
query = query.bind(val.as_str());
}
query
.bind(scan_limit as i64)
.fetch_all(&mut *tx)
.await
.map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?
} else {
let sql = format!(
r#"
SELECT q.id, q.work_item, q.attempt_count, q.session_id, NULL AS active_worker_id
FROM worker_queue q
WHERE q.visible_at <= ?1
AND (q.lock_token IS NULL OR q.locked_until <= ?1)
AND q.session_id IS NULL
AND ({tag_clause})
ORDER BY q.id
LIMIT ?{limit_param}
"#,
);
let mut query = sqlx::query(&sql).bind(now_ms);
for val in &tag_values {
query = query.bind(val.as_str());
}
query
.bind(scan_limit as i64)
.fetch_all(&mut *tx)
.await
.map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?
};

let mut out = Vec::with_capacity(max_items);
let mut new_sessions_claimed = 0usize;
let mut claimed_session_ids = HashSet::new();

for row in rows {
if out.len() >= max_items {
break;
}

let id: i64 = row
.try_get("id")
.map_err(|e| ProviderError::permanent("fetch_work_items", format!("Failed to get id: {e}")))?;
let work_item_str: String = row
.try_get("work_item")
.map_err(|e| ProviderError::permanent("fetch_work_items", format!("Failed to get work_item: {e}")))?;
let current_attempt_count: i64 = row.try_get("attempt_count").map_err(|e| {
ProviderError::permanent("fetch_work_items", format!("Failed to get attempt_count: {e}"))
})?;
let session_id: Option<String> = row
.try_get("session_id")
.map_err(|e| ProviderError::permanent("fetch_work_items", format!("Failed to get session_id: {e}")))?;
let active_worker_id: Option<String> = row.try_get("active_worker_id").unwrap_or(None);

let work_item: WorkItem = serde_json::from_str(&work_item_str)
.map_err(|e| ProviderError::permanent("fetch_work_items", format!("Deserialization error: {e}")))?;

if let (Some(sid), Some(config)) = (&session_id, session) {
let already_owned =
active_worker_id.as_deref() == Some(config.owner_id.as_str()) || claimed_session_ids.contains(sid);
if !already_owned {
if new_sessions_claimed >= max_new_sessions {
continue;
}
new_sessions_claimed += 1;
}

let session_locked_until = now_ms + config.lock_timeout.as_millis() as i64;
let upsert_result = sqlx::query(
r#"
INSERT INTO sessions (session_id, worker_id, locked_until, last_activity_at)
VALUES (?1, ?2, ?3, ?4)
ON CONFLICT (session_id) DO UPDATE
SET worker_id = ?2,
locked_until = ?3,
last_activity_at = ?4
WHERE sessions.locked_until <= ?4 OR sessions.worker_id = ?2
"#,
)
.bind(sid)
.bind(&config.owner_id)
.bind(session_locked_until)
.bind(now_ms)
.execute(&mut *tx)
.await
.map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?;

if upsert_result.rows_affected() == 0 {
if !already_owned {
new_sessions_claimed = new_sessions_claimed.saturating_sub(1);
}
continue;
}
claimed_session_ids.insert(sid.clone());
}

let lock_token = format!("{batch_token}_{id}");
let update_result = sqlx::query(
r#"
UPDATE worker_queue
SET lock_token = ?1, locked_until = ?2, attempt_count = attempt_count + 1
WHERE id = ?3
AND (lock_token IS NULL OR locked_until <= ?4)
"#,
)
.bind(&lock_token)
.bind(locked_until)
.bind(id)
.bind(now_ms)
.execute(&mut *tx)
.await
.map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?;

if update_result.rows_affected() == 0 {
if let Some(sid) = &session_id {
if claimed_session_ids.remove(sid) {
new_sessions_claimed = new_sessions_claimed.saturating_sub(1);
}
}
continue;
}

out.push((work_item, lock_token, (current_attempt_count + 1) as u32));
}

tx.commit()
.await
.map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?;

Ok(out)
}

async fn ack_work_item(&self, token: &str, completion: Option<WorkItem>) -> Result<(), ProviderError> {
let mut tx = self
.pool
Expand Down
Loading
Loading