Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ def run_in_child_handler() -> ResultType:
is_virtual=is_virtual,
),
)
child_context.state.track_replay(operation_id=operation_id)
return result

def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
Expand Down
44 changes: 33 additions & 11 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SerDes,
deserialize,
)
from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001
from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus # noqa: TCH001
from aws_durable_execution_sdk_python.threading import OrderedCounter
from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol
from aws_durable_execution_sdk_python.types import (
Expand Down Expand Up @@ -238,6 +238,7 @@ def __init__(
parent_id: str | None = None,
logger: Logger | None = None,
step_id_prefix: str | None = None,
replay_status: ReplayStatus = ReplayStatus.REPLAY,
) -> None:
self.state: ExecutionState = state
self.execution_context: ExecutionContext = execution_context
Expand All @@ -252,15 +253,17 @@ def __init__(
# cached at construction to make invariant even if parent/prefix mutates.
self._is_virtual: bool = self._parent_id != self._step_id_prefix
self._step_counter: OrderedCounter = OrderedCounter()
self._replay_status: ReplayStatus = replay_status
self._track_replay()

log_info = LogInfo(
execution_state=state,
parent_id=parent_id,
)
self._log_info = log_info
self.logger: Logger = logger or Logger.from_log_info(
logger=logging.getLogger(),
info=log_info,
context=self,
)

@property
Expand All @@ -275,6 +278,11 @@ def is_virtual(self) -> bool:
"""
return self._is_virtual

@property
def is_replaying(self) -> bool:
"""True if this context is in replay mode"""
return self._replay_status is ReplayStatus.REPLAY

# region factories
@staticmethod
def from_lambda_context(
Expand Down Expand Up @@ -323,9 +331,9 @@ def create_child_context(
lambda_context=self.lambda_context,
parent_id=child_parent_id,
step_id_prefix=operation_id,
replay_status=self._replay_status,
logger=self.logger.with_log_info(
LogInfo(
execution_state=self.state,
parent_id=child_parent_id,
)
),
Expand All @@ -348,6 +356,7 @@ def set_logger(self, new_logger: LoggerInterface):
self.logger = Logger.from_log_info(
logger=new_logger,
info=self._log_info,
context=self,
)

def _create_step_id_for_logical_step(self, step: int) -> str:
Expand All @@ -369,6 +378,19 @@ def _create_step_id(self) -> str:
new_counter: int = self._step_counter.increment()
return self._create_step_id_for_logical_step(new_counter)

def _track_replay(self) -> None:
"""Transition replay status to NEW if the next operation has not been checkpointed"""
if self._replay_status is ReplayStatus.NEW:
return
# check if next operation exists
next_counter = self._step_counter.get_current() + 1
next_step_id = self._create_step_id_for_logical_step(next_counter)
if not self.state.get_checkpoint_result(next_step_id).is_existent():
# update the context replay status to NEW
self._replay_status = ReplayStatus.NEW
# update the execution replay status to NEW
self.state.transition_replay_status()

# region Operations

def create_callback(
Expand Down Expand Up @@ -400,14 +422,14 @@ def create_callback(
),
config=config,
)
self._track_replay()
callback_id: str = executor.process()
result: Callback = Callback(
callback_id=callback_id,
operation_id=operation_id,
state=self.state,
serdes=config.serdes,
)
self.state.track_replay(operation_id=operation_id)
return result

def invoke(
Expand Down Expand Up @@ -442,8 +464,8 @@ def invoke(
),
config=config,
)
self._track_replay()
result: R = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

def map(
Expand Down Expand Up @@ -478,6 +500,7 @@ def map_in_child_context() -> BatchResult[R]:
operation_identifier=operation_identifier,
)

self._track_replay()
result: BatchResult[R] = child_handler(
func=map_in_child_context,
state=self.state,
Expand All @@ -491,7 +514,6 @@ def map_in_child_context() -> BatchResult[R]:
item_serdes=None,
),
)
self.state.track_replay(operation_id=operation_id)
return result

def parallel(
Expand Down Expand Up @@ -521,6 +543,7 @@ def parallel_in_child_context() -> BatchResult[T]:
operation_identifier=operation_identifier,
)

self._track_replay()
result: BatchResult[T] = child_handler(
func=parallel_in_child_context,
state=self.state,
Expand All @@ -534,7 +557,6 @@ def parallel_in_child_context() -> BatchResult[T]:
item_serdes=None,
),
)
self.state.track_replay(operation_id=operation_id)
return result

def run_in_child_context(
Expand Down Expand Up @@ -568,6 +590,7 @@ def callable_with_child_context():
)
)

self._track_replay()
result: T = child_handler(
func=callable_with_child_context,
state=self.state,
Expand All @@ -578,7 +601,6 @@ def callable_with_child_context():
),
config=config,
)
self.state.track_replay(operation_id=operation_id)
return result

def step(
Expand All @@ -603,8 +625,8 @@ def step(
),
context_logger=self.logger,
)
self._track_replay()
result: T = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

def wait(self, duration: Duration, name: str | None = None) -> None:
Expand All @@ -629,8 +651,8 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
name=name,
),
)
self._track_replay()
executor.process()
self.state.track_replay(operation_id=operation_id)

def wait_for_callback(
self,
Expand Down Expand Up @@ -686,8 +708,8 @@ def wait_for_condition(
context_logger=self.logger,
)
)
self._track_replay()
result: T = executor.process()
self.state.track_replay(operation_id=operation_id)
return result


Expand Down
23 changes: 10 additions & 13 deletions src/aws_durable_execution_sdk_python/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
if TYPE_CHECKING:
from collections.abc import Callable, Mapping, MutableMapping

from aws_durable_execution_sdk_python.context import ExecutionState
from aws_durable_execution_sdk_python import DurableContext
from aws_durable_execution_sdk_python.identifier import OperationIdentifier


@dataclass(frozen=True)
class LogInfo:
execution_state: ExecutionState
parent_id: str | None = None
operation_id: str | None = None
name: str | None = None
Expand All @@ -25,13 +24,11 @@ class LogInfo:
@classmethod
def from_operation_identifier(
cls,
execution_state: ExecutionState,
op_id: OperationIdentifier,
attempt: int | None = None,
) -> LogInfo:
"""Create new log info from an execution arn, OperationIdentifier and attempt."""
return cls(
execution_state=execution_state,
parent_id=op_id.parent_id,
operation_id=op_id.operation_id,
name=op_id.name,
Expand All @@ -41,7 +38,6 @@ def from_operation_identifier(
def with_parent_id(self, parent_id: str) -> LogInfo:
"""Clone the log info with a new parent id."""
return LogInfo(
execution_state=self.execution_state,
parent_id=parent_id,
operation_id=self.operation_id,
name=self.name,
Expand All @@ -54,17 +50,19 @@ def __init__(
self,
logger: LoggerInterface,
default_extra: Mapping[str, object],
execution_state: ExecutionState,
context: DurableContext,
) -> None:
self._logger = logger
self._default_extra = default_extra
self._execution_state = execution_state
self._context = context

@classmethod
def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
def from_log_info(
cls, logger: LoggerInterface, info: LogInfo, context: DurableContext
) -> Logger:
"""Create a new logger with the given LogInfo."""
extra: MutableMapping[str, object] = {
"executionArn": info.execution_state.durable_execution_arn
"executionArn": context.state.durable_execution_arn
}
if info.parent_id:
extra["parentId"] = info.parent_id
Expand All @@ -75,15 +73,14 @@ def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
extra["attempt"] = info.attempt
if info.operation_id:
extra["operationId"] = info.operation_id
return cls(
logger=logger, default_extra=extra, execution_state=info.execution_state
)
return cls(logger=logger, default_extra=extra, context=context)

def with_log_info(self, info: LogInfo) -> Logger:
"""Clone the existing logger with new LogInfo."""
return Logger.from_log_info(
logger=self._logger,
info=info,
context=self._context,
)

def get_logger(self) -> LoggerInterface:
Expand Down Expand Up @@ -128,4 +125,4 @@ def _log(
log_func(msg, *args, extra=merged_extra)

def _should_log(self) -> bool:
return not self._execution_state.is_replaying()
return not self._context.is_replaying
1 change: 0 additions & 1 deletion src/aws_durable_execution_sdk_python/operation/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
step_context: StepContext = StepContext(
logger=self.context_logger.with_log_info(
LogInfo.from_operation_identifier(
execution_state=self.state,
op_id=self.operation_identifier,
attempt=attempt,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
check_context = WaitForConditionCheckContext(
logger=self.context_logger.with_log_info(
LogInfo.from_operation_identifier(
execution_state=self.state,
op_id=self.operation_identifier,
attempt=attempt,
)
Expand Down
39 changes: 6 additions & 33 deletions src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,39 +343,12 @@ def get_execution_operation(self) -> Operation | None:

return candidate

def track_replay(self, operation_id: str) -> None:
"""Check if operation exists with completed status; if not, transition to NEW status.

This method is called before each operation (step, wait, invoke, etc.) to determine
if we've reached the replay boundary. Once we encounter an operation that doesn't
exist or isn't completed, we transition from REPLAY to NEW status, which enables
logging for all subsequent code.

Args:
operation_id: The operation ID to check
"""
with self._replay_status_lock:
if self._replay_status == ReplayStatus.REPLAY:
self._visited_operations.add(operation_id)
completed_ops = {
op_id
for op_id, op in self.operations.items()
if op.operation_type != OperationType.EXECUTION
and op.status
in {
OperationStatus.SUCCEEDED,
OperationStatus.FAILED,
OperationStatus.CANCELLED,
OperationStatus.STOPPED,
OperationStatus.TIMED_OUT,
}
}
if completed_ops.issubset(self._visited_operations):
logger.debug(
"Transitioning from REPLAY to NEW status at operation %s",
operation_id,
)
self._replay_status = ReplayStatus.NEW
def transition_replay_status(self) -> None:
"""Transition to NEW status"""
if self._replay_status is ReplayStatus.REPLAY:
with self._replay_status_lock:
logger.debug("Transitioning from REPLAY to NEW status")
self._replay_status = ReplayStatus.NEW

def is_replaying(self) -> bool:
"""Check if execution is currently in replay mode.
Expand Down
Loading