diff --git a/tests/unit/_autoscaling/test_autoscaled_pool.py b/tests/unit/_autoscaling/test_autoscaled_pool.py index c7d31a6bea..c77a1d8926 100644 --- a/tests/unit/_autoscaling/test_autoscaled_pool.py +++ b/tests/unit/_autoscaling/test_autoscaled_pool.py @@ -15,7 +15,7 @@ from crawlee._autoscaling._types import LoadRatioInfo, SystemInfo from crawlee._types import ConcurrencySettings from crawlee._utils.time import measure_time -from tests.unit.utils import wait_for_condition +from tests.unit.utils import poll_until_condition if TYPE_CHECKING: from collections.abc import Awaitable @@ -192,20 +192,20 @@ def get_historical_system_info() -> SystemInfo: try: # Wait until concurrency scales up above 1. - await wait_for_condition(lambda: pool.desired_concurrency > 1, timeout=5.0) + assert await poll_until_condition(lambda: pool.desired_concurrency > 1, timeout=5.0) # Wait until concurrency reaches maximum. - await wait_for_condition(lambda: pool.desired_concurrency == 4, timeout=5.0) + assert await poll_until_condition(lambda: pool.desired_concurrency == 4, timeout=5.0) # Multiple concurrent workers should have completed more tasks than a single worker could. - await wait_for_condition(lambda: done_count > 10, timeout=5.0) + assert await poll_until_condition(lambda: done_count > 10, timeout=5.0) # Simulate CPU overload and wait for the pool to scale down. overload_active = True - await wait_for_condition(lambda: pool.desired_concurrency < 4, timeout=5.0) + assert await poll_until_condition(lambda: pool.desired_concurrency < 4, timeout=5.0) # Wait until the pool scales all the way down to minimum. - await wait_for_condition(lambda: pool.desired_concurrency == 1, timeout=5.0) + assert await poll_until_condition(lambda: pool.desired_concurrency == 1, timeout=5.0) finally: pool_run_task.cancel() with suppress(asyncio.CancelledError): diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 94a8b2dbe7..a3008777c3 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -32,6 +32,7 @@ from crawlee.statistics import FinalStatistics, StatisticsState from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import Dataset, KeyValueStore, RequestQueue +from tests.unit.utils import poll_until_condition if TYPE_CHECKING: from collections.abc import Callable, Sequence @@ -1988,17 +1989,13 @@ async def test_crawler_intermediate_statistics() -> None: crawler = BasicCrawler() check_time = timedelta(seconds=0.1) - async def wait_for_statistics_initialization() -> None: - while not crawler.statistics.active: # noqa: ASYNC110 # It is ok for tests. - await asyncio.sleep(0.1) - @crawler.router.default_handler async def handler(_: BasicCrawlingContext) -> None: await asyncio.sleep(check_time.total_seconds() * 5) # Start crawler and wait until statistics are initialized. crawler_task = asyncio.create_task(crawler.run(['https://a.placeholder.com'])) - await wait_for_statistics_initialization() + assert await poll_until_condition(lambda: crawler.statistics.active) # Wait some time and check that runtime is updated. await asyncio.sleep(check_time.total_seconds()) diff --git a/tests/unit/request_loaders/test_sitemap_request_loader.py b/tests/unit/request_loaders/test_sitemap_request_loader.py index 98741578f1..0fd77cae59 100644 --- a/tests/unit/request_loaders/test_sitemap_request_loader.py +++ b/tests/unit/request_loaders/test_sitemap_request_loader.py @@ -1,4 +1,3 @@ -import asyncio import base64 import gzip from typing import TYPE_CHECKING @@ -10,6 +9,7 @@ from crawlee.http_clients._base import HttpClient from crawlee.request_loaders._sitemap_request_loader import SitemapRequestLoader from crawlee.storages import KeyValueStore +from tests.unit.utils import poll_until_condition if TYPE_CHECKING: from crawlee._types import JsonSerializable @@ -91,9 +91,7 @@ async def test_is_empty_does_not_depend_on_fetch_next_request(server_url: URL, h assert await sitemap_loader.is_empty() - await asyncio.sleep(0.1) - - assert await sitemap_loader.is_finished() + assert await poll_until_condition(sitemap_loader.is_finished) async def test_abort_sitemap_loading(server_url: URL, http_client: HttpClient) -> None: @@ -139,18 +137,14 @@ async def test_create_persist_state_for_sitemap_loading( async def test_data_persistence_for_sitemap_loading( server_url: URL, http_client: HttpClient, key_value_store: KeyValueStore ) -> None: - async def wait_for_sitemap_loader_not_empty(sitemap_loader: SitemapRequestLoader) -> None: - while await sitemap_loader.is_empty() and not await sitemap_loader.is_finished(): # noqa: ASYNC110 - await asyncio.sleep(0.1) - sitemap_url = (server_url / 'sitemap.xml').with_query(base64=encode_base64(BASIC_SITEMAP.encode())) persist_key = 'data_persist_state' sitemap_loader = SitemapRequestLoader( [str(sitemap_url)], http_client=http_client, persist_state_key=persist_key, enqueue_strategy='all' ) - # Give time to load - await asyncio.wait_for(wait_for_sitemap_loader_not_empty(sitemap_loader), timeout=10) + # Give time to load. + await poll_until_condition(sitemap_loader.is_empty, lambda is_empty: not is_empty, timeout=10) await sitemap_loader.close() diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index d77d524150..56704ebbed 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -12,6 +12,7 @@ from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, UnprocessedRequest from crawlee.storages import RequestQueue from crawlee.storages._storage_instance_manager import StorageInstanceManager +from tests.unit.utils import poll_until_condition if TYPE_CHECKING: from collections.abc import AsyncGenerator, Sequence @@ -516,9 +517,8 @@ async def test_add_requests_wait_for_all( # Immediately after adding, the total count may be less than 15 due to background processing assert await rq.get_total_count() <= 15 - # Wait for background tasks to complete - while await rq.get_total_count() < 15: # noqa: ASYNC110 - await asyncio.sleep(0.1) + # Wait for background tasks to complete. + await poll_until_condition(rq.get_total_count, lambda count: count >= 15) # Verify all requests were added assert await rq.get_total_count() == 15 @@ -542,8 +542,8 @@ async def test_is_finished(rq: RequestQueue) -> None: # Queue shouldn't be finished while background tasks are running assert await rq.is_finished() is False - # Wait for background tasks to finish - await asyncio.sleep(0.2) + # Wait for the background add task to finish. + await poll_until_condition(lambda: not rq._add_requests_tasks) # Process all requests while True: diff --git a/tests/unit/utils.py b/tests/unit/utils.py index cda78e9c14..a965d3cc7b 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -1,34 +1,75 @@ from __future__ import annotations import asyncio +import inspect import sys -from typing import TYPE_CHECKING +import time +from typing import TYPE_CHECKING, TypeVar, cast, overload import pytest if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Awaitable, Callable + +T = TypeVar('T') run_alone_on_mac = pytest.mark.run_alone if sys.platform == 'darwin' else lambda x: x -async def wait_for_condition( - condition: Callable[[], bool], +async def maybe_await(value: Awaitable[T] | T) -> T: + """Await `value` if it is awaitable, otherwise return it unchanged. + + Lets `poll_until_condition` accept both sync and async callables. + """ + if inspect.isawaitable(value): + return await cast('Awaitable[T]', value) + return cast('T', value) + + +@overload +async def poll_until_condition( + fn: Callable[[], Awaitable[T]], + condition: Callable[[T], bool] = ..., + *, + timeout: float = ..., + poll_interval: float = ..., + backoff_factor: float = ..., +) -> T: ... +@overload +async def poll_until_condition( + fn: Callable[[], T], + condition: Callable[[T], bool] = ..., *, - timeout: float = 5.0, + timeout: float = ..., + poll_interval: float = ..., + backoff_factor: float = ..., +) -> T: ... +async def poll_until_condition( + fn: Callable[[], Awaitable[T] | T], + condition: Callable[[T], bool] = bool, + *, + timeout: float = 5, poll_interval: float = 0.05, -) -> None: - """Poll `condition` until it returns True, or raise `AssertionError` on timeout. + backoff_factor: float = 1, +) -> T: + """Poll `fn` until `condition(result)` is True or the timeout expires. + + Polls `fn` at `poll_interval`-second intervals until `condition` is satisfied or `timeout` seconds have elapsed. + Returns the last polled result regardless of whether the condition was met, so the caller can run its own + assertion. The default condition checks for a truthy result. - Args: - condition: A callable that returns True when the desired state is reached. - timeout: Maximum time in seconds to wait before raising. - poll_interval: Time in seconds between condition checks. + Use this instead of a fixed `asyncio.sleep` when waiting for some state to settle (e.g. autoscaling + concurrency) that may take a variable amount of time. For highly variable waits, pass `backoff_factor` > 1 + to multiply the interval after each poll, covering a long timeout with few calls. """ - loop = asyncio.get_running_loop() - deadline = loop.time() + timeout - while loop.time() < deadline: - if condition(): - return - await asyncio.sleep(poll_interval) - raise AssertionError(f'Condition not met within {timeout}s') + deadline = time.monotonic() + timeout + delay = poll_interval + result = await maybe_await(fn()) + while not condition(result): + remaining = deadline - time.monotonic() + if remaining <= 0: + break + await asyncio.sleep(min(delay, remaining)) + delay *= backoff_factor + result = await maybe_await(fn()) + return result