From a56e4b097db295c4c90dab598cbd5abb294b41e8 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Tue, 12 May 2026 13:14:14 -0700 Subject: [PATCH] feat: with retry helper --- .../src/with_retry/with_retry_callback.py | 65 +++ .../with_retry/test_with_retry_callback.py | 69 ++++ .../__init__.py | 6 + .../retries.py | 97 ++++- tests/with_retry_test.py | 377 ++++++++++++++++++ 5 files changed, 613 insertions(+), 1 deletion(-) create mode 100644 examples/src/with_retry/with_retry_callback.py create mode 100644 examples/test/with_retry/test_with_retry_callback.py create mode 100644 tests/with_retry_test.py diff --git a/examples/src/with_retry/with_retry_callback.py b/examples/src/with_retry/with_retry_callback.py new file mode 100644 index 00000000..cd5bb48a --- /dev/null +++ b/examples/src/with_retry/with_retry_callback.py @@ -0,0 +1,65 @@ +"""Demonstrates with_retry wrapping a wait_for_callback operation. + +The callback may fail multiple times before succeeding. The with_retry helper +retries the entire callback flow (including creating a new callback each attempt) +with exponential backoff between attempts. +""" + +from typing import Any + +from aws_durable_execution_sdk_python.config import Duration, WaitForCallbackConfig +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution +from aws_durable_execution_sdk_python.retries import ( + RetryStrategyConfig, + WithRetryConfig, + with_retry, +) + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> dict[str, Any]: + """Handler demonstrating with_retry around a wait_for_callback. + + The external system may fail to process the callback multiple times. + with_retry will re-create the callback and wait again on each retry, + with exponential backoff between attempts. + """ + + def retryable_callback_flow(ctx: DurableContext, attempt: int) -> str: + """The retryable block: create a callback and wait for the result.""" + + def submitter(callback_id: str, _callback_ctx) -> None: + """Submit the callback ID to an external system.""" + # In real usage, this would send the callback_id to an external + # system (e.g., via API call, SQS message, etc.) + pass + + config = WaitForCallbackConfig( + timeout=Duration.from_seconds(30), + heartbeat_timeout=Duration.from_seconds(60), + ) + + return ctx.wait_for_callback( + submitter, name=f"external-call-attempt-{attempt}", config=config + ) + + retry_config = WithRetryConfig( + retry_strategy_config=RetryStrategyConfig( + max_attempts=5, + initial_delay=Duration.from_seconds(2), + backoff_rate=1.0, + ), + ) + + result = with_retry( + context, + func=retryable_callback_flow, + config=retry_config, + name="callback-with-retry", + ) + + return { + "success": True, + "result": result, + } diff --git a/examples/test/with_retry/test_with_retry_callback.py b/examples/test/with_retry/test_with_retry_callback.py new file mode 100644 index 00000000..f3b22063 --- /dev/null +++ b/examples/test/with_retry/test_with_retry_callback.py @@ -0,0 +1,69 @@ +"""Tests for with_retry_callback example. + +Demonstrates that with_retry retries the entire wait_for_callback flow +when the callback fails. The external system fails 2 times before +succeeding on the 3rd attempt. +""" + +import pytest +from src.with_retry import with_retry_callback +from test.conftest import deserialize_operation_payload + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ErrorObject + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=with_retry_callback.handler, + lambda_function_name="With Retry Callback", +) +def test_with_retry_callback_fails_twice_then_succeeds(durable_runner): + """Test that with_retry retries the callback flow after failures. + + The external system sends callback failure 2 times, then succeeds + on the 3rd attempt. with_retry handles the failures and retries + the entire wait_for_callback block. + """ + with durable_runner: + execution_arn = durable_runner.run_async(input=None, timeout=60) + + # Attempt 1: external system fails + callback_id_1 = durable_runner.wait_for_callback( + execution_arn=execution_arn, + name="external-call-attempt-1 create callback id", + ) + durable_runner.send_callback_failure( + callback_id=callback_id_1, + error=ErrorObject.from_message("External system unavailable"), + ) + + # Attempt 2: external system fails again + callback_id_2 = durable_runner.wait_for_callback( + execution_arn=execution_arn, + name="external-call-attempt-2 create callback id", + ) + durable_runner.send_callback_failure( + callback_id=callback_id_2, + error=ErrorObject.from_message("External system timeout"), + ) + + # Attempt 3: external system succeeds + callback_id_3 = durable_runner.wait_for_callback( + execution_arn=execution_arn, + name="external-call-attempt-3 create callback id", + ) + durable_runner.send_callback_success( + callback_id=callback_id_3, + result="approval granted".encode(), + ) + + result = durable_runner.wait_for_result(execution_arn=execution_arn) + + assert result.status is InvocationStatus.SUCCEEDED + + result_data = deserialize_operation_payload(result.result) + assert result_data == { + "success": True, + "result": "approval granted", + } diff --git a/src/aws_durable_execution_sdk_python/__init__.py b/src/aws_durable_execution_sdk_python/__init__.py index 23a85cdf..0367d67a 100644 --- a/src/aws_durable_execution_sdk_python/__init__.py +++ b/src/aws_durable_execution_sdk_python/__init__.py @@ -24,9 +24,13 @@ # Core decorator - used in every durable function from aws_durable_execution_sdk_python.execution import durable_execution +# Retry helpers +from aws_durable_execution_sdk_python.retries import WithRetryConfig, with_retry + # Essential context types - passed to user functions from aws_durable_execution_sdk_python.types import StepContext + __all__ = [ "BatchResult", "DurableContext", @@ -34,9 +38,11 @@ "InvocationError", "StepContext", "ValidationError", + "WithRetryConfig", "__version__", "durable_execution", "durable_step", "durable_wait_for_callback", "durable_with_child_context", + "with_retry", ] diff --git a/src/aws_durable_execution_sdk_python/retries.py b/src/aws_durable_execution_sdk_python/retries.py index 5a09db21..f6f98f9d 100644 --- a/src/aws_durable_execution_sdk_python/retries.py +++ b/src/aws_durable_execution_sdk_python/retries.py @@ -5,13 +5,20 @@ import math import re from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar from aws_durable_execution_sdk_python.config import Duration, JitterStrategy +from aws_durable_execution_sdk_python.exceptions import SuspendExecution + if TYPE_CHECKING: from collections.abc import Callable + from aws_durable_execution_sdk_python.config import ChildConfig + from aws_durable_execution_sdk_python.types import DurableContext + +T = TypeVar("T") + Numeric = int | float # Default pattern that matches all error messages @@ -172,3 +179,91 @@ def critical(cls) -> Callable[[Exception, int], RetryDecision]: jitter_strategy=JitterStrategy.NONE, ) ) + + +@dataclass(frozen=True) +class WithRetryConfig: + """Configuration for with_retry. + + Wraps the existing RetryStrategyConfig (same config used for step + retries) and adds execution-mode options specific to with_retry. + + Attributes: + retry_strategy_config: RetryStrategyConfig controlling retry + behavior (max_attempts, initial_delay, backoff_rate, jitter, + error filtering). The same config used for step retries. + wrap_with_run_in_child_context: Whether to wrap the retry loop in + a child context for isolation. Default True. + child_context_config: Optional ChildConfig forwarded to + run_in_child_context when wrapping is enabled. Ignored when + wrap_with_run_in_child_context is False. + """ + + retry_strategy_config: RetryStrategyConfig + wrap_with_run_in_child_context: bool = True + child_context_config: ChildConfig | None = None + + +def with_retry( + context: DurableContext, + func: Callable[[DurableContext, int], T], + config: WithRetryConfig, + name: str | None = None, +) -> T: + """Retry a block of durable logic with configurable backoff. + + Semantically a run_in_child_context with a retry policy wrapped around + it — on failure the whole function body is re-run from the beginning + with configurable backoff. + + Unlike context.step() which retries a single atomic operation, + with_retry retries an entire function body that may contain multiple + durable operations (steps, waits, invokes, callbacks, etc.). + + Uses the existing RetryStrategyConfig (via WithRetryConfig), so retry + configuration is consistent across the SDK. + + Args: + context: The DurableContext to execute within. + func: A callable that accepts (DurableContext, attempt: int) and + returns T. The function body may contain multiple durable + operations. + config: WithRetryConfig containing a RetryStrategyConfig plus + execution-mode options. + name: Optional name for the child context and backoff waits. + When provided, backoff waits are named + "{name}-backoff-{attempt}". + + Returns: + The result of func on successful execution. + + Raises: + The exception from the last failed attempt when retries are + exhausted or the retry strategy returns should_retry=False. + SuspendExecution: Re-raised immediately (SDK control flow). + """ + retry_strategy = create_retry_strategy(config.retry_strategy_config) + + def run_loop(ctx: DurableContext) -> T: + attempt = 0 + while True: + attempt += 1 + try: + return func(ctx, attempt) + except SuspendExecution: + raise # SDK control flow - never intercept + except Exception as err: + decision = retry_strategy(err, attempt) + if not decision.should_retry: + raise + wait_name = f"{name}-backoff-{attempt}" if name else None + ctx.wait(duration=decision.delay, name=wait_name) + + if config.wrap_with_run_in_child_context: + return context.run_in_child_context( + run_loop, + name=name, + config=config.child_context_config, + ) + else: + return run_loop(context) diff --git a/tests/with_retry_test.py b/tests/with_retry_test.py new file mode 100644 index 00000000..fb8a8ef6 --- /dev/null +++ b/tests/with_retry_test.py @@ -0,0 +1,377 @@ +"""Unit tests for with_retry helper function.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypeVar +from unittest.mock import MagicMock + +import pytest + +from aws_durable_execution_sdk_python.config import Duration +from aws_durable_execution_sdk_python.exceptions import SuspendExecution +from aws_durable_execution_sdk_python.retries import ( + RetryStrategyConfig, + WithRetryConfig, + create_retry_strategy, + with_retry, +) + + +if TYPE_CHECKING: + from collections.abc import Callable + + from aws_durable_execution_sdk_python.config import ChildConfig + from aws_durable_execution_sdk_python.types import DurableContext + +_T = TypeVar("_T") + + +# region Mock DurableContext + + +@dataclass +class WaitCall: + """Record of a wait() call.""" + + duration: Duration + name: str | None + + +@dataclass +class RunInChildContextCall: + """Record of a run_in_child_context() call.""" + + name: str | None + config: ChildConfig | None + result: object = None + + +@dataclass +class MockDurableContext: + """A fake DurableContext that records wait() and run_in_child_context() calls.""" + + wait_calls: list[WaitCall] = field(default_factory=list) + child_context_calls: list[RunInChildContextCall] = field(default_factory=list) + + def wait(self, duration: Duration, name: str | None = None) -> None: + self.wait_calls.append(WaitCall(duration=duration, name=name)) + + def run_in_child_context( + self, + func: Callable[[DurableContext], _T], + name: str | None = None, + config: ChildConfig | None = None, + ) -> _T: + result: _T = func(self) # type: ignore[arg-type] + self.child_context_calls.append( + RunInChildContextCall(name=name, config=config, result=result) + ) + return result + + def step(self, *args, **kwargs): + raise NotImplementedError("step not used in with_retry tests") + + def map(self, *args, **kwargs): + raise NotImplementedError("map not used in with_retry tests") + + def parallel(self, *args, **kwargs): + raise NotImplementedError("parallel not used in with_retry tests") + + def create_callback(self, *args, **kwargs): + raise NotImplementedError("create_callback not used in with_retry tests") + + +# endregion + + +# region Helper fixtures + + +def _make_config( + max_attempts: int = 3, + initial_delay: Duration | None = None, + wrap_with_run_in_child_context: bool = True, + child_context_config: ChildConfig | None = None, +) -> WithRetryConfig: + """Create a WithRetryConfig with no jitter for deterministic tests.""" + from aws_durable_execution_sdk_python.config import JitterStrategy + + return WithRetryConfig( + retry_strategy_config=RetryStrategyConfig( + max_attempts=max_attempts, + initial_delay=initial_delay or Duration.from_seconds(1), + jitter_strategy=JitterStrategy.NONE, + ), + wrap_with_run_in_child_context=wrap_with_run_in_child_context, + child_context_config=child_context_config, + ) + + +# endregion + + +# region Tests + + +def test_success_on_first_attempt_returns_result_without_retry_strategy(): + """Function succeeds on first attempt returns result without invoking retry strategy.""" + ctx = MockDurableContext() + config = _make_config(wrap_with_run_in_child_context=False) + + strategy_calls: list = [] + original_create = create_retry_strategy + + # Patch to track strategy invocations + def tracking_func(ctx: DurableContext, attempt: int) -> str: + return "success" + + result = with_retry(ctx, tracking_func, config) + + assert result == "success" + assert len(ctx.wait_calls) == 0 + + +def test_function_fails_then_succeeds_returns_successful_result(): + """Function fails then succeeds returns result from successful attempt.""" + ctx = MockDurableContext() + config = _make_config(max_attempts=3, wrap_with_run_in_child_context=False) + + call_count = 0 + + def failing_then_succeeding(ctx: DurableContext, attempt: int) -> str: + nonlocal call_count + call_count += 1 + if attempt < 3: + raise ValueError(f"fail on attempt {attempt}") + return "eventual success" + + result = with_retry(ctx, failing_then_succeeding, config) + + assert result == "eventual success" + assert call_count == 3 + assert len(ctx.wait_calls) == 2 + + +def test_retry_strategy_returns_should_retry_false_reraises_exception(): + """Retry strategy returns should_retry=False re-raises exception.""" + ctx = MockDurableContext() + # max_attempts=1 means the strategy will return should_retry=False on first failure + config = _make_config(max_attempts=1, wrap_with_run_in_child_context=False) + + def always_fails(ctx: DurableContext, attempt: int) -> None: + raise RuntimeError("permanent failure") + + with pytest.raises(RuntimeError, match="permanent failure"): + with_retry(ctx, always_fails, config) + + assert len(ctx.wait_calls) == 0 + + +def test_suspend_execution_is_reraised_immediately(): + """SuspendExecution is re-raised immediately without invoking retry strategy.""" + ctx = MockDurableContext() + config = _make_config(max_attempts=5, wrap_with_run_in_child_context=False) + + def raises_suspend(ctx: DurableContext, attempt: int) -> None: + raise SuspendExecution("suspending") + + with pytest.raises(SuspendExecution, match="suspending"): + with_retry(ctx, raises_suspend, config) + + # No waits should have been called - strategy was never invoked + assert len(ctx.wait_calls) == 0 + + +def test_default_config_wraps_in_child_context(): + """Default config wraps in child context.""" + ctx = MockDurableContext() + config = _make_config(wrap_with_run_in_child_context=True) + + def simple_func(ctx: DurableContext, attempt: int) -> str: + return "child result" + + result = with_retry(ctx, simple_func, config) + + assert result == "child result" + assert len(ctx.child_context_calls) == 1 + + +def test_wrap_with_run_in_child_context_false_skips_child_context(): + """wrap_with_run_in_child_context=False skips child context.""" + ctx = MockDurableContext() + config = _make_config(wrap_with_run_in_child_context=False) + + def simple_func(ctx: DurableContext, attempt: int) -> str: + return "direct result" + + result = with_retry(ctx, simple_func, config) + + assert result == "direct result" + assert len(ctx.child_context_calls) == 0 + + +def test_no_name_creates_anonymous_child_context_and_anonymous_waits(): + """No name creates anonymous child context and anonymous waits.""" + ctx = MockDurableContext() + config = _make_config(max_attempts=3, wrap_with_run_in_child_context=True) + + call_count = 0 + + def fails_once(ctx: DurableContext, attempt: int) -> str: + nonlocal call_count + call_count += 1 + if attempt == 1: + raise ValueError("transient") + return "ok" + + result = with_retry(ctx, fails_once, config, name=None) + + assert result == "ok" + # Child context should have been called with name=None + assert len(ctx.child_context_calls) == 1 + assert ctx.child_context_calls[0].name is None + # Wait should have been called with name=None + assert len(ctx.wait_calls) == 1 + assert ctx.wait_calls[0].name is None + + +def test_name_is_forwarded_to_child_context_and_backoff_waits(): + """Name is forwarded to child context and backoff waits.""" + ctx = MockDurableContext() + config = _make_config(max_attempts=3, wrap_with_run_in_child_context=True) + + call_count = 0 + + def fails_twice(ctx: DurableContext, attempt: int) -> str: + nonlocal call_count + call_count += 1 + if attempt <= 2: + raise ValueError("transient") + return "done" + + result = with_retry(ctx, fails_twice, config, name="my-retry") + + assert result == "done" + # Child context should have been called with the name + assert len(ctx.child_context_calls) == 1 + assert ctx.child_context_calls[0].name == "my-retry" + # Waits should be named "{name}-backoff-{attempt}" + assert len(ctx.wait_calls) == 2 + assert ctx.wait_calls[0].name == "my-retry-backoff-1" + assert ctx.wait_calls[1].name == "my-retry-backoff-2" + + +def test_child_context_config_is_forwarded(): + """child_context_config is forwarded to run_in_child_context.""" + ctx = MockDurableContext() + + mock_child_config = MagicMock() + + config = WithRetryConfig( + retry_strategy_config=RetryStrategyConfig(max_attempts=3), + wrap_with_run_in_child_context=True, + child_context_config=mock_child_config, + ) + + def simple_func(ctx: DurableContext, attempt: int) -> str: + return "result" + + with_retry(ctx, simple_func, config, name="test") + + assert len(ctx.child_context_calls) == 1 + assert ctx.child_context_calls[0].config is mock_child_config + + +def test_attempt_number_starts_at_1_and_increments(): + """Attempt number starts at 1 and increments.""" + ctx = MockDurableContext() + config = _make_config(max_attempts=5, wrap_with_run_in_child_context=False) + + recorded_attempts: list[int] = [] + + def record_attempts(ctx: DurableContext, attempt: int) -> str: + recorded_attempts.append(attempt) + if attempt < 4: + raise ValueError("not yet") + return "done" + + result = with_retry(ctx, record_attempts, config) + + assert result == "done" + assert recorded_attempts == [1, 2, 3, 4] + + +def test_with_retry_and_config_importable_from_package(): + """with_retry and WithRetryConfig are importable from package.""" + from aws_durable_execution_sdk_python import WithRetryConfig as ImportedConfig + from aws_durable_execution_sdk_python import with_retry as imported_with_retry + + assert ImportedConfig is WithRetryConfig + assert imported_with_retry is with_retry + + +def test_integration_with_create_retry_strategy(): + """Integration with create_retry_strategy produces correct retry behavior.""" + ctx = MockDurableContext() + + # Use a real RetryStrategyConfig with specific settings + config = WithRetryConfig( + retry_strategy_config=RetryStrategyConfig( + max_attempts=4, + initial_delay=Duration.from_seconds(2), + backoff_rate=2.0, + jitter_strategy=__import__( + "aws_durable_execution_sdk_python.config", fromlist=["JitterStrategy"] + ).JitterStrategy.NONE, + ), + wrap_with_run_in_child_context=False, + ) + + call_count = 0 + + def fails_three_times(ctx: DurableContext, attempt: int) -> str: + nonlocal call_count + call_count += 1 + if attempt <= 3: + raise ValueError(f"fail {attempt}") + return "success after retries" + + result = with_retry(ctx, fails_three_times, config) + + assert result == "success after retries" + assert call_count == 4 + + # Verify backoff delays: 2*2^0=2, 2*2^1=4, 2*2^2=8 + assert len(ctx.wait_calls) == 3 + assert ctx.wait_calls[0].duration.to_seconds() == 2 + assert ctx.wait_calls[1].duration.to_seconds() == 4 + assert ctx.wait_calls[2].duration.to_seconds() == 8 + + +def test_integration_retries_exhausted_raises_last_exception(): + """When all retries are exhausted, the last exception is raised.""" + ctx = MockDurableContext() + + config = WithRetryConfig( + retry_strategy_config=RetryStrategyConfig( + max_attempts=3, + initial_delay=Duration.from_seconds(1), + jitter_strategy=__import__( + "aws_durable_execution_sdk_python.config", fromlist=["JitterStrategy"] + ).JitterStrategy.NONE, + ), + wrap_with_run_in_child_context=False, + ) + + def always_fails(ctx: DurableContext, attempt: int) -> None: + raise RuntimeError(f"error on attempt {attempt}") + + with pytest.raises(RuntimeError, match="error on attempt 3"): + with_retry(ctx, always_fails, config) + + # Should have waited between attempts 1->2 and 2->3 + assert len(ctx.wait_calls) == 2 + + +# endregion