diff --git a/src/aws_durable_execution_sdk_python/concurrency/executor.py b/src/aws_durable_execution_sdk_python/concurrency/executor.py index 24c7657..eb3a4d8 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -439,7 +439,6 @@ def run_in_child_handler() -> ResultType: is_virtual=is_virtual, ), ) - child_context.state.track_replay(operation_id=operation_id) return result def replay(self, execution_state: ExecutionState, executor_context: DurableContext): diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index bfded98..160b926 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -42,7 +42,7 @@ SerDes, deserialize, ) -from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001 +from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus # noqa: TCH001 from aws_durable_execution_sdk_python.threading import OrderedCounter from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol from aws_durable_execution_sdk_python.types import ( @@ -238,6 +238,7 @@ def __init__( parent_id: str | None = None, logger: Logger | None = None, step_id_prefix: str | None = None, + replay_status: ReplayStatus = ReplayStatus.REPLAY, ) -> None: self.state: ExecutionState = state self.execution_context: ExecutionContext = execution_context @@ -252,15 +253,17 @@ def __init__( # cached at construction to make invariant even if parent/prefix mutates. self._is_virtual: bool = self._parent_id != self._step_id_prefix self._step_counter: OrderedCounter = OrderedCounter() + self._replay_status: ReplayStatus = replay_status + self._track_replay() log_info = LogInfo( - execution_state=state, parent_id=parent_id, ) self._log_info = log_info self.logger: Logger = logger or Logger.from_log_info( logger=logging.getLogger(), info=log_info, + context=self, ) @property @@ -275,6 +278,11 @@ def is_virtual(self) -> bool: """ return self._is_virtual + @property + def is_replaying(self) -> bool: + """True if this context is in replay mode""" + return self._replay_status is ReplayStatus.REPLAY + # region factories @staticmethod def from_lambda_context( @@ -323,9 +331,9 @@ def create_child_context( lambda_context=self.lambda_context, parent_id=child_parent_id, step_id_prefix=operation_id, + replay_status=self._replay_status, logger=self.logger.with_log_info( LogInfo( - execution_state=self.state, parent_id=child_parent_id, ) ), @@ -348,6 +356,7 @@ def set_logger(self, new_logger: LoggerInterface): self.logger = Logger.from_log_info( logger=new_logger, info=self._log_info, + context=self, ) def _create_step_id_for_logical_step(self, step: int) -> str: @@ -369,6 +378,19 @@ def _create_step_id(self) -> str: new_counter: int = self._step_counter.increment() return self._create_step_id_for_logical_step(new_counter) + def _track_replay(self) -> None: + """Transition replay status to NEW if the next operation has not been checkpointed""" + if self._replay_status is ReplayStatus.NEW: + return + # check if next operation exists + next_counter = self._step_counter.get_current() + 1 + next_step_id = self._create_step_id_for_logical_step(next_counter) + if not self.state.get_checkpoint_result(next_step_id).is_existent(): + # update the context replay status to NEW + self._replay_status = ReplayStatus.NEW + # update the execution replay status to NEW + self.state.transition_replay_status() + # region Operations def create_callback( @@ -400,6 +422,7 @@ def create_callback( ), config=config, ) + self._track_replay() callback_id: str = executor.process() result: Callback = Callback( callback_id=callback_id, @@ -407,7 +430,6 @@ def create_callback( state=self.state, serdes=config.serdes, ) - self.state.track_replay(operation_id=operation_id) return result def invoke( @@ -442,8 +464,8 @@ def invoke( ), config=config, ) + self._track_replay() result: R = executor.process() - self.state.track_replay(operation_id=operation_id) return result def map( @@ -478,6 +500,7 @@ def map_in_child_context() -> BatchResult[R]: operation_identifier=operation_identifier, ) + self._track_replay() result: BatchResult[R] = child_handler( func=map_in_child_context, state=self.state, @@ -491,7 +514,6 @@ def map_in_child_context() -> BatchResult[R]: item_serdes=None, ), ) - self.state.track_replay(operation_id=operation_id) return result def parallel( @@ -521,6 +543,7 @@ def parallel_in_child_context() -> BatchResult[T]: operation_identifier=operation_identifier, ) + self._track_replay() result: BatchResult[T] = child_handler( func=parallel_in_child_context, state=self.state, @@ -534,7 +557,6 @@ def parallel_in_child_context() -> BatchResult[T]: item_serdes=None, ), ) - self.state.track_replay(operation_id=operation_id) return result def run_in_child_context( @@ -568,6 +590,7 @@ def callable_with_child_context(): ) ) + self._track_replay() result: T = child_handler( func=callable_with_child_context, state=self.state, @@ -578,7 +601,6 @@ def callable_with_child_context(): ), config=config, ) - self.state.track_replay(operation_id=operation_id) return result def step( @@ -603,8 +625,8 @@ def step( ), context_logger=self.logger, ) + self._track_replay() result: T = executor.process() - self.state.track_replay(operation_id=operation_id) return result def wait(self, duration: Duration, name: str | None = None) -> None: @@ -629,8 +651,8 @@ def wait(self, duration: Duration, name: str | None = None) -> None: name=name, ), ) + self._track_replay() executor.process() - self.state.track_replay(operation_id=operation_id) def wait_for_callback( self, @@ -686,8 +708,8 @@ def wait_for_condition( context_logger=self.logger, ) ) + self._track_replay() result: T = executor.process() - self.state.track_replay(operation_id=operation_id) return result diff --git a/src/aws_durable_execution_sdk_python/logger.py b/src/aws_durable_execution_sdk_python/logger.py index c2a2be7..e02359d 100644 --- a/src/aws_durable_execution_sdk_python/logger.py +++ b/src/aws_durable_execution_sdk_python/logger.py @@ -10,13 +10,12 @@ if TYPE_CHECKING: from collections.abc import Callable, Mapping, MutableMapping - from aws_durable_execution_sdk_python.context import ExecutionState + from aws_durable_execution_sdk_python import DurableContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier @dataclass(frozen=True) class LogInfo: - execution_state: ExecutionState parent_id: str | None = None operation_id: str | None = None name: str | None = None @@ -25,13 +24,11 @@ class LogInfo: @classmethod def from_operation_identifier( cls, - execution_state: ExecutionState, op_id: OperationIdentifier, attempt: int | None = None, ) -> LogInfo: """Create new log info from an execution arn, OperationIdentifier and attempt.""" return cls( - execution_state=execution_state, parent_id=op_id.parent_id, operation_id=op_id.operation_id, name=op_id.name, @@ -41,7 +38,6 @@ def from_operation_identifier( def with_parent_id(self, parent_id: str) -> LogInfo: """Clone the log info with a new parent id.""" return LogInfo( - execution_state=self.execution_state, parent_id=parent_id, operation_id=self.operation_id, name=self.name, @@ -54,17 +50,19 @@ def __init__( self, logger: LoggerInterface, default_extra: Mapping[str, object], - execution_state: ExecutionState, + context: DurableContext, ) -> None: self._logger = logger self._default_extra = default_extra - self._execution_state = execution_state + self._context = context @classmethod - def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger: + def from_log_info( + cls, logger: LoggerInterface, info: LogInfo, context: DurableContext + ) -> Logger: """Create a new logger with the given LogInfo.""" extra: MutableMapping[str, object] = { - "executionArn": info.execution_state.durable_execution_arn + "executionArn": context.state.durable_execution_arn } if info.parent_id: extra["parentId"] = info.parent_id @@ -75,15 +73,14 @@ def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger: extra["attempt"] = info.attempt if info.operation_id: extra["operationId"] = info.operation_id - return cls( - logger=logger, default_extra=extra, execution_state=info.execution_state - ) + return cls(logger=logger, default_extra=extra, context=context) def with_log_info(self, info: LogInfo) -> Logger: """Clone the existing logger with new LogInfo.""" return Logger.from_log_info( logger=self._logger, info=info, + context=self._context, ) def get_logger(self) -> LoggerInterface: @@ -128,4 +125,4 @@ def _log( log_func(msg, *args, extra=merged_extra) def _should_log(self) -> bool: - return not self._execution_state.is_replaying() + return not self._context.is_replaying diff --git a/src/aws_durable_execution_sdk_python/operation/step.py b/src/aws_durable_execution_sdk_python/operation/step.py index 8a418fb..6517c51 100644 --- a/src/aws_durable_execution_sdk_python/operation/step.py +++ b/src/aws_durable_execution_sdk_python/operation/step.py @@ -210,7 +210,6 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: step_context: StepContext = StepContext( logger=self.context_logger.with_log_info( LogInfo.from_operation_identifier( - execution_state=self.state, op_id=self.operation_identifier, attempt=attempt, ) diff --git a/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py b/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py index 5c4f1c4..3f4eaeb 100644 --- a/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py +++ b/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py @@ -181,7 +181,6 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: check_context = WaitForConditionCheckContext( logger=self.context_logger.with_log_info( LogInfo.from_operation_identifier( - execution_state=self.state, op_id=self.operation_identifier, attempt=attempt, ) diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 8317550..fd63988 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -343,39 +343,12 @@ def get_execution_operation(self) -> Operation | None: return candidate - def track_replay(self, operation_id: str) -> None: - """Check if operation exists with completed status; if not, transition to NEW status. - - This method is called before each operation (step, wait, invoke, etc.) to determine - if we've reached the replay boundary. Once we encounter an operation that doesn't - exist or isn't completed, we transition from REPLAY to NEW status, which enables - logging for all subsequent code. - - Args: - operation_id: The operation ID to check - """ - with self._replay_status_lock: - if self._replay_status == ReplayStatus.REPLAY: - self._visited_operations.add(operation_id) - completed_ops = { - op_id - for op_id, op in self.operations.items() - if op.operation_type != OperationType.EXECUTION - and op.status - in { - OperationStatus.SUCCEEDED, - OperationStatus.FAILED, - OperationStatus.CANCELLED, - OperationStatus.STOPPED, - OperationStatus.TIMED_OUT, - } - } - if completed_ops.issubset(self._visited_operations): - logger.debug( - "Transitioning from REPLAY to NEW status at operation %s", - operation_id, - ) - self._replay_status = ReplayStatus.NEW + def transition_replay_status(self) -> None: + """Transition to NEW status""" + if self._replay_status is ReplayStatus.REPLAY: + with self._replay_status_lock: + logger.debug("Transitioning from REPLAY to NEW status") + self._replay_status = ReplayStatus.NEW def is_replaying(self) -> bool: """Check if execution is currently in replay mode. diff --git a/tests/context_test.py b/tests/context_test.py index af32a3e..689b065 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -4,6 +4,7 @@ import json import random from itertools import islice +from unittest import mock from unittest.mock import ANY, MagicMock, Mock, patch import pytest @@ -1930,6 +1931,7 @@ def test_execution_context_propagates_to_child_context(): assert child_context.execution_context.durable_execution_arn == parent_arn # Should be the same instance (not a copy) assert child_context.execution_context is parent_context.execution_context + assert child_context.is_replaying def test_from_lambda_context_creates_execution_context(): @@ -1957,6 +1959,7 @@ def test_execution_context_type(): context = create_test_context(state=mock_state) assert isinstance(context.execution_context, ExecutionContext) + assert context.is_replaying # endregion ExecutionContext tests @@ -2140,6 +2143,7 @@ def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual(): assert outer_branch._parent_id == "outer-parallel-op" # noqa: SLF001 assert outer_branch._step_id_prefix == "outer-branch-op" # noqa: SLF001 assert outer_branch.is_virtual is True + assert outer_branch.is_replaying # Second virtual layer: an inner FLAT map inside the outer branch, # whose per-item branch is also virtual. @@ -2151,6 +2155,7 @@ def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual(): assert inner_branch._parent_id == "outer-parallel-op" # noqa: SLF001 assert inner_branch._step_id_prefix == "inner-branch-op" # noqa: SLF001 assert inner_branch.is_virtual is True + assert inner_branch.is_replaying # Step ids inside the inner branch still prefix on the inner branch's # own operation id; they must not leak the outer ancestor into the @@ -2159,4 +2164,47 @@ def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual(): assert inner_branch._create_step_id_for_logical_step(1) == expected # noqa: SLF001 +def test_context_created_with_new_status_when_check_result_returns_nonexistent(): + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + context = create_test_context(state=mock_state) + assert not context.is_replaying + + # op id for 1 + mock_state.get_checkpoint_result.assert_called_once_with( + "1ced8f5be2db23a6513eba4d819c73806424748a7bc6fa0d792cc1c7d1775a97" + ) + + +def test_transition_replay_status_when_check_result_returns_nonexistent(): + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + mock_checkpoint_result = Mock(spec=CheckpointedResult) + mock_state.get_checkpoint_result.return_value = mock_checkpoint_result + mock_checkpoint_result.is_existent.return_value = True + context = create_test_context(state=mock_state) + assert context.is_replaying + + context._track_replay() + assert context.is_replaying + + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + context._track_replay() + assert not context.is_replaying + + # op id for 1 + mock_state.get_checkpoint_result.assert_called_with( + "1ced8f5be2db23a6513eba4d819c73806424748a7bc6fa0d792cc1c7d1775a97" + ) + + # endregion Virtual-context identity tests diff --git a/tests/execution_test.py b/tests/execution_test.py index db13b5a..343fa46 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -2690,8 +2690,9 @@ def _make_lambda_context(): def test_durable_execution_replays_when_paginated_state_has_prior_operations(): """Test paginated execution state starts in replay mode when prior operations exist.""" mock_client = Mock(spec=DurableServiceClient) + # step_operation with operation_id = hashed(1) step_operation = Operation( - operation_id="step1", + operation_id="1ced8f5be2db23a6513eba4d819c73806424748a7bc6fa0d792cc1c7d1775a97", operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) @@ -2704,7 +2705,7 @@ def test_durable_execution_replays_when_paginated_state_has_prior_operations(): @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: - return {"is_replaying": context.state.is_replaying()} + return {"is_replaying": context.is_replaying} result = test_handler(invocation_input, _make_lambda_context()) diff --git a/tests/logger_test.py b/tests/logger_test.py index b6017fa..d758431 100644 --- a/tests/logger_test.py +++ b/tests/logger_test.py @@ -4,6 +4,8 @@ from collections.abc import Mapping from unittest.mock import Mock +from aws_durable_execution_sdk_python import DurableContext +from aws_durable_execution_sdk_python.context import ExecutionContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( Operation, @@ -84,6 +86,11 @@ def exception( operations={}, service_client=Mock(), ) +EXECUTION_CONTEXT = ExecutionContext("arn:aws:test") +DURABLE_CONTEXT = DurableContext( + state=EXECUTION_STATE, + execution_context=EXECUTION_CONTEXT, +) def test_powertools_logger_compatibility(): @@ -102,8 +109,8 @@ def accepts_logger_interface(logger: LoggerInterface) -> None: accepts_logger_interface(powertools_logger) # Test that our Logger can wrap the PowertoolsLoggerStub - log_info = LogInfo(EXECUTION_STATE) - wrapped_logger = Logger.from_log_info(powertools_logger, log_info) + log_info = LogInfo() + wrapped_logger = Logger.from_log_info(powertools_logger, log_info, DURABLE_CONTEXT) # Test all methods work wrapped_logger.debug("debug message") @@ -115,8 +122,7 @@ def accepts_logger_interface(logger: LoggerInterface) -> None: def test_log_info_creation(): """Test LogInfo creation with all parameters.""" - log_info = LogInfo(EXECUTION_STATE, "parent123", "operation123", "test_name", 5) - assert log_info.execution_state.durable_execution_arn == "arn:aws:test" + log_info = LogInfo("parent123", "operation123", "test_name", 5) assert log_info.parent_id == "parent123" assert log_info.operation_id == "operation123" assert log_info.name == "test_name" @@ -125,8 +131,7 @@ def test_log_info_creation(): def test_log_info_creation_minimal(): """Test LogInfo creation with minimal parameters.""" - log_info = LogInfo(EXECUTION_STATE) - assert log_info.execution_state.durable_execution_arn == "arn:aws:test" + log_info = LogInfo() assert log_info.parent_id is None assert log_info.operation_id is None assert log_info.name is None @@ -136,8 +141,7 @@ def test_log_info_creation_minimal(): def test_log_info_from_operation_identifier(): """Test LogInfo.from_operation_identifier.""" op_id = OperationIdentifier("op123", "parent456", "op_name") - log_info = LogInfo.from_operation_identifier(EXECUTION_STATE, op_id, 3) - assert log_info.execution_state.durable_execution_arn == "arn:aws:test" + log_info = LogInfo.from_operation_identifier(op_id, 3) assert log_info.parent_id == "parent456" assert log_info.operation_id == "op123" assert log_info.name == "op_name" @@ -147,8 +151,7 @@ def test_log_info_from_operation_identifier(): def test_log_info_from_operation_identifier_no_attempt(): """Test LogInfo.from_operation_identifier without attempt.""" op_id = OperationIdentifier("op123", "parent456", "op_name") - log_info = LogInfo.from_operation_identifier(EXECUTION_STATE, op_id) - assert log_info.execution_state.durable_execution_arn == "arn:aws:test" + log_info = LogInfo.from_operation_identifier(op_id) assert log_info.parent_id == "parent456" assert log_info.operation_id == "op123" assert log_info.name == "op_name" @@ -157,9 +160,8 @@ def test_log_info_from_operation_identifier_no_attempt(): def test_log_info_with_parent_id(): """Test LogInfo.with_parent_id.""" - original = LogInfo(EXECUTION_STATE, "old_parent", "op123", "test_name", 2) + original = LogInfo("old_parent", "op123", "test_name", 2) new_log_info = original.with_parent_id("new_parent") - assert new_log_info.execution_state.durable_execution_arn == "arn:aws:test" assert new_log_info.parent_id == "new_parent" assert new_log_info.operation_id == "op123" assert new_log_info.name == "test_name" @@ -169,8 +171,8 @@ def test_log_info_with_parent_id(): def test_logger_from_log_info_full(): """Test Logger.from_log_info with all LogInfo fields.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE, "parent123", "op123", "test_name", 5) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo("parent123", "op123", "test_name", 5) + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = { "executionArn": "arn:aws:test", @@ -188,20 +190,20 @@ def test_logger_from_log_info_partial_fields(): mock_logger = Mock() # Test with parent_id but no name or attempt - log_info = LogInfo(EXECUTION_STATE, "parent123") - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo("parent123") + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = {"executionArn": "arn:aws:test", "parentId": "parent123"} assert logger._default_extra == expected_extra # noqa: SLF001 # Test with name but no parent_id or attempt - log_info = LogInfo(EXECUTION_STATE, None, None, "test_name") - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo(None, None, "test_name") + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = {"executionArn": "arn:aws:test", "operationName": "test_name"} assert logger._default_extra == expected_extra # noqa: SLF001 # Test with attempt but no parent_id or name - log_info = LogInfo(EXECUTION_STATE, None, None, None, 5) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo(None, None, None, 5) + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = {"executionArn": "arn:aws:test", "attempt": 5} assert logger._default_extra == expected_extra # noqa: SLF001 @@ -209,8 +211,8 @@ def test_logger_from_log_info_partial_fields(): def test_logger_from_log_info_minimal(): """Test Logger.from_log_info with minimal LogInfo.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = {"executionArn": "arn:aws:test"} assert logger._default_extra == expected_extra # noqa: SLF001 @@ -219,8 +221,7 @@ def test_logger_from_log_info_minimal(): def test_logger_with_log_info(): """Test Logger.with_log_info.""" mock_logger = Mock() - original_info = LogInfo(EXECUTION_STATE, "parent1") - logger = Logger.from_log_info(mock_logger, original_info) + original_info = LogInfo("parent1") execution_state_new = ExecutionState( durable_execution_arn="arn:aws:new", @@ -228,7 +229,9 @@ def test_logger_with_log_info(): operations={}, service_client=Mock(), ) - new_info = LogInfo(execution_state_new, "parent2", "op123", "new_name") + durable_context = DurableContext(execution_state_new, EXECUTION_CONTEXT) + logger = Logger.from_log_info(mock_logger, original_info, durable_context) + new_info = LogInfo("parent2", "op123", "new_name") new_logger = logger.with_log_info(new_info) expected_extra = { @@ -244,16 +247,16 @@ def test_logger_with_log_info(): def test_logger_get_logger(): """Test Logger.get_logger.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) assert logger.get_logger() is mock_logger def test_logger_debug(): """Test Logger.debug method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE, "parent123") - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo("parent123") + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.debug("test %s message", "arg1", extra={"custom": "value"}) @@ -270,8 +273,8 @@ def test_logger_debug(): def test_logger_info(): """Test Logger.info method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.info("info message") @@ -282,8 +285,8 @@ def test_logger_info(): def test_logger_warning(): """Test Logger.warning method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.warning("warning %s %s message", "arg1", "arg2") @@ -296,8 +299,8 @@ def test_logger_warning(): def test_logger_error(): """Test Logger.error method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.error("error message", extra={"error_code": 500}) @@ -308,8 +311,8 @@ def test_logger_error(): def test_logger_exception(): """Test Logger.exception method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.exception("exception message") @@ -322,8 +325,8 @@ def test_logger_exception(): def test_logger_methods_with_none_extra(): """Test logger methods handle None extra parameter.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.debug("debug", extra=None) logger.info("info", extra=None) @@ -342,8 +345,8 @@ def test_logger_methods_with_none_extra(): def test_logger_extra_override(): """Test that custom extra overrides default extra.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE, "parent123") - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo("parent123") + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.info("test", extra={"executionArn": "overridden", "newField": "value"}) @@ -357,8 +360,8 @@ def test_logger_extra_override(): def test_logger_without_mocked_logger(): """Test Logger methods without mocking the underlying logger.""" - log_info = LogInfo(EXECUTION_STATE, "parent123", "test_name", 5) - logger = Logger.from_log_info(logging.getLogger(), log_info) + log_info = LogInfo("parent123", "op1", "test_name", 5) + logger = Logger.from_log_info(logging.getLogger(), log_info, DURABLE_CONTEXT) logger.info("test", extra={"execution_arn": "overridden", "new_field": "value"}) logger.warning("test", extra={"execution_arn": "overridden", "new_field": "value"}) @@ -378,12 +381,13 @@ def test_logger_replay_no_logging(): service_client=Mock(), replay_status=ReplayStatus.REPLAY, ) - log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5) + durable_context = Mock(DurableContext) + durable_context.is_replaying = True + durable_context.state = replay_execution_state + log_info = LogInfo("parent123", "op1", "test_name", 5) mock_logger = Mock() - logger = Logger.from_log_info(mock_logger, log_info) + logger = Logger.from_log_info(mock_logger, log_info, durable_context) logger.info("logging info") - replay_execution_state.track_replay(operation_id="op1") - mock_logger.info.assert_not_called() @@ -405,14 +409,16 @@ def test_logger_replay_then_new_logging(): service_client=Mock(), replay_status=ReplayStatus.REPLAY, ) - log_info = LogInfo(execution_state, "parent123", "test_name", 5) + durable_context = Mock(DurableContext) + durable_context.is_replaying = True + durable_context.state = execution_state + log_info = LogInfo("parent123", "op1", "test_name", 5) mock_logger = Mock() - logger = Logger.from_log_info(mock_logger, log_info) - execution_state.track_replay(operation_id="op1") + logger = Logger.from_log_info(mock_logger, log_info, durable_context) logger.info("logging info") mock_logger.info.assert_not_called() - execution_state.track_replay(operation_id="op2") + durable_context.is_replaying = False logger.info("logging info") mock_logger.info.assert_called_once() diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index b3c979d..db6342a 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -2,6 +2,7 @@ import importlib import json +from collections import defaultdict from unittest.mock import Mock, patch import pytest @@ -26,7 +27,7 @@ from aws_durable_execution_sdk_python.operation import child # PLC0415 from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler from aws_durable_execution_sdk_python.serdes import serialize -from aws_durable_execution_sdk_python.state import ExecutionState +from aws_durable_execution_sdk_python.state import ExecutionState, CheckpointedResult from tests.serdes_test import CustomStrSerDes @@ -846,13 +847,11 @@ def get_checkpoint(op_id): mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() - context_map = {} + context_map = defaultdict(set) def create_id(self, i): ctx_id = id(self) - if ctx_id not in context_map: - context_map[ctx_id] = [] - context_map[ctx_id].append(i) + context_map[ctx_id].add(i) return ( "parent" if len(context_map) == 1 and len(context_map[ctx_id]) == 1 @@ -908,13 +907,11 @@ def get_checkpoint(op_id): mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() - context_map = {} + context_map = defaultdict(set) def create_id(self, i): ctx_id = id(self) - if ctx_id not in context_map: - context_map[ctx_id] = [] - context_map[ctx_id].append(i) + context_map[ctx_id].add(i) return ( "parent" if len(context_map) == 1 and len(context_map[ctx_id]) == 1 diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index 1922207..02300a0 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -2,6 +2,7 @@ import importlib import json +from collections import defaultdict from collections.abc import Mapping from typing import Any from unittest.mock import Mock, patch @@ -811,13 +812,11 @@ def get_checkpoint(op_id): mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() - context_map = {} + context_map = defaultdict(set) def create_id(self, i): ctx_id = id(self) - if ctx_id not in context_map: - context_map[ctx_id] = [] - context_map[ctx_id].append(i) + context_map[ctx_id].add(i) return ( "parent" if len(context_map) == 1 and len(context_map[ctx_id]) == 1 @@ -872,13 +871,11 @@ def get_checkpoint(op_id): mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() - context_map = {} + context_map = defaultdict(set) def create_id(self, i): ctx_id = id(self) - if ctx_id not in context_map: - context_map[ctx_id] = [] - context_map[ctx_id].append(i) + context_map[ctx_id].add(i) return ( "parent" if len(context_map) == 1 and len(context_map[ctx_id]) == 1 diff --git a/tests/state_test.py b/tests/state_test.py index 0152ca6..343ed42 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -3385,60 +3385,15 @@ def test_create_checkpoint_sync_always_synchronous(): def test_state_replay_mode(): - operation1 = Operation( - operation_id="op1", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, - ) - operation2 = Operation( - operation_id="op2", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, - ) - execution_state = ExecutionState( - durable_execution_arn="arn:aws:test", - initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation1, "op2": operation2}, - service_client=Mock(), - replay_status=ReplayStatus.REPLAY, - ) - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op1") - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op2") - assert execution_state.is_replaying() is False - - -def test_state_replay_mode_with_timed_out(): - """Test that TIMED_OUT operations are treated as terminal states for replay tracking. - - This test verifies that when an operation has TIMED_OUT status, it is correctly - recognized as a completed/terminal state, allowing the replay status to transition - from REPLAY to NEW once all completed operations have been visited. - - Regression test for: https://github.com/aws/aws-durable-execution-sdk-python/issues/262 - """ - operation1 = Operation( - operation_id="op1", - operation_type=OperationType.STEP, - status=OperationStatus.TIMED_OUT, - ) - operation2 = Operation( - operation_id="op2", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, - ) execution_state = ExecutionState( durable_execution_arn="arn:aws:test", initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation1, "op2": operation2}, + operations={}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op1") - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op2") + execution_state.transition_replay_status() assert execution_state.is_replaying() is False