Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/unit/_autoscaling/test_autoscaled_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from crawlee._autoscaling._types import LoadRatioInfo, SystemInfo
from crawlee._types import ConcurrencySettings
from crawlee._utils.time import measure_time
from tests.unit.utils import wait_for_condition
from tests.unit.utils import poll_until_condition

if TYPE_CHECKING:
from collections.abc import Awaitable
Expand Down Expand Up @@ -192,20 +192,20 @@ def get_historical_system_info() -> SystemInfo:

try:
# Wait until concurrency scales up above 1.
await wait_for_condition(lambda: pool.desired_concurrency > 1, timeout=5.0)
assert await poll_until_condition(lambda: pool.desired_concurrency > 1, timeout=5.0)

# Wait until concurrency reaches maximum.
await wait_for_condition(lambda: pool.desired_concurrency == 4, timeout=5.0)
assert await poll_until_condition(lambda: pool.desired_concurrency == 4, timeout=5.0)

# Multiple concurrent workers should have completed more tasks than a single worker could.
await wait_for_condition(lambda: done_count > 10, timeout=5.0)
assert await poll_until_condition(lambda: done_count > 10, timeout=5.0)

# Simulate CPU overload and wait for the pool to scale down.
overload_active = True
await wait_for_condition(lambda: pool.desired_concurrency < 4, timeout=5.0)
assert await poll_until_condition(lambda: pool.desired_concurrency < 4, timeout=5.0)

# Wait until the pool scales all the way down to minimum.
await wait_for_condition(lambda: pool.desired_concurrency == 1, timeout=5.0)
assert await poll_until_condition(lambda: pool.desired_concurrency == 1, timeout=5.0)
finally:
pool_run_task.cancel()
with suppress(asyncio.CancelledError):
Expand Down
7 changes: 2 additions & 5 deletions tests/unit/crawlers/_basic/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from crawlee.statistics import FinalStatistics, StatisticsState
from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient
from crawlee.storages import Dataset, KeyValueStore, RequestQueue
from tests.unit.utils import poll_until_condition

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
Expand Down Expand Up @@ -1988,17 +1989,13 @@ async def test_crawler_intermediate_statistics() -> None:
crawler = BasicCrawler()
check_time = timedelta(seconds=0.1)

async def wait_for_statistics_initialization() -> None:
while not crawler.statistics.active: # noqa: ASYNC110 # It is ok for tests.
await asyncio.sleep(0.1)

@crawler.router.default_handler
async def handler(_: BasicCrawlingContext) -> None:
await asyncio.sleep(check_time.total_seconds() * 5)

# Start crawler and wait until statistics are initialized.
crawler_task = asyncio.create_task(crawler.run(['https://a.placeholder.com']))
await wait_for_statistics_initialization()
assert await poll_until_condition(lambda: crawler.statistics.active)

# Wait some time and check that runtime is updated.
await asyncio.sleep(check_time.total_seconds())
Expand Down
14 changes: 4 additions & 10 deletions tests/unit/request_loaders/test_sitemap_request_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import base64
import gzip
from typing import TYPE_CHECKING
Expand All @@ -10,6 +9,7 @@
from crawlee.http_clients._base import HttpClient
from crawlee.request_loaders._sitemap_request_loader import SitemapRequestLoader
from crawlee.storages import KeyValueStore
from tests.unit.utils import poll_until_condition

if TYPE_CHECKING:
from crawlee._types import JsonSerializable
Expand Down Expand Up @@ -91,9 +91,7 @@ async def test_is_empty_does_not_depend_on_fetch_next_request(server_url: URL, h

assert await sitemap_loader.is_empty()

await asyncio.sleep(0.1)

assert await sitemap_loader.is_finished()
assert await poll_until_condition(sitemap_loader.is_finished)


async def test_abort_sitemap_loading(server_url: URL, http_client: HttpClient) -> None:
Expand Down Expand Up @@ -139,18 +137,14 @@ async def test_create_persist_state_for_sitemap_loading(
async def test_data_persistence_for_sitemap_loading(
server_url: URL, http_client: HttpClient, key_value_store: KeyValueStore
) -> None:
async def wait_for_sitemap_loader_not_empty(sitemap_loader: SitemapRequestLoader) -> None:
while await sitemap_loader.is_empty() and not await sitemap_loader.is_finished(): # noqa: ASYNC110
await asyncio.sleep(0.1)

sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode()))
persist_key = 'data_persist_state'
sitemap_loader = SitemapRequestLoader(
[str(sitemap_url)], http_client=http_client, persist_state_key=persist_key, enqueue_strategy='all'
)

# Give time to load
await asyncio.wait_for(wait_for_sitemap_loader_not_empty(sitemap_loader), timeout=10)
# Give time to load.
await poll_until_condition(sitemap_loader.is_empty, lambda is_empty: not is_empty, timeout=10)

await sitemap_loader.close()

Expand Down
10 changes: 5 additions & 5 deletions tests/unit/storages/test_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, UnprocessedRequest
from crawlee.storages import RequestQueue
from crawlee.storages._storage_instance_manager import StorageInstanceManager
from tests.unit.utils import poll_until_condition

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Sequence
Expand Down Expand Up @@ -516,9 +517,8 @@ async def test_add_requests_wait_for_all(
# Immediately after adding, the total count may be less than 15 due to background processing
assert await rq.get_total_count() <= 15

# Wait for background tasks to complete
while await rq.get_total_count() < 15: # noqa: ASYNC110
await asyncio.sleep(0.1)
# Wait for background tasks to complete.
await poll_until_condition(rq.get_total_count, lambda count: count >= 15)

# Verify all requests were added
assert await rq.get_total_count() == 15
Expand All @@ -542,8 +542,8 @@ async def test_is_finished(rq: RequestQueue) -> None:
# Queue shouldn't be finished while background tasks are running
assert await rq.is_finished() is False

# Wait for background tasks to finish
await asyncio.sleep(0.2)
# Wait for the background add task to finish.
await poll_until_condition(lambda: not rq._add_requests_tasks)

# Process all requests
while True:
Expand Down
77 changes: 59 additions & 18 deletions tests/unit/utils.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,75 @@
from __future__ import annotations

import asyncio
import inspect
import sys
from typing import TYPE_CHECKING
import time
from typing import TYPE_CHECKING, TypeVar, cast, overload

import pytest

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable

T = TypeVar('T')

run_alone_on_mac = pytest.mark.run_alone if sys.platform == 'darwin' else lambda x: x


async def wait_for_condition(
condition: Callable[[], bool],
async def maybe_await(value: Awaitable[T] | T) -> T:
"""Await `value` if it is awaitable, otherwise return it unchanged.

Lets `poll_until_condition` accept both sync and async callables.
"""
if inspect.isawaitable(value):
return await cast('Awaitable[T]', value)
return cast('T', value)


@overload
async def poll_until_condition(
fn: Callable[[], Awaitable[T]],
condition: Callable[[T], bool] = ...,
*,
timeout: float = ...,
poll_interval: float = ...,
backoff_factor: float = ...,
) -> T: ...
@overload
async def poll_until_condition(
fn: Callable[[], T],
condition: Callable[[T], bool] = ...,
*,
timeout: float = 5.0,
timeout: float = ...,
poll_interval: float = ...,
backoff_factor: float = ...,
) -> T: ...
async def poll_until_condition(
fn: Callable[[], Awaitable[T] | T],
condition: Callable[[T], bool] = bool,
*,
timeout: float = 5,
poll_interval: float = 0.05,
) -> None:
"""Poll `condition` until it returns True, or raise `AssertionError` on timeout.
backoff_factor: float = 1,
) -> T:
"""Poll `fn` until `condition(result)` is True or the timeout expires.

Polls `fn` at `poll_interval`-second intervals until `condition` is satisfied or `timeout` seconds have elapsed.
Returns the last polled result regardless of whether the condition was met, so the caller can run its own
assertion. The default condition checks for a truthy result.

Args:
condition: A callable that returns True when the desired state is reached.
timeout: Maximum time in seconds to wait before raising.
poll_interval: Time in seconds between condition checks.
Use this instead of a fixed `asyncio.sleep` when waiting for some state to settle (e.g. autoscaling
concurrency) that may take a variable amount of time. For highly variable waits, pass `backoff_factor` > 1
to multiply the interval after each poll, covering a long timeout with few calls.
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
while loop.time() < deadline:
if condition():
return
await asyncio.sleep(poll_interval)
raise AssertionError(f'Condition not met within {timeout}s')
deadline = time.monotonic() + timeout
delay = poll_interval
result = await maybe_await(fn())
while not condition(result):
remaining = deadline - time.monotonic()
if remaining <= 0:
break
await asyncio.sleep(min(delay, remaining))
delay *= backoff_factor
result = await maybe_await(fn())
return result
Loading