Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 102 additions & 29 deletions tests/test_tq_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
import sys
import tempfile
import time
from collections.abc import Callable
from pathlib import Path
from typing import TypeVar

import pytest

# Path to tq.py
TQ_PATH = Path(__file__).parent.parent / "tq.py"
T = TypeVar("T")


@pytest.fixture
Expand All @@ -43,37 +46,106 @@ def run_tq(*args, data_dir=None, cwd=None, timeout=30):
return result


def wait_for_queue_rows(db_path: Path, status: str, expected_count: int, timeout: float = 5.0):
"""Poll until the queue has the expected number of rows with the given status."""
deadline = time.time() + timeout
last_error = None

while time.time() < deadline:
def poll_with_linear_backoff(
operation: Callable[[], T],
*,
is_ready: Callable[[T], bool],
timeout: float,
description: str,
retriable_exceptions: tuple[type[Exception], ...] = (),
initial_delay: float = 0.05,
delay_step: float = 0.05,
max_delay: float = 0.5,
) -> T:
"""Poll until a condition is ready using a bounded linear backoff."""
deadline = time.monotonic() + timeout
delay = initial_delay
last_error: Exception | None = None

while True:
try:
conn = sqlite3.connect(db_path, timeout=5.0)
conn.row_factory = sqlite3.Row
try:
count = conn.execute(
"SELECT COUNT(*) as c FROM queue WHERE status = ?",
(status,),
).fetchone()["c"]
finally:
conn.close()

if count == expected_count:
return
except sqlite3.OperationalError as exc:
value = operation()
last_error = None
if is_ready(value):
return value
except retriable_exceptions as exc:
last_error = exc

time.sleep(0.05)
remaining = deadline - time.monotonic()
if remaining <= 0:
break

time.sleep(min(delay, remaining))
delay = min(delay + delay_step, max_delay)

if last_error is not None:
raise last_error
raise AssertionError(
f"Timed out waiting for {expected_count} queue rows with status={status!r}"
raise AssertionError(f"Timed out waiting for {description}")


def wait_for_queue_rows(db_path: Path, status: str, expected_count: int, timeout: float = 5.0):
"""Poll until the queue has the expected number of rows with the given status."""

def read_count() -> int:
conn = sqlite3.connect(db_path, timeout=5.0)
conn.row_factory = sqlite3.Row
try:
return conn.execute(
"SELECT COUNT(*) as c FROM queue WHERE status = ?",
(status,),
).fetchone()["c"]
finally:
conn.close()

poll_with_linear_backoff(
read_count,
is_ready=lambda count: count == expected_count,
timeout=timeout,
description=f"{expected_count} queue rows with status={status!r}",
retriable_exceptions=(sqlite3.OperationalError,),
)


def wait_for_queue_row(
db_path: Path,
query: str,
params: tuple = (),
*,
timeout: float = 5.0,
predicate: Callable[[sqlite3.Row], bool] | None = None,
):
"""Poll until a query returns a row, optionally requiring an additional predicate."""

def read_row() -> sqlite3.Row | None:
conn = sqlite3.connect(db_path, timeout=5.0)
conn.row_factory = sqlite3.Row
try:
return conn.execute(query, params).fetchone()
finally:
conn.close()

return poll_with_linear_backoff(
read_row,
is_ready=lambda row: row is not None and (predicate is None or predicate(row)),
timeout=timeout,
description=f"queue query: {query}",
retriable_exceptions=(sqlite3.OperationalError,),
)


def wait_for_process_exit(proc: subprocess.Popen, timeout: float = 10.0):
"""Poll until a subprocess exits so signal-handling tests do not depend on one wait tick."""
try:
return poll_with_linear_backoff(
proc.poll,
is_ready=lambda returncode: returncode is not None,
timeout=timeout,
description=f"process exit: {proc.args}",
)
except AssertionError as exc:
raise subprocess.TimeoutExpired(proc.args, timeout) from exc


class TestTqRun:
"""Tests for the tq run command."""

Expand Down Expand Up @@ -649,20 +721,21 @@ def test_sigint_cleanup_running_task(self, temp_data_dir):
# Wait for it to start running
wait_for_queue_rows(db_path, "running", 1)

# Verify it's running
conn = sqlite3.connect(db_path, timeout=5.0)
conn.row_factory = sqlite3.Row
running = conn.execute("SELECT * FROM queue WHERE status = 'running'").fetchone()
# Wait for the stronger running condition this test depends on.
running = wait_for_queue_row(
db_path,
"SELECT * FROM queue WHERE status = 'running'",
predicate=lambda row: row["child_pid"] is not None,
)
assert running is not None, "Task should be running"
task_id = running["id"]
child_pid = running["child_pid"]
assert child_pid is not None, "Child PID should be recorded"

# Send SIGINT
proc.send_signal(signal.SIGINT)

# Wait for cleanup
proc.wait(timeout=10)
wait_for_process_exit(proc, timeout=10)
assert proc.returncode == 130, "Should exit with 130 (128 + SIGINT)"

# Verify task was cleaned up
Expand Down Expand Up @@ -733,7 +806,7 @@ def test_multiple_waiters_cancelled(self, temp_data_dir):

# Wait for all to exit
for waiter in waiters:
waiter.wait(timeout=5)
wait_for_process_exit(waiter, timeout=10)

# Verify all waiting tasks were cleaned up
conn = sqlite3.connect(db_path, timeout=5.0)
Expand Down