diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index f91d7a538..bdf5ee8ca 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -37,6 +37,7 @@ def __init__( should_truncate_results: bool = True, *, per_turn: bool | int = False, + protected_messages: int = 0, ): """Initialize the sliding window conversation manager. @@ -54,18 +55,29 @@ def __init__( manage message history and prevent the agent loop from slowing down. Start with per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed for performance tuning. + protected_messages: Number of messages at the start of the conversation that should + never be removed during trimming. Defaults to 0 (no protection). + + Use this when the first message(s) contain a task prompt or critical context that + the agent must retain throughout the entire conversation. For example, in batch + report generation, set ``protected_messages=1`` to ensure the initial user prompt + is never trimmed away during context overflow recovery. Raises: - ValueError: If per_turn is 0 or a negative integer. + ValueError: If per_turn is 0 or a negative integer, or if protected_messages is negative. """ if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0: raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}") + if protected_messages < 0: + raise ValueError(f"protected_messages must be non-negative, got {protected_messages}") + super().__init__() self.window_size = window_size self.should_truncate_results = should_truncate_results self.per_turn = per_turn + self.protected_messages = protected_messages self._model_call_count = 0 def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: @@ -160,6 +172,10 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A - toolResult with no corresponding toolUse - toolUse with no corresponding toolResult + When ``protected_messages`` is set, the first N messages are preserved and + re-inserted after trimming so that critical context (e.g. the initial task + prompt) is never lost. + Args: agent: The agent whose messages will be reduce. This list is modified in-place. @@ -173,6 +189,11 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A """ messages = agent.messages + # Snapshot protected messages before any trimming + protected: list = [] + if self.protected_messages > 0 and len(messages) > self.protected_messages: + protected = [msg for msg in messages[: self.protected_messages]] + # Try to truncate the tool result first oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages) if oldest_message_idx_with_tool_results is not None and self.should_truncate_results: @@ -188,6 +209,10 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size + # Never trim into the protected region + if trim_index < self.protected_messages: + trim_index = self.protected_messages + # Find the next valid trim point that: # 1. Starts with a user message (required by most model providers) # 2. Does not start with an orphaned toolResult @@ -256,6 +281,18 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A # Overwrite message history messages[:] = messages[trim_index:] + # Re-insert protected messages that were trimmed away + if protected: + # Check which protected messages are no longer present + reinsert = [msg for msg in protected if msg not in messages] + if reinsert: + messages[:0] = reinsert + logger.info( + "protected_messages=<%d> | re-inserted %d protected message(s) after trim", + self.protected_messages, + len(reinsert), + ) + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: """Truncate tool results and replace image blocks in a message to reduce context size. diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index c8b9df1cf..02d15b4d1 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -703,3 +703,107 @@ def test_boundary_text_in_tool_result_not_truncated(): assert not changed assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == boundary_text + + +# ── protected_messages tests ────────────────────────────────────────────── + + +def test_protected_messages_negative_raises(): + """protected_messages must be non-negative.""" + with pytest.raises(ValueError, match="non-negative"): + SlidingWindowConversationManager(protected_messages=-1) + + +def test_protected_messages_zero_is_default(): + """Default protected_messages=0 behaves identically to the original manager.""" + manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False) + assert manager.protected_messages == 0 + + +def test_protected_messages_preserves_first_message_on_trim(): + """When protected_messages=1, the first user message survives trimming.""" + manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=1) + agent = MagicMock() + agent.messages = [ + {"role": "user", "content": [{"text": "Generate the report"}]}, + {"role": "assistant", "content": [{"text": "Step 1"}]}, + {"role": "user", "content": [{"text": "Follow-up"}]}, + {"role": "assistant", "content": [{"text": "Step 2"}]}, + {"role": "user", "content": [{"text": "Another question"}]}, + ] + + manager.apply_management(agent) + + # The first message must still be present + assert agent.messages[0]["content"][0]["text"] == "Generate the report" + # And the conversation should end with the most recent messages + assert agent.messages[-1]["content"][0]["text"] == "Another question" + + +def test_protected_messages_preserves_first_message_on_overflow(): + """protected_messages=1 preserves the prompt even during context overflow (reduce_context with e).""" + manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=1) + agent = MagicMock() + agent.messages = [ + {"role": "user", "content": [{"text": "Task prompt"}]}, + {"role": "assistant", "content": [{"text": "Calling tools"}]}, + {"role": "user", "content": [{"text": "Tool results"}]}, + {"role": "assistant", "content": [{"text": "More work"}]}, + {"role": "user", "content": [{"text": "More results"}]}, + ] + + manager.reduce_context(agent, e=RuntimeError("context overflow")) + + assert agent.messages[0]["content"][0]["text"] == "Task prompt" + + +def test_protected_messages_multiple(): + """protected_messages=2 preserves the first two messages.""" + manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=2) + agent = MagicMock() + agent.messages = [ + {"role": "user", "content": [{"text": "System context"}]}, + {"role": "assistant", "content": [{"text": "Acknowledged"}]}, + {"role": "user", "content": [{"text": "Question 1"}]}, + {"role": "assistant", "content": [{"text": "Answer 1"}]}, + {"role": "user", "content": [{"text": "Question 2"}]}, + ] + + manager.apply_management(agent) + + assert agent.messages[0]["content"][0]["text"] == "System context" + assert agent.messages[1]["content"][0]["text"] == "Acknowledged" + + +def test_protected_messages_no_trim_needed(): + """When messages fit in the window, protected_messages has no effect.""" + manager = SlidingWindowConversationManager(window_size=10, should_truncate_results=False, protected_messages=1) + agent = MagicMock() + agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + + manager.apply_management(agent) + + assert len(agent.messages) == 2 + + +def test_protected_messages_trim_index_skips_protected_region(): + """The trim index must never fall within the protected region.""" + manager = SlidingWindowConversationManager(window_size=3, should_truncate_results=False, protected_messages=1) + agent = MagicMock() + # 5 messages, window_size=3 → trim_index starts at 2 + # But protected_messages=1 means index 0 is protected + agent.messages = [ + {"role": "user", "content": [{"text": "Important prompt"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Q2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "user", "content": [{"text": "Q3"}]}, + ] + + manager.apply_management(agent) + + # First message must survive + assert agent.messages[0]["content"][0]["text"] == "Important prompt"