Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
should_truncate_results: bool = True,
*,
per_turn: bool | int = False,
protected_messages: int = 0,
):
"""Initialize the sliding window conversation manager.

Expand All @@ -54,18 +55,29 @@ def __init__(
manage message history and prevent the agent loop from slowing down. Start with
per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed
for performance tuning.
protected_messages: Number of messages at the start of the conversation that should
never be removed during trimming. Defaults to 0 (no protection).

Use this when the first message(s) contain a task prompt or critical context that
the agent must retain throughout the entire conversation. For example, in batch
report generation, set ``protected_messages=1`` to ensure the initial user prompt
is never trimmed away during context overflow recovery.

Raises:
ValueError: If per_turn is 0 or a negative integer.
ValueError: If per_turn is 0 or a negative integer, or if protected_messages is negative.
"""
if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0:
raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}")

if protected_messages < 0:
raise ValueError(f"protected_messages must be non-negative, got {protected_messages}")

super().__init__()

self.window_size = window_size
self.should_truncate_results = should_truncate_results
self.per_turn = per_turn
self.protected_messages = protected_messages
self._model_call_count = 0

def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
Expand Down Expand Up @@ -160,6 +172,10 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
- toolResult with no corresponding toolUse
- toolUse with no corresponding toolResult

When ``protected_messages`` is set, the first N messages are preserved and
re-inserted after trimming so that critical context (e.g. the initial task
prompt) is never lost.

Args:
agent: The agent whose messages will be reduce.
This list is modified in-place.
Expand All @@ -173,6 +189,11 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
"""
messages = agent.messages

# Snapshot protected messages before any trimming
protected: list = []
if self.protected_messages > 0 and len(messages) > self.protected_messages:
protected = [msg for msg in messages[: self.protected_messages]]

# Try to truncate the tool result first
oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages)
if oldest_message_idx_with_tool_results is not None and self.should_truncate_results:
Expand All @@ -188,6 +209,10 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size

# Never trim into the protected region
if trim_index < self.protected_messages:
trim_index = self.protected_messages

# Find the next valid trim point that:
# 1. Starts with a user message (required by most model providers)
# 2. Does not start with an orphaned toolResult
Expand Down Expand Up @@ -256,6 +281,18 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
# Overwrite message history
messages[:] = messages[trim_index:]

# Re-insert protected messages that were trimmed away
if protected:
# Check which protected messages are no longer present
reinsert = [msg for msg in protected if msg not in messages]
if reinsert:
messages[:0] = reinsert
logger.info(
"protected_messages=<%d> | re-inserted %d protected message(s) after trim",
self.protected_messages,
len(reinsert),
)

def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
"""Truncate tool results and replace image blocks in a message to reduce context size.

Expand Down
104 changes: 104 additions & 0 deletions tests/strands/agent/test_conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,107 @@ def test_boundary_text_in_tool_result_not_truncated():

assert not changed
assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == boundary_text


# ── protected_messages tests ──────────────────────────────────────────────


def test_protected_messages_negative_raises():
"""protected_messages must be non-negative."""
with pytest.raises(ValueError, match="non-negative"):
SlidingWindowConversationManager(protected_messages=-1)


def test_protected_messages_zero_is_default():
"""Default protected_messages=0 behaves identically to the original manager."""
manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False)
assert manager.protected_messages == 0


def test_protected_messages_preserves_first_message_on_trim():
"""When protected_messages=1, the first user message survives trimming."""
manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=1)
agent = MagicMock()
agent.messages = [
{"role": "user", "content": [{"text": "Generate the report"}]},
{"role": "assistant", "content": [{"text": "Step 1"}]},
{"role": "user", "content": [{"text": "Follow-up"}]},
{"role": "assistant", "content": [{"text": "Step 2"}]},
{"role": "user", "content": [{"text": "Another question"}]},
]

manager.apply_management(agent)

# The first message must still be present
assert agent.messages[0]["content"][0]["text"] == "Generate the report"
# And the conversation should end with the most recent messages
assert agent.messages[-1]["content"][0]["text"] == "Another question"


def test_protected_messages_preserves_first_message_on_overflow():
"""protected_messages=1 preserves the prompt even during context overflow (reduce_context with e)."""
manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=1)
agent = MagicMock()
agent.messages = [
{"role": "user", "content": [{"text": "Task prompt"}]},
{"role": "assistant", "content": [{"text": "Calling tools"}]},
{"role": "user", "content": [{"text": "Tool results"}]},
{"role": "assistant", "content": [{"text": "More work"}]},
{"role": "user", "content": [{"text": "More results"}]},
]

manager.reduce_context(agent, e=RuntimeError("context overflow"))

assert agent.messages[0]["content"][0]["text"] == "Task prompt"


def test_protected_messages_multiple():
"""protected_messages=2 preserves the first two messages."""
manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=2)
agent = MagicMock()
agent.messages = [
{"role": "user", "content": [{"text": "System context"}]},
{"role": "assistant", "content": [{"text": "Acknowledged"}]},
{"role": "user", "content": [{"text": "Question 1"}]},
{"role": "assistant", "content": [{"text": "Answer 1"}]},
{"role": "user", "content": [{"text": "Question 2"}]},
]

manager.apply_management(agent)

assert agent.messages[0]["content"][0]["text"] == "System context"
assert agent.messages[1]["content"][0]["text"] == "Acknowledged"


def test_protected_messages_no_trim_needed():
"""When messages fit in the window, protected_messages has no effect."""
manager = SlidingWindowConversationManager(window_size=10, should_truncate_results=False, protected_messages=1)
agent = MagicMock()
agent.messages = [
{"role": "user", "content": [{"text": "Hello"}]},
{"role": "assistant", "content": [{"text": "Hi"}]},
]

manager.apply_management(agent)

assert len(agent.messages) == 2


def test_protected_messages_trim_index_skips_protected_region():
"""The trim index must never fall within the protected region."""
manager = SlidingWindowConversationManager(window_size=3, should_truncate_results=False, protected_messages=1)
agent = MagicMock()
# 5 messages, window_size=3 → trim_index starts at 2
# But protected_messages=1 means index 0 is protected
agent.messages = [
{"role": "user", "content": [{"text": "Important prompt"}]},
{"role": "assistant", "content": [{"text": "Response 1"}]},
{"role": "user", "content": [{"text": "Q2"}]},
{"role": "assistant", "content": [{"text": "Response 2"}]},
{"role": "user", "content": [{"text": "Q3"}]},
]

manager.apply_management(agent)

# First message must survive
assert agent.messages[0]["content"][0]["text"] == "Important prompt"