diff --git a/README.md b/README.md index 2f7190f..cb19463 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,33 @@ run_task("npm run build", queue_name="web", ...) Both agents block until their respective builds complete. The server handles sequencing automatically. +**Hierarchical queues** - Use `/`-delimited queue names plus `--queue-capacity` when you need +parallelism with a shared cap: + +```bash +uvx agent-task-queue@latest \ + --queue-capacity=gradle=2 \ + --queue-capacity=gradle/emu-5557=1 \ + --queue-capacity=gradle/emu-5559=1 +``` + +```python +run_task("./gradlew assembleDebug assembleDebugAndroidTest", queue_name="gradle/build", ...) +run_task("./gradlew connectedDebugAndroidTest -x assembleDebug -x assembleDebugAndroidTest", queue_name="gradle/emu-5557", env_vars="ANDROID_SERIAL=127.0.0.1:5557", ...) +run_task("./gradlew connectedDebugAndroidTest -x assembleDebug -x assembleDebugAndroidTest", queue_name="gradle/emu-5559", env_vars="ANDROID_SERIAL=127.0.0.1:5559", ...) +``` + +Queue capacities apply to each command for its entire lifetime. For Android-style workflows, that +usually means queueing shared Gradle prep/build first, then fan out emulator-specific commands that +reuse those outputs; see [Android Multi-Emulator Pattern](#android-multi-emulator-pattern). + +Configured capacities apply to a scope and all of its descendants. In the example above, the +shared `gradle` scope allows at most two concurrent Gradle-backed tasks, while each emulator leaf +queue remains exclusive. If multiple servers or `tq` CLI invocations share the same data +directory, start them with matching `--queue-capacity` flags; these overrides are process-local and +are not persisted in `queue.db`. If you do not configure any capacities, behavior is unchanged: +each exact `queue_name` is still a FIFO queue with capacity 1. + ## Demo: Two Agents, One Build Queue **Terminal A** - First agent requests an Android build: @@ -86,7 +113,7 @@ With the queue: ## Key Features -- **FIFO Queuing**: Strict first-in-first-out ordering +- **FIFO Queuing**: Strict first-in-first-out ordering within each exact `queue_name` - **No Queue Timeouts**: MCP keeps connection alive while waiting in queue. The `timeout_seconds` parameter only applies to execution time—tasks can wait in queue indefinitely without timing out. (see [Why MCP?](#why-mcp-instead-of-a-cli-tool)) - **Environment Variables**: Pass `env_vars="ANDROID_SERIAL=emulator-5560"` - **Multiple Queues**: Isolate different workloads with `queue_name` @@ -242,6 +269,12 @@ Agents use the `run_task` MCP tool for expensive operations: | `timeout_seconds` | No | Max **execution** time before kill (default: 1200). Queue wait time doesn't count. | | `env_vars` | No | Environment variables: `"KEY=val,KEY2=val2"` | +`queue_name` may be hierarchical, such as `gradle/emu-5557`, when the server is configured with +`--queue-capacity` scopes. + +Sibling queues that share a parent scope compete for that parent capacity on a best-effort basis; +FIFO ordering is guaranteed within each exact queue, not across sibling queues. + ### Example ``` @@ -253,6 +286,52 @@ run_task( ) ``` +### Android Multi-Emulator Pattern + +If your machine can safely run a small number of Gradle-backed device tests in parallel, use a +shared Gradle scope plus one queue per emulator. Because queue capacities only see whole commands, +the practical pattern is to split shared Gradle prep/build from emulator-specific execution. + +First, queue the shared Gradle prep/build once: + +```python +run_task( + command="./gradlew assembleDebug assembleDebugAndroidTest", + working_directory="/project", + queue_name="gradle/build", +) +``` + +Then fan out one task per emulator using a command that reuses those prebuilt outputs: + +```bash +uvx agent-task-queue@latest \ + --queue-capacity=gradle=2 \ + --queue-capacity=gradle/emu-5557=1 \ + --queue-capacity=gradle/emu-5559=1 \ + --queue-capacity=gradle/emu-5561=1 +``` + +Then pin each task to the matching queue and `ANDROID_SERIAL`: + +```python +run_task( + command="./gradlew connectedDebugAndroidTest -x assembleDebug -x assembleDebugAndroidTest", + working_directory="/project", + queue_name="gradle/emu-5557", + env_vars="ANDROID_SERIAL=127.0.0.1:5557", +) +``` + +Adapt the exact Gradle tasks and `-x` exclusions to your project. The key is that the second step +must reuse the outputs from the shared prep step instead of rebuilding them in every emulator queue. + +If your emulator execution phase no longer needs Gradle at all, queue it outside the shared +`gradle` scope entirely so only the build/prep step consumes shared Gradle capacity. + +When multiple entrypoints share this queue database, they must all use the same +`--queue-capacity` configuration for the shared parent caps to mean the same thing. + ### Agent Configuration Notes Some agents need additional configuration to use the queue instead of built-in shell commands. @@ -295,6 +374,7 @@ The server supports the following command-line options: | `--max-output-files` | `50` | Number of task output files to retain | | `--tail-lines` | `50` | Lines of output to include on failure | | `--lock-timeout` | `120` | Minutes before stale locks are cleared | +| `--queue-capacity` | none | Repeatable `scope=capacity` override for hierarchical queue names | Pass options via the `args` property in your MCP config: @@ -306,7 +386,9 @@ Pass options via the `args` property in your MCP config: "args": [ "agent-task-queue@latest", "--max-output-files=100", - "--lock-timeout=60" + "--lock-timeout=60", + "--queue-capacity=gradle=2", + "--queue-capacity=gradle/emu-5557=1" ] } } diff --git a/queue_core.py b/queue_core.py index a846e18..f671903 100644 --- a/queue_core.py +++ b/queue_core.py @@ -9,6 +9,7 @@ import json import os import signal +import shlex import sqlite3 from contextlib import contextmanager from dataclasses import dataclass @@ -19,9 +20,9 @@ # --- Configuration --- DEFAULT_DATA_DIR = Path(os.environ.get("TASK_QUEUE_DATA_DIR", "/tmp/agent-task-queue")) POLL_INTERVAL_WAITING = float(os.environ.get("TASK_QUEUE_POLL_WAITING", "1")) -POLL_INTERVAL_READY = float(os.environ.get("TASK_QUEUE_POLL_READY", "1")) DEFAULT_MAX_LOCK_AGE_MINUTES = 120 DEFAULT_MAX_METRICS_SIZE_MB = 5 +QUEUE_SCOPE_SEPARATOR = "/" @dataclass @@ -105,6 +106,221 @@ def init_db(paths: QueuePaths): pass # Column already exists +def normalize_queue_name(queue_name: str) -> str: + """Collapse redundant separators and whitespace in queue names.""" + parts = [part.strip() for part in queue_name.split(QUEUE_SCOPE_SEPARATOR) if part.strip()] + if not parts: + raise ValueError("queue_name must contain at least one non-empty segment") + return QUEUE_SCOPE_SEPARATOR.join(parts) + + +def parse_queue_capacities(capacity_args: list[str] | None) -> dict[str, int]: + """Parse repeated scope=capacity CLI arguments into a normalized map.""" + capacities: dict[str, int] = {} + for arg in capacity_args or []: + if "=" not in arg: + raise ValueError( + f"Invalid --queue-capacity value '{arg}'. Expected SCOPE=CAPACITY." + ) + + scope, raw_capacity = arg.split("=", 1) + scope = normalize_queue_name(scope) + try: + capacity = int(raw_capacity) + except ValueError as exc: + raise ValueError( + f"Invalid capacity '{raw_capacity}' for scope '{scope}'. Expected a positive integer." + ) from exc + + if capacity < 1: + raise ValueError( + f"Invalid capacity '{capacity}' for scope '{scope}'. Capacity must be >= 1." + ) + + capacities[scope] = capacity + + return capacities + + +def queue_scopes(queue_name: str) -> list[str]: + """Return hierarchical scopes from broadest to most specific.""" + normalized = normalize_queue_name(queue_name) + parts = normalized.split(QUEUE_SCOPE_SEPARATOR) + return [QUEUE_SCOPE_SEPARATOR.join(parts[:idx]) for idx in range(1, len(parts) + 1)] + + +def queue_names_in_scope(conn, scope: str) -> list[str]: + """Return queue names in a scope, including descendant queues.""" + normalized_scope = normalize_queue_name(scope) + escaped_scope = escape_like_pattern(normalized_scope) + rows = conn.execute( + """SELECT DISTINCT queue_name + FROM queue + WHERE queue_name = ? + OR queue_name LIKE ? ESCAPE '\\' + ORDER BY queue_name""", + (normalized_scope, f"{escaped_scope}{QUEUE_SCOPE_SEPARATOR}%"), + ).fetchall() + return [row["queue_name"] for row in rows] + + +def cleanup_targets_for_queue( + conn, + queue_name: str, + queue_capacities: dict[str, int] | None, +) -> list[str]: + """Return queue names that should be reaped before capacity checks.""" + normalized_queue = normalize_queue_name(queue_name) + targets = [normalized_queue] + + if not queue_capacities: + return targets + + for scope in queue_scopes(normalized_queue): + if scope not in queue_capacities: + continue + targets.extend(queue_names_in_scope(conn, scope)) + + seen: set[str] = set() + unique_targets: list[str] = [] + for target in targets: + if target in seen: + continue + seen.add(target) + unique_targets.append(target) + return unique_targets + + +def escape_like_pattern(value: str) -> str: + """Escape SQLite LIKE wildcards in a literal scope name.""" + return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + +def count_running_in_scope(conn, scope: str) -> int: + """Count running tasks in a scope, including descendant queues.""" + escaped_scope = escape_like_pattern(scope) + row = conn.execute( + """SELECT COUNT(*) AS c + FROM queue + WHERE status = 'running' + AND (queue_name = ? OR queue_name LIKE ? ESCAPE '\\')""", + (scope, f"{escaped_scope}{QUEUE_SCOPE_SEPARATOR}%"), + ).fetchone() + return row["c"] + + +def count_running_in_queue(conn, queue_name: str) -> int: + """Count running tasks in the exact queue only.""" + row = conn.execute( + "SELECT COUNT(*) AS c FROM queue WHERE queue_name = ? AND status = 'running'", + (queue_name,), + ).fetchone() + return row["c"] + + +def count_waiting_ahead(conn, queue_name: str, task_id: int) -> int: + """Count older waiting tasks in the exact queue.""" + row = conn.execute( + """SELECT COUNT(*) AS c + FROM queue + WHERE queue_name = ? + AND status = 'waiting' + AND id < ?""", + (queue_name, task_id), + ).fetchone() + return row["c"] + + +def can_acquire_task( + conn, + task_id: int, + queue_name: str, + queue_capacities: dict[str, int], + waiting_ahead: int | None = None, +) -> bool: + """Return True when the task can start without violating queue capacities.""" + normalized_queue = normalize_queue_name(queue_name) + scopes = queue_scopes(normalized_queue) + + exact_capacity = queue_capacities.get(normalized_queue, 1) + if normalized_queue in queue_capacities: + available_slots = exact_capacity - count_running_in_scope(conn, normalized_queue) + else: + available_slots = exact_capacity - count_running_in_queue(conn, normalized_queue) + + for scope in scopes: + if scope not in queue_capacities or scope == normalized_queue: + continue + available_slots = min( + available_slots, + queue_capacities[scope] - count_running_in_scope(conn, scope), + ) + + if available_slots <= 0: + return False + + if waiting_ahead is None: + waiting_ahead = count_waiting_ahead(conn, normalized_queue, task_id) + + if waiting_ahead >= available_slots: + return False + + return True + + +def attempt_task_start( + conn, + task_id: int, + queue_name: str, + queue_capacities: dict[str, int], + owner_pid: int, +) -> tuple[bool, int]: + """Attempt to transition a waiting task into running state. + + Returns: + Tuple of (started, queue_position). queue_position is only meaningful when started is False. + """ + previous_busy_timeout = None + + try: + previous_busy_timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0] + conn.execute("PRAGMA busy_timeout=100") + conn.execute("BEGIN IMMEDIATE") + waiting_ahead = count_waiting_ahead(conn, queue_name, task_id) + + if not can_acquire_task( + conn, + task_id, + queue_name, + queue_capacities, + waiting_ahead=waiting_ahead, + ): + position = waiting_ahead + 1 + conn.commit() + return False, position + + cursor = conn.execute( + """UPDATE queue SET status = 'running', updated_at = ?, pid = ? + WHERE id = ? AND status = 'waiting'""", + (datetime.now().isoformat(), owner_pid, task_id), + ) + if cursor.rowcount > 0: + conn.commit() + return True, 0 + + position = waiting_ahead + 1 + conn.commit() + return False, position + except Exception: + if conn.in_transaction: + conn.rollback() + raise + finally: + if previous_busy_timeout is not None: + # SQLite PRAGMA statements do not support bound parameters. + conn.execute("PRAGMA busy_timeout=" + str(int(previous_busy_timeout))) + + def ensure_db(paths: QueuePaths): """Ensure database exists and is valid. Recreates if corrupted.""" try: @@ -159,12 +375,28 @@ def is_task_queue_process(pid: int) -> bool: return False cmdline = result.stdout.strip().lower() - return ( - "task_queue" in cmdline - or "agent-task-queue" in cmdline - or "tq.py" in cmdline - or "pytest" in cmdline # For pytest running tests - ) + known_entrypoints = { + "task_queue", + "task_queue.py", + "agent-task-queue", + "tq", + "tq.py", + "pytest", + } + + try: + argv = shlex.split(cmdline) + except ValueError: + argv = cmdline.split() + + # Installed entrypoints appear as basenames like `tq`, while module launches often + # look like `python .../tq.py`, so inspect the first two argv entries before falling + # back to the broader historical substring checks. + for token in argv[:2]: + if Path(token).name in known_entrypoints: + return True + + return "task_queue" in cmdline or "agent-task-queue" in cmdline except Exception: # If we can't check, assume valid (conservative - avoid false orphan cleanup) return True diff --git a/task_queue.py b/task_queue.py index 2362b1c..e44076a 100644 --- a/task_queue.py +++ b/task_queue.py @@ -12,6 +12,7 @@ import resource import signal import sqlite3 +import sys import time import threading import uuid @@ -31,12 +32,15 @@ init_db as _init_db, ensure_db as _ensure_db, cleanup_queue as _cleanup_queue, + cleanup_targets_for_queue, log_metric as _log_metric, log_fmt, is_process_alive, kill_process_tree, + normalize_queue_name, + parse_queue_capacities, + attempt_task_start, POLL_INTERVAL_WAITING, - POLL_INTERVAL_READY, ) # Unique identifier for this server instance - used to detect orphaned tasks @@ -84,16 +88,37 @@ def parse_args(): default=120, help="Minutes before stale locks are cleared (default: 120)", ) + parser.add_argument( + "--queue-capacity", + action="append", + default=[], + metavar="SCOPE=CAPACITY", + help=( + "Hierarchical queue capacity override. Repeatable. " + "Example: --queue-capacity=gradle=2 --queue-capacity=gradle/emu-5557=1" + ), + ) return parser.parse_args() +def _should_parse_module_args(argv0: str | None = None, module_name: str | None = None) -> bool: + """Return True when the module is being launched as the task queue server.""" + module_name = module_name or __name__ + if module_name == "__main__": + return True + + executable = Path(argv0 or sys.argv[0]).name + return executable in {"agent-task-queue", "task_queue", "task_queue.py"} + + # Parse args at module load (before MCP server starts) -_args = parse_args() if __name__ == "__main__" else argparse.Namespace( +_args = parse_args() if _should_parse_module_args() else argparse.Namespace( data_dir="/tmp/agent-task-queue", max_log_size=5, max_output_files=50, tail_lines=50, lock_timeout=120, + queue_capacity=[], ) # --- Configuration --- @@ -104,6 +129,7 @@ def parse_args(): TAIL_LINES_ON_FAILURE = _args.tail_lines SERVER_NAME = "Task Queue" MAX_LOCK_AGE_MINUTES = _args.lock_timeout +QUEUE_CAPACITIES = parse_queue_capacities(_args.queue_capacity) mcp = FastMCP(SERVER_NAME) @@ -130,67 +156,74 @@ def log_metric(event: str, **kwargs): _log_metric(PATHS.metrics_path, event, MAX_METRICS_SIZE_MB, **kwargs) -def cleanup_queue(conn, queue_name: str): +def cleanup_queue(conn, queue_name: str, queue_capacities: dict[str, int] | None = None): """Clean up queue using configured paths and detect orphaned tasks.""" - _cleanup_queue( - conn, - queue_name, - PATHS.metrics_path, - MAX_LOCK_AGE_MINUTES, - log_fn=lambda msg: print(log_fmt(msg)), - ) - - my_pid = os.getpid() - - # Cleanup 1: Tasks with our PID but DIFFERENT server_id (from old server instance) - # This handles the edge case where PID is reused after server restart - stale_server_tasks = conn.execute( - "SELECT id, status, child_pid, server_id FROM queue WHERE queue_name = ? AND pid = ? AND server_id IS NOT NULL AND server_id != ?", - (queue_name, my_pid, SERVER_INSTANCE_ID), - ).fetchall() - - for task in stale_server_tasks: - if task["child_pid"] and is_process_alive(task["child_pid"]): - print(log_fmt(f"WARNING: Killing orphaned subprocess {task['child_pid']} from old server")) - kill_process_tree(task["child_pid"]) - - conn.execute("DELETE FROM queue WHERE id = ?", (task["id"],)) - log_metric( - "orphan_cleared", - task_id=task["id"], - queue_name=queue_name, - status=task["status"], - old_server_id=task["server_id"], - reason="stale_server_instance", + if queue_capacities is None: + queue_capacities = QUEUE_CAPACITIES + + for target_queue in cleanup_targets_for_queue(conn, queue_name, queue_capacities): + _cleanup_queue( + conn, + target_queue, + PATHS.metrics_path, + MAX_LOCK_AGE_MINUTES, + log_fn=lambda msg: print(log_fmt(msg)), ) - print(log_fmt(f"WARNING: Cleared task from old server instance (ID: {task['id']}, old_server: {task['server_id']})")) - # Cleanup 2: Tasks with our PID AND server_id but not in active tracking set - # This catches tasks left behind when clients disconnect without proper cleanup - our_tasks = conn.execute( - "SELECT id, status, child_pid FROM queue WHERE queue_name = ? AND pid = ? AND (server_id = ? OR server_id IS NULL)", - (queue_name, my_pid, SERVER_INSTANCE_ID), - ).fetchall() + my_pid = os.getpid() - with _active_task_ids_lock: - active_ids = _active_task_ids.copy() + # Cleanup 1: Tasks with our PID but DIFFERENT server_id (from old server instance) + # This handles the edge case where PID is reused after server restart + stale_server_tasks = conn.execute( + "SELECT id, status, child_pid, server_id FROM queue WHERE queue_name = ? AND pid = ? AND server_id IS NOT NULL AND server_id != ?", + (target_queue, my_pid, SERVER_INSTANCE_ID), + ).fetchall() - for orphan in our_tasks: - if orphan["id"] not in active_ids: - # This task belongs to us but we're not tracking it - it's orphaned - if orphan["child_pid"] and is_process_alive(orphan["child_pid"]): - print(log_fmt(f"WARNING: Killing orphaned subprocess {orphan['child_pid']}")) - kill_process_tree(orphan["child_pid"]) + for task in stale_server_tasks: + if task["child_pid"] and is_process_alive(task["child_pid"]): + print(log_fmt(f"WARNING: Killing orphaned subprocess {task['child_pid']} from old server")) + kill_process_tree(task["child_pid"]) - conn.execute("DELETE FROM queue WHERE id = ?", (orphan["id"],)) + conn.execute("DELETE FROM queue WHERE id = ?", (task["id"],)) log_metric( "orphan_cleared", - task_id=orphan["id"], - queue_name=queue_name, - status=orphan["status"], - reason="not_in_active_set", + task_id=task["id"], + queue_name=target_queue, + status=task["status"], + old_server_id=task["server_id"], + reason="stale_server_instance", ) - print(log_fmt(f"WARNING: Cleared orphaned task (ID: {orphan['id']}, status: {orphan['status']})")) + print(log_fmt(f"WARNING: Cleared task from old server instance (ID: {task['id']}, old_server: {task['server_id']})")) + + # Cleanup 2: Tasks with our PID AND server_id but not in active tracking set + # This catches tasks left behind when clients disconnect without proper cleanup + our_tasks = conn.execute( + "SELECT id, status, child_pid FROM queue WHERE queue_name = ? AND pid = ? AND (server_id = ? OR server_id IS NULL)", + (target_queue, my_pid, SERVER_INSTANCE_ID), + ).fetchall() + + with _active_task_ids_lock: + active_ids = _active_task_ids.copy() + + for orphan in our_tasks: + if orphan["id"] not in active_ids: + # This task belongs to us but we're not tracking it - it's orphaned + if orphan["child_pid"] and is_process_alive(orphan["child_pid"]): + print(log_fmt(f"WARNING: Killing orphaned subprocess {orphan['child_pid']}")) + kill_process_tree(orphan["child_pid"]) + + conn.execute("DELETE FROM queue WHERE id = ?", (orphan["id"],)) + log_metric( + "orphan_cleared", + task_id=orphan["id"], + queue_name=target_queue, + status=orphan["status"], + reason="not_in_active_set", + ) + print(log_fmt(f"WARNING: Cleared orphaned task (ID: {orphan['id']}, status: {orphan['status']})")) + + if conn.in_transaction: + conn.commit() # --- Output File Management --- @@ -238,13 +271,15 @@ def get_memory_mb() -> float: # --- Core Queue Logic --- async def wait_for_turn(queue_name: str, command: str | None = None) -> int: """Register task, wait for turn, return task ID when acquired.""" + queue_name = normalize_queue_name(queue_name) + # Ensure database exists and is valid ensure_db() # Run cleanup BEFORE inserting - this clears orphaned tasks that would otherwise # block the queue forever (since cleanup only runs during polling) with get_db() as conn: - cleanup_queue(conn, queue_name) + cleanup_queue(conn, queue_name, QUEUE_CAPACITIES) my_pid = os.getpid() ctx = None @@ -277,23 +312,30 @@ async def wait_for_turn(queue_name: str, command: str | None = None) -> int: try: while True: - with get_db() as conn: - cleanup_queue(conn, queue_name) - - runner = conn.execute( - "SELECT id FROM queue WHERE queue_name = ? AND status = 'running'", - (queue_name,), - ).fetchone() - - if runner: - pos = ( - conn.execute( - "SELECT COUNT(*) as c FROM queue WHERE queue_name = ? AND status = 'waiting' AND id < ?", - (queue_name, task_id), - ).fetchone()["c"] - + 1 + try: + with get_db() as conn: + cleanup_queue(conn, queue_name, QUEUE_CAPACITIES) + + started, pos = attempt_task_start( + conn, + task_id, + queue_name, + QUEUE_CAPACITIES, + my_pid, ) + if started: + wait_time = time.time() - queued_at + log_metric( + "task_started", + task_id=task_id, + queue_name=queue_name, + wait_time_seconds=round(wait_time, 2), + ) + if ctx: + await ctx.info(log_fmt("Lock ACQUIRED. Starting execution.")) + return task_id + wait_ticks += 1 if pos != last_pos: @@ -306,38 +348,11 @@ async def wait_for_turn(queue_name: str, command: str | None = None) -> int: f"Still waiting... Position #{pos} ({int(wait_ticks * POLL_INTERVAL_WAITING)}s elapsed)" ) ) + except sqlite3.OperationalError as exc: + if "database is locked" not in str(exc).lower(): + raise - await asyncio.sleep(POLL_INTERVAL_WAITING) - continue - - # Atomic lock acquisition: UPDATE only succeeds if we're the first - # waiting task AND no one is currently running. This prevents race - # conditions where two tasks both think they're next. - cursor = conn.execute( - """UPDATE queue SET status = 'running', updated_at = ?, pid = ? - WHERE id = ? AND status = 'waiting' - AND NOT EXISTS ( - SELECT 1 FROM queue WHERE queue_name = ? AND status = 'running' - ) - AND id = ( - SELECT MIN(id) FROM queue WHERE queue_name = ? AND status = 'waiting' - )""", - (datetime.now().isoformat(), my_pid, task_id, queue_name, queue_name), - ) - - if cursor.rowcount > 0: - wait_time = time.time() - queued_at - log_metric( - "task_started", - task_id=task_id, - queue_name=queue_name, - wait_time_seconds=round(wait_time, 2), - ) - if ctx: - await ctx.info(log_fmt("Lock ACQUIRED. Starting execution.")) - return task_id - - await asyncio.sleep(POLL_INTERVAL_READY) + await asyncio.sleep(POLL_INTERVAL_WAITING) except asyncio.CancelledError: # Client disconnected (e.g., sub-agent cancelled) - clean up our queue entry with _active_task_ids_lock: @@ -438,6 +453,8 @@ async def run_task( command: The full shell command to run. working_directory: ABSOLUTE path to the execution root. queue_name: Queue identifier for grouping tasks (default: "global"). + Queue names may be hierarchical (for example `gradle/emu-5557`) when the server + is configured with `--queue-capacity` scopes. timeout_seconds: Max **execution** time before killing the task (default: 1200 = 20 mins). Queue wait time does NOT count against this timeout. env_vars: Environment variables to set, format: "KEY1=value1,KEY2=value2" @@ -451,6 +468,11 @@ async def run_task( if not os.path.exists(working_directory): return f"ERROR: Working directory does not exist: {working_directory}" + try: + queue_name = normalize_queue_name(queue_name) + except ValueError as exc: + return f"ERROR: {str(exc)}" + # Parse environment variables env = os.environ.copy() if env_vars: diff --git a/tests/test_queue.py b/tests/test_queue.py index 115ecb9..d2f4339 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -6,6 +6,7 @@ import pytest import asyncio import os +import subprocess import time from pathlib import Path @@ -15,6 +16,8 @@ from datetime import datetime, timedelta from fastmcp import Client +import queue_core +import task_queue from task_queue import ( mcp, PATHS, @@ -25,6 +28,11 @@ cleanup_queue, MAX_LOCK_AGE_MINUTES, ) +from queue_core import ( + attempt_task_start, + cleanup_queue as cleanup_queue_core, + parse_queue_capacities, +) # Use PATHS for database path DB_PATH = PATHS.db_path @@ -49,6 +57,12 @@ def clean_db(): DB_PATH.unlink() +@pytest.fixture(autouse=True) +def reset_queue_capacities(monkeypatch): + """Reset queue capacity overrides between tests.""" + monkeypatch.setattr(task_queue, "QUEUE_CAPACITIES", {}) + + @pytest.fixture def client(): """Create FastMCP client connected to our server.""" @@ -259,6 +273,153 @@ async def test_different_queues_isolation(): assert "Queue Alpha Again" in read_output_file(str(result3)) +@pytest.mark.asyncio +async def test_parent_capacity_blocks_different_child_queues(monkeypatch): + """A parent scope with capacity 1 should serialize its child queues.""" + monkeypatch.setattr(task_queue, "QUEUE_CAPACITIES", parse_queue_capacities(["gradle=1"])) + + results = {} + end_times = {} + overall_start = time.time() + + async def run_task_a(): + client = Client(mcp) + async with client: + result = await client.call_tool( + "run_task", + { + "command": "sleep 2 && echo 'EMU 5557 done'", + "working_directory": "/tmp", + "queue_name": "gradle/emu-5557", + }, + ) + end_times["A"] = time.time() + results["A"] = str(result) + + async def run_task_b(): + await asyncio.sleep(0.3) + client = Client(mcp) + async with client: + result = await client.call_tool( + "run_task", + { + "command": "echo 'EMU 5559 done'", + "working_directory": "/tmp", + "queue_name": "gradle/emu-5559", + }, + ) + end_times["B"] = time.time() + results["B"] = str(result) + + await asyncio.gather(run_task_a(), run_task_b()) + + assert "SUCCESS" in results["A"] + assert "SUCCESS" in results["B"] + assert "EMU 5557 done" in read_output_file(results["A"]) + assert "EMU 5559 done" in read_output_file(results["B"]) + assert time.time() - overall_start >= 1.8 + assert end_times["B"] >= end_times["A"] - 0.3 + + +@pytest.mark.asyncio +async def test_parent_capacity_preserves_fifo_within_child_queue(monkeypatch): + """A tighter parent scope should not let a younger child task jump the queue.""" + monkeypatch.setattr( + task_queue, + "QUEUE_CAPACITIES", + parse_queue_capacities(["gradle=1", "gradle/emu-5557=2"]), + ) + + results = {} + end_times = {} + + async def run_parent_blocker(): + client = Client(mcp) + async with client: + result = await client.call_tool( + "run_task", + { + "command": "sleep 2 && echo 'parent done'", + "working_directory": "/tmp", + "queue_name": "gradle/emu-5559", + }, + ) + results["parent"] = str(result) + + async def run_older_child(): + await asyncio.sleep(0.2) + client = Client(mcp) + async with client: + result = await client.call_tool( + "run_task", + { + "command": "sleep 1 && echo 'older child done'", + "working_directory": "/tmp", + "queue_name": "gradle/emu-5557", + }, + ) + end_times["older"] = time.time() + results["older"] = str(result) + + async def run_younger_child(): + await asyncio.sleep(0.4) + client = Client(mcp) + async with client: + result = await client.call_tool( + "run_task", + { + "command": "echo 'younger child done'", + "working_directory": "/tmp", + "queue_name": "gradle/emu-5557", + }, + ) + end_times["younger"] = time.time() + results["younger"] = str(result) + + await asyncio.gather(run_parent_blocker(), run_older_child(), run_younger_child()) + + assert "SUCCESS" in results["older"] + assert "SUCCESS" in results["younger"] + assert "older child done" in read_output_file(results["older"]) + assert "younger child done" in read_output_file(results["younger"]) + assert end_times["older"] <= end_times["younger"] + + +@pytest.mark.asyncio +async def test_parent_capacity_allows_parallel_child_queues(monkeypatch): + """A parent scope with capacity 2 should allow two child queues to run together.""" + monkeypatch.setattr(task_queue, "QUEUE_CAPACITIES", parse_queue_capacities(["gradle=2"])) + + results = {} + end_times = {} + overall_start = time.time() + + async def run_child(queue_name: str, result_key: str): + client = Client(mcp) + async with client: + result = await client.call_tool( + "run_task", + { + "command": f"sleep 2 && echo '{queue_name} done'", + "working_directory": "/tmp", + "queue_name": queue_name, + }, + ) + end_times[result_key] = time.time() + results[result_key] = str(result) + + await asyncio.gather( + run_child("gradle/emu-5557", "A"), + run_child("gradle/emu-5559", "B"), + ) + + total_elapsed = time.time() - overall_start + assert "SUCCESS" in results["A"] + assert "SUCCESS" in results["B"] + assert total_elapsed < 3.5 + assert abs(end_times["A"] - end_times["B"]) < 1.0 + + @pytest.mark.asyncio async def test_tool_available(client): """Test that the run_task tool is available.""" @@ -814,6 +975,131 @@ def test_orphan_cleanup_removes_untracked_task(): assert count_after == 0, "Untracked task for our PID should be cleaned up" +def test_is_task_queue_process_accepts_installed_tq_entrypoint(monkeypatch): + """Installed tq entrypoints should be treated as live queue owners.""" + monkeypatch.setattr(queue_core, "is_process_alive", lambda pid: True) + + def fake_run(*args, **kwargs): + return subprocess.CompletedProcess( + args[0], + 0, + stdout="/Users/test/.local/bin/tq run echo hi\n", + stderr="", + ) + + monkeypatch.setattr(subprocess, "run", fake_run) + + assert queue_core.is_task_queue_process(12345) is True + + +def test_attempt_task_start_after_core_cleanup_commit_on_same_connection(): + """Callers can reuse the same connection after committing cleanup work.""" + dead_pid = 999999999 + my_pid = os.getpid() + + with get_db() as conn: + conn.execute( + """INSERT INTO queue (queue_name, status, pid, child_pid, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?)""", + ( + "cleanup_transaction_test", + "running", + dead_pid, + None, + datetime.now().isoformat(), + datetime.now().isoformat(), + ), + ) + cursor = conn.execute( + """INSERT INTO queue (queue_name, status, pid, child_pid, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?)""", + ( + "cleanup_transaction_test", + "waiting", + my_pid, + None, + datetime.now().isoformat(), + datetime.now().isoformat(), + ), + ) + task_id = cursor.lastrowid + + cleanup_queue_core(conn, "cleanup_transaction_test", PATHS.metrics_path) + conn.commit() + + started, queue_position = attempt_task_start( + conn, + task_id, + "cleanup_transaction_test", + {}, + my_pid, + ) + + assert started is True + assert queue_position == 0 + + +def test_parent_scope_cleanup_reaps_stale_sibling_runner(monkeypatch): + """A stale sibling runner should not keep a parent scope permanently full.""" + capacities = parse_queue_capacities(["gradle=1"]) + monkeypatch.setattr(task_queue, "QUEUE_CAPACITIES", capacities) + + dead_pid = 999999999 + my_pid = os.getpid() + + with get_db() as conn: + conn.execute( + """INSERT INTO queue (queue_name, status, pid, created_at, updated_at) + VALUES (?, ?, ?, ?, ?)""", + ( + "gradle/emu-5557", + "running", + dead_pid, + datetime.now().isoformat(), + datetime.now().isoformat(), + ), + ) + cursor = conn.execute( + """INSERT INTO queue (queue_name, status, pid, server_id, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?)""", + ( + "gradle/emu-5559", + "waiting", + my_pid, + task_queue.SERVER_INSTANCE_ID, + datetime.now().isoformat(), + datetime.now().isoformat(), + ), + ) + task_id = cursor.lastrowid + + with task_queue._active_task_ids_lock: + task_queue._active_task_ids.add(task_id) + + try: + cleanup_queue(conn, "gradle/emu-5559") + + started, queue_position = attempt_task_start( + conn, + task_id, + "gradle/emu-5559", + capacities, + my_pid, + ) + + remaining_sibling_runners = conn.execute( + "SELECT COUNT(*) AS c FROM queue WHERE queue_name = ? AND status = 'running'", + ("gradle/emu-5557",), + ).fetchone()["c"] + + assert started is True + assert queue_position == 0 + assert remaining_sibling_runners == 0 + finally: + with task_queue._active_task_ids_lock: + task_queue._active_task_ids.discard(task_id) + + def test_stale_server_instance_cleanup(): """Test that cleanup removes tasks from old server instances even if PID is reused. @@ -883,6 +1169,13 @@ def test_parse_args_defaults(): sys.argv = original_argv +def test_should_parse_module_args_for_console_script(): + """Installed entrypoints should parse module args; library imports should not.""" + assert task_queue._should_parse_module_args("agent-task-queue", "task_queue") is True + assert task_queue._should_parse_module_args("task_queue.py", "task_queue") is True + assert task_queue._should_parse_module_args("pytest", "task_queue") is False + + def test_parse_args_data_dir(): """Test --data-dir argument parsing.""" import sys diff --git a/tests/test_tq_cli.py b/tests/test_tq_cli.py index 15c0290..d44fc81 100644 --- a/tests/test_tq_cli.py +++ b/tests/test_tq_cli.py @@ -6,6 +6,7 @@ import json import os import signal +import sqlite3 import subprocess import sys import tempfile @@ -42,6 +43,37 @@ def run_tq(*args, data_dir=None, cwd=None, timeout=30): return result +def wait_for_queue_rows(db_path: Path, status: str, expected_count: int, timeout: float = 5.0): + """Poll until the queue has the expected number of rows with the given status.""" + deadline = time.time() + timeout + last_error = None + + while time.time() < deadline: + try: + conn = sqlite3.connect(db_path, timeout=5.0) + conn.row_factory = sqlite3.Row + try: + count = conn.execute( + "SELECT COUNT(*) as c FROM queue WHERE status = ?", + (status,), + ).fetchone()["c"] + finally: + conn.close() + + if count == expected_count: + return + except sqlite3.OperationalError as exc: + last_error = exc + + time.sleep(0.05) + + if last_error is not None: + raise last_error + raise AssertionError( + f"Timed out waiting for {expected_count} queue rows with status={status!r}" + ) + + class TestTqRun: """Tests for the tq run command.""" @@ -82,6 +114,33 @@ def test_queue_option_long_form(self, temp_data_dir): assert result.returncode == 0 assert "queued in 'longqueue'" in result.stdout + def test_queue_capacity_option(self, temp_data_dir): + """Test global --queue-capacity option with implicit run mode.""" + result = run_tq( + "--queue-capacity=gradle=2", + "-q", + "gradle/emu-5557", + "echo", + "capacity test", + data_dir=temp_data_dir, + ) + + assert result.returncode == 0 + assert "queued in 'gradle/emu-5557'" in result.stdout + assert "[tq] SUCCESS" in result.stdout + + def test_invalid_queue_capacity_option(self, temp_data_dir): + """Test invalid --queue-capacity value is rejected.""" + result = run_tq( + "--queue-capacity=gradle=abc", + "echo", + "oops", + data_dir=temp_data_dir, + ) + + assert result.returncode == 1 + assert "Invalid capacity 'abc'" in result.stderr + def test_working_directory_option(self, temp_data_dir): """Test -C/--dir option.""" result = run_tq("-C", "/tmp", "pwd", data_dir=temp_data_dir) @@ -508,7 +567,6 @@ def test_sigint_cleanup_waiting_task(self, temp_data_dir): import os import signal import sqlite3 - import time db_path = Path(temp_data_dir) / "queue.db" @@ -577,7 +635,6 @@ def test_sigint_cleanup_running_task(self, temp_data_dir): import os import signal import sqlite3 - import time db_path = Path(temp_data_dir) / "queue.db" @@ -590,7 +647,7 @@ def test_sigint_cleanup_running_task(self, temp_data_dir): ) # Wait for it to start running - time.sleep(0.5) + wait_for_queue_rows(db_path, "running", 1) # Verify it's running conn = sqlite3.connect(db_path, timeout=5.0) @@ -647,7 +704,7 @@ def test_multiple_waiters_cancelled(self, temp_data_dir): start_new_session=True, ) - time.sleep(0.5) + wait_for_queue_rows(db_path, "running", 1) # Start multiple waiters waiters = [] @@ -662,6 +719,7 @@ def test_multiple_waiters_cancelled(self, temp_data_dir): time.sleep(0.2) # Stagger registration # Verify all waiters are in queue + wait_for_queue_rows(db_path, "waiting", 3) conn = sqlite3.connect(db_path, timeout=5.0) conn.row_factory = sqlite3.Row waiting_count = conn.execute( diff --git a/tq.py b/tq.py index 9ab1e50..caef9d7 100644 --- a/tq.py +++ b/tq.py @@ -25,10 +25,14 @@ init_db, ensure_db, cleanup_queue as _cleanup_queue, + cleanup_targets_for_queue, log_metric as _log_metric, release_lock, is_process_alive, kill_process_tree, + normalize_queue_name, + parse_queue_capacities, + attempt_task_start, POLL_INTERVAL_WAITING, DEFAULT_MAX_LOCK_AGE_MINUTES, DEFAULT_MAX_METRICS_SIZE_MB, @@ -48,6 +52,11 @@ def get_paths(args) -> QueuePaths: return QueuePaths.from_data_dir(data_dir) +def get_queue_capacities(args) -> dict[str, int]: + """Parse queue capacity overrides from CLI args.""" + return parse_queue_capacities(getattr(args, "queue_capacity", [])) + + def cmd_list(args): """List all tasks in the queue.""" paths = get_paths(args) @@ -249,34 +258,43 @@ def log_metric(paths: QueuePaths, event: str, **kwargs): _log_metric(paths.metrics_path, event, DEFAULT_MAX_METRICS_SIZE_MB, **kwargs) -def cleanup_queue(conn, queue_name: str, paths: QueuePaths): +def cleanup_queue( + conn, + queue_name: str, + paths: QueuePaths, + queue_capacities: dict[str, int] | None = None, +): """Clean up queue (wrapper for CLI).""" - _cleanup_queue(conn, queue_name, paths.metrics_path, DEFAULT_MAX_LOCK_AGE_MINUTES) + for target_queue in cleanup_targets_for_queue(conn, queue_name, queue_capacities): + _cleanup_queue(conn, target_queue, paths.metrics_path, DEFAULT_MAX_LOCK_AGE_MINUTES) + + # Additional cleanup: Tasks with our PID but DIFFERENT instance_id (from old CLI instance) + # This handles the edge case where PID is reused after CLI crash + my_pid = os.getpid() + stale_tasks = conn.execute( + "SELECT id, status, child_pid, server_id FROM queue WHERE queue_name = ? AND pid = ? AND server_id IS NOT NULL AND server_id != ?", + (target_queue, my_pid, CLI_INSTANCE_ID), + ).fetchall() - # Additional cleanup: Tasks with our PID but DIFFERENT instance_id (from old CLI instance) - # This handles the edge case where PID is reused after CLI crash - my_pid = os.getpid() - stale_tasks = conn.execute( - "SELECT id, status, child_pid, server_id FROM queue WHERE queue_name = ? AND pid = ? AND server_id IS NOT NULL AND server_id != ?", - (queue_name, my_pid, CLI_INSTANCE_ID), - ).fetchall() + for task in stale_tasks: + if task["child_pid"] and is_process_alive(task["child_pid"]): + print(f"[tq] WARNING: Killing orphaned subprocess {task['child_pid']} from old CLI instance") + kill_process_tree(task["child_pid"]) - for task in stale_tasks: - if task["child_pid"] and is_process_alive(task["child_pid"]): - print(f"[tq] WARNING: Killing orphaned subprocess {task['child_pid']} from old CLI instance") - kill_process_tree(task["child_pid"]) + conn.execute("DELETE FROM queue WHERE id = ?", (task["id"],)) + log_metric( + paths, + "orphan_cleared", + task_id=task["id"], + queue_name=target_queue, + status=task["status"], + old_instance_id=task["server_id"], + reason="stale_cli_instance", + ) + print(f"[tq] WARNING: Cleared task from old CLI instance (ID: {task['id']}, old_instance: {task['server_id']})") - conn.execute("DELETE FROM queue WHERE id = ?", (task["id"],)) - log_metric( - paths, - "orphan_cleared", - task_id=task["id"], - queue_name=queue_name, - status=task["status"], - old_instance_id=task["server_id"], - reason="stale_cli_instance", - ) - print(f"[tq] WARNING: Cleared task from old CLI instance (ID: {task['id']}, old_instance: {task['server_id']})") + if conn.in_transaction: + conn.commit() def register_task(conn, queue_name: str, paths: QueuePaths, command: str = None) -> int: @@ -295,7 +313,7 @@ def register_task(conn, queue_name: str, paths: QueuePaths, command: str = None) return task_id -def wait_for_turn(conn, queue_name: str, task_id: int, paths: QueuePaths) -> None: +def wait_for_turn(conn, queue_name: str, task_id: int, paths: QueuePaths, queue_capacities: dict[str, int]) -> None: """Wait for the task's turn to run. Task must already be registered.""" my_pid = os.getpid() queued_at = time.time() @@ -303,41 +321,26 @@ def wait_for_turn(conn, queue_name: str, task_id: int, paths: QueuePaths) -> Non last_pos = -1 while True: - cleanup_queue(conn, queue_name, paths) - - runner = conn.execute( - "SELECT id FROM queue WHERE queue_name = ? AND status = 'running'", - (queue_name,), - ).fetchone() - - if runner: - pos = conn.execute( - "SELECT COUNT(*) as c FROM queue WHERE queue_name = ? AND status = 'waiting' AND id < ?", - (queue_name, task_id), - ).fetchone()["c"] + 1 + try: + cleanup_queue(conn, queue_name, paths, queue_capacities) + + started, pos = attempt_task_start( + conn, + task_id, + queue_name, + queue_capacities, + my_pid, + ) - if pos != last_pos: - print(f"[tq] Position #{pos} in queue. Waiting...") - last_pos = pos + if not started: - time.sleep(POLL_INTERVAL_WAITING) - continue + if pos != last_pos: + print(f"[tq] Position #{pos} in queue. Waiting...") + last_pos = pos - # Try to acquire lock atomically - cursor = conn.execute( - """UPDATE queue SET status = 'running', updated_at = ?, pid = ? - WHERE id = ? AND status = 'waiting' - AND NOT EXISTS ( - SELECT 1 FROM queue WHERE queue_name = ? AND status = 'running' - ) - AND id = ( - SELECT MIN(id) FROM queue WHERE queue_name = ? AND status = 'waiting' - )""", - (datetime.now().isoformat(), my_pid, task_id, queue_name, queue_name), - ) - conn.commit() + time.sleep(POLL_INTERVAL_WAITING) + continue - if cursor.rowcount > 0: wait_time = time.time() - queued_at log_metric( paths, @@ -351,6 +354,9 @@ def wait_for_turn(conn, queue_name: str, task_id: int, paths: QueuePaths) -> Non else: print("[tq] Lock acquired") return # Lock acquired, task_id was passed in + except sqlite3.OperationalError as exc: + if "database is locked" not in str(exc).lower(): + raise time.sleep(POLL_INTERVAL_WAITING) @@ -364,7 +370,13 @@ def cmd_run(args): # Use shlex.join to properly quote arguments with spaces command = shlex.join(args.run_command) working_dir = os.path.abspath(args.dir) if args.dir else os.getcwd() - queue_name = args.queue + try: + queue_name = normalize_queue_name(args.queue) + queue_capacities = get_queue_capacities(args) + except ValueError as exc: + print(f"Error: {exc}", file=sys.stderr) + sys.exit(1) + timeout = args.timeout if not os.path.exists(working_dir): @@ -426,11 +438,11 @@ def cleanup_handler(signum, frame): try: # Run cleanup BEFORE inserting - this clears orphaned tasks that would otherwise # block the queue forever (since cleanup only runs during polling) - cleanup_queue(conn, queue_name, paths) + cleanup_queue(conn, queue_name, paths, queue_capacities) # Register task first so task_id is available for cleanup if interrupted task_id = register_task(conn, queue_name, paths, command=command) - wait_for_turn(conn, queue_name, task_id, paths) + wait_for_turn(conn, queue_name, task_id, paths, queue_capacities) print(f"[tq] Running: {command}") print(f"[tq] Directory: {working_dir}") @@ -531,6 +543,16 @@ def main(): "--data-dir", help="Data directory (default: $TASK_QUEUE_DATA_DIR or /tmp/agent-task-queue)", ) + parser.add_argument( + "--queue-capacity", + action="append", + default=[], + metavar="SCOPE=CAPACITY", + help=( + "Hierarchical queue capacity override. Repeatable. " + "Example: --queue-capacity=gradle=2 --queue-capacity=gradle/emu-5557=1" + ), + ) subparsers = parser.add_subparsers(dest="command", help="Commands") @@ -564,8 +586,8 @@ def main(): i = 0 while i < len(args_list): arg = args_list[i] - if arg.startswith("--data-dir"): - # Skip --data-dir=value or --data-dir value + if arg.startswith("--data-dir") or arg.startswith("--queue-capacity"): + # Skip --data-dir=value, --queue-capacity=value, or the following value. if "=" not in arg: i += 1 # Skip the next arg (value) i += 1