diff --git a/tests/test_tq_cli.py b/tests/test_tq_cli.py index d44fc81..c04f203 100644 --- a/tests/test_tq_cli.py +++ b/tests/test_tq_cli.py @@ -11,12 +11,15 @@ import sys import tempfile import time +from collections.abc import Callable from pathlib import Path +from typing import TypeVar import pytest # Path to tq.py TQ_PATH = Path(__file__).parent.parent / "tq.py" +T = TypeVar("T") @pytest.fixture @@ -43,37 +46,106 @@ def run_tq(*args, data_dir=None, cwd=None, timeout=30): return result -def wait_for_queue_rows(db_path: Path, status: str, expected_count: int, timeout: float = 5.0): - """Poll until the queue has the expected number of rows with the given status.""" - deadline = time.time() + timeout - last_error = None - - while time.time() < deadline: +def poll_with_linear_backoff( + operation: Callable[[], T], + *, + is_ready: Callable[[T], bool], + timeout: float, + description: str, + retriable_exceptions: tuple[type[Exception], ...] = (), + initial_delay: float = 0.05, + delay_step: float = 0.05, + max_delay: float = 0.5, +) -> T: + """Poll until a condition is ready using a bounded linear backoff.""" + deadline = time.monotonic() + timeout + delay = initial_delay + last_error: Exception | None = None + + while True: try: - conn = sqlite3.connect(db_path, timeout=5.0) - conn.row_factory = sqlite3.Row - try: - count = conn.execute( - "SELECT COUNT(*) as c FROM queue WHERE status = ?", - (status,), - ).fetchone()["c"] - finally: - conn.close() - - if count == expected_count: - return - except sqlite3.OperationalError as exc: + value = operation() + last_error = None + if is_ready(value): + return value + except retriable_exceptions as exc: last_error = exc - time.sleep(0.05) + remaining = deadline - time.monotonic() + if remaining <= 0: + break + + time.sleep(min(delay, remaining)) + delay = min(delay + delay_step, max_delay) if last_error is not None: raise last_error - raise AssertionError( - f"Timed out waiting for {expected_count} queue rows with status={status!r}" + raise AssertionError(f"Timed out waiting for {description}") + + +def wait_for_queue_rows(db_path: Path, status: str, expected_count: int, timeout: float = 5.0): + """Poll until the queue has the expected number of rows with the given status.""" + + def read_count() -> int: + conn = sqlite3.connect(db_path, timeout=5.0) + conn.row_factory = sqlite3.Row + try: + return conn.execute( + "SELECT COUNT(*) as c FROM queue WHERE status = ?", + (status,), + ).fetchone()["c"] + finally: + conn.close() + + poll_with_linear_backoff( + read_count, + is_ready=lambda count: count == expected_count, + timeout=timeout, + description=f"{expected_count} queue rows with status={status!r}", + retriable_exceptions=(sqlite3.OperationalError,), ) +def wait_for_queue_row( + db_path: Path, + query: str, + params: tuple = (), + *, + timeout: float = 5.0, + predicate: Callable[[sqlite3.Row], bool] | None = None, +): + """Poll until a query returns a row, optionally requiring an additional predicate.""" + + def read_row() -> sqlite3.Row | None: + conn = sqlite3.connect(db_path, timeout=5.0) + conn.row_factory = sqlite3.Row + try: + return conn.execute(query, params).fetchone() + finally: + conn.close() + + return poll_with_linear_backoff( + read_row, + is_ready=lambda row: row is not None and (predicate is None or predicate(row)), + timeout=timeout, + description=f"queue query: {query}", + retriable_exceptions=(sqlite3.OperationalError,), + ) + + +def wait_for_process_exit(proc: subprocess.Popen, timeout: float = 10.0): + """Poll until a subprocess exits so signal-handling tests do not depend on one wait tick.""" + try: + return poll_with_linear_backoff( + proc.poll, + is_ready=lambda returncode: returncode is not None, + timeout=timeout, + description=f"process exit: {proc.args}", + ) + except AssertionError as exc: + raise subprocess.TimeoutExpired(proc.args, timeout) from exc + + class TestTqRun: """Tests for the tq run command.""" @@ -649,20 +721,21 @@ def test_sigint_cleanup_running_task(self, temp_data_dir): # Wait for it to start running wait_for_queue_rows(db_path, "running", 1) - # Verify it's running - conn = sqlite3.connect(db_path, timeout=5.0) - conn.row_factory = sqlite3.Row - running = conn.execute("SELECT * FROM queue WHERE status = 'running'").fetchone() + # Wait for the stronger running condition this test depends on. + running = wait_for_queue_row( + db_path, + "SELECT * FROM queue WHERE status = 'running'", + predicate=lambda row: row["child_pid"] is not None, + ) assert running is not None, "Task should be running" task_id = running["id"] child_pid = running["child_pid"] - assert child_pid is not None, "Child PID should be recorded" # Send SIGINT proc.send_signal(signal.SIGINT) # Wait for cleanup - proc.wait(timeout=10) + wait_for_process_exit(proc, timeout=10) assert proc.returncode == 130, "Should exit with 130 (128 + SIGINT)" # Verify task was cleaned up @@ -733,7 +806,7 @@ def test_multiple_waiters_cancelled(self, temp_data_dir): # Wait for all to exit for waiter in waiters: - waiter.wait(timeout=5) + wait_for_process_exit(waiter, timeout=10) # Verify all waiting tasks were cleaned up conn = sqlite3.connect(db_path, timeout=5.0)