Skip to content

Commit 2b047ff

Browse files
kevmyungaidandaly24jariy17claude
authored
fix(memory): Improve pagination behavior in get_last_k_turns() and list_messages() (#209)
* feat(memory): Improve pagination behavior in get_last_k_turns() and list_messages() - get_last_k_turns(): Auto-calculate max_results based on k (max(100, k*3)) - list_messages(): Add fetch_all parameter to fetch all messages (up to 10000) - Backward compatible: default behavior unchanged * fix(memory): Address PR review comments for pagination behavior - Extract shared pagination logic into pagination.py helper - Fix test mocks to use _data_plane_client.list_events - Add MAX_FETCH_ALL_RESULTS constant (10000) in strands session_manager - Rename include_branches to include_parent_branches in client.py for consistency - Add comprehensive tests for pagination helper * fix(memory): Address PR review comments - Remove fetch_all parameter from list_messages (misleading name) - Use MAX_FETCH_ALL_RESULTS (10000) as default when no limit specified - Remove pagination.py module, inline logic into client.py and session.py - Revert include_branches rename to avoid breaking change * fix: Apply ruff formatting to pagination changes Apply automatic formatting fixes identified by ruff format pre-commit hook. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> --------- Co-authored-by: Aidan Daly <aidandal@amazon.com> Co-authored-by: T.J Ariyawansa <tjariy@amazon.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 43c6c3c commit 2b047ff

6 files changed

Lines changed: 337 additions & 211 deletions

File tree

src/bedrock_agentcore/memory/client.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,55 +1096,75 @@ def get_last_k_turns(
10961096
k: int = 5,
10971097
branch_name: Optional[str] = None,
10981098
include_branches: bool = False,
1099-
max_results: int = 100,
1099+
max_results: Optional[int] = None,
11001100
) -> List[List[Dict[str, Any]]]:
11011101
"""Get the last K conversation turns.
11021102
11031103
A "turn" typically consists of a user message followed by assistant response(s).
11041104
This method groups messages into logical turns for easier processing.
11051105
1106+
If max_results is specified, fetches up to that many events and finds turns within them
1107+
(backward compatible behavior).
1108+
If max_results is None, automatically paginates until k turns are found.
1109+
11061110
Returns:
11071111
List of turns, where each turn is a list of message dictionaries
11081112
"""
1113+
base_params = {
1114+
"memoryId": memory_id,
1115+
"actorId": actor_id,
1116+
"sessionId": session_id,
1117+
}
1118+
1119+
if branch_name and branch_name != "main":
1120+
base_params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_branches}}
1121+
11091122
try:
1110-
# Use the new list_events method
1111-
events = self.list_events(
1112-
memory_id=memory_id,
1113-
actor_id=actor_id,
1114-
session_id=session_id,
1115-
branch_name=branch_name,
1116-
include_parent_branches=False,
1117-
max_results=max_results,
1118-
)
1123+
turns: List[List[Dict[str, Any]]] = []
1124+
current_turn: List[Dict[str, Any]] = []
1125+
next_token = None
1126+
total_fetched = 0
1127+
1128+
while len(turns) < k:
1129+
if max_results is not None:
1130+
remaining = max_results - total_fetched
1131+
if remaining <= 0:
1132+
break
1133+
batch_size = min(100, remaining)
1134+
else:
1135+
batch_size = 100
11191136

1120-
if not events:
1121-
return []
1137+
params = {**base_params, "maxResults": batch_size, "includePayloads": True}
1138+
if next_token:
1139+
params["nextToken"] = next_token
11221140

1123-
# Process events to group into turns
1124-
turns = []
1125-
current_turn = []
1141+
response = self.gmdp_client.list_events(**params)
1142+
events = response.get("events", [])
11261143

1127-
for event in events:
1128-
if len(turns) >= k:
1129-
break # Only need last K turns
1130-
for payload_item in event.get("payload", []):
1131-
if "conversational" in payload_item:
1132-
role = payload_item["conversational"].get("role")
1144+
if not events:
1145+
break
1146+
1147+
total_fetched += len(events)
11331148

1134-
# Start new turn on USER message
1135-
if role == Role.USER.value and current_turn:
1136-
turns.append(current_turn)
1137-
current_turn = []
1149+
for event in events:
1150+
if len(turns) >= k:
1151+
break
1152+
for payload_item in event.get("payload", []):
1153+
if "conversational" in payload_item:
1154+
role = payload_item["conversational"].get("role")
1155+
if role == Role.USER.value and current_turn:
1156+
turns.append(current_turn)
1157+
current_turn = []
1158+
current_turn.append(payload_item["conversational"])
11381159

1139-
current_turn.append(payload_item["conversational"])
1160+
next_token = response.get("nextToken")
1161+
if not next_token:
1162+
break
11401163

1141-
# Don't forget the last turn
1142-
if current_turn:
1164+
if current_turn and len(turns) < k:
11431165
turns.append(current_turn)
11441166

1145-
# Return the last k turns
1146-
return turns[:k] if len(turns) > k else turns
1147-
1167+
return turns[:k]
11481168
except ClientError as e:
11491169
logger.error("Failed to get last K turns: %s", e)
11501170
raise

src/bedrock_agentcore/memory/integrations/strands/session_manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
SESSION_PREFIX = "session_"
3232
AGENT_PREFIX = "agent_"
3333
MESSAGE_PREFIX = "message_"
34+
MAX_FETCH_ALL_RESULTS = 10000
3435

3536

3637
class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository):
@@ -427,7 +428,12 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
427428
)
428429

429430
def list_messages(
430-
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
431+
self,
432+
session_id: str,
433+
agent_id: str,
434+
limit: Optional[int] = None,
435+
offset: int = 0,
436+
**kwargs: Any,
431437
) -> list[SessionMessage]:
432438
"""List messages for an agent from AgentCore Memory with pagination.
433439
@@ -448,7 +454,8 @@ def list_messages(
448454
raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")
449455

450456
try:
451-
max_results = (limit + offset) if limit else 100
457+
max_results = (limit + offset) if limit else MAX_FETCH_ALL_RESULTS
458+
452459
events = self.memory_client.list_events(
453460
memory_id=self.config.memory_id,
454461
actor_id=self.config.actor_id,
@@ -512,7 +519,8 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig):
512519
)
513520
if retrieval_config.relevance_score:
514521
memories = [
515-
m for m in memories
522+
m
523+
for m in memories
516524
if m.get("relevanceScore", retrieval_config.relevance_score) >= retrieval_config.relevance_score
517525
]
518526
context_items = []

src/bedrock_agentcore/memory/session.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -786,53 +786,75 @@ def get_last_k_turns(
786786
k: int = 5,
787787
branch_name: Optional[str] = None,
788788
include_parent_branches: bool = False,
789-
max_results: int = 100,
789+
max_results: Optional[int] = None,
790790
) -> List[List[EventMessage]]:
791791
"""Get the last K conversation turns.
792792
793793
A "turn" typically consists of a user message followed by assistant response(s).
794794
This method groups messages into logical turns for easier processing.
795795
796+
If max_results is specified, fetches up to that many events and finds turns within them
797+
(backward compatible behavior).
798+
If max_results is None, automatically paginates until k turns are found.
799+
796800
Returns:
797801
List of turns, where each turn is a list of message dictionaries
798802
"""
803+
base_params = {
804+
"memoryId": self._memory_id,
805+
"actorId": actor_id,
806+
"sessionId": session_id,
807+
}
808+
809+
if branch_name and branch_name != "main":
810+
base_params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}}
811+
799812
try:
800-
events = self.list_events(
801-
actor_id=actor_id,
802-
session_id=session_id,
803-
branch_name=branch_name,
804-
include_parent_branches=include_parent_branches,
805-
max_results=max_results,
806-
)
813+
turns: List[List[EventMessage]] = []
814+
current_turn: List[EventMessage] = []
815+
next_token = None
816+
total_fetched = 0
817+
818+
while len(turns) < k:
819+
if max_results is not None:
820+
remaining = max_results - total_fetched
821+
if remaining <= 0:
822+
break
823+
batch_size = min(100, remaining)
824+
else:
825+
batch_size = 100
807826

808-
if not events:
809-
return []
827+
params = {**base_params, "maxResults": batch_size, "includePayloads": True}
828+
if next_token:
829+
params["nextToken"] = next_token
830+
831+
response = self._data_plane_client.list_events(**params)
832+
events = response.get("events", [])
810833

811-
# Process events to group into turns
812-
turns = []
813-
current_turn = []
834+
if not events:
835+
break
814836

815-
for event in events:
816-
if len(turns) >= k:
817-
break # Only need last K turns
818-
for payload_item in event.get("payload", []):
819-
if "conversational" in payload_item:
820-
role = payload_item["conversational"].get("role")
837+
total_fetched += len(events)
821838

822-
# Start new turn on USER message
823-
if role == MessageRole.USER.value and current_turn:
824-
turns.append(current_turn)
825-
current_turn = []
839+
for event in events:
840+
if len(turns) >= k:
841+
break
842+
for payload_item in event.get("payload", []):
843+
if "conversational" in payload_item:
844+
role = payload_item["conversational"].get("role")
845+
if role == MessageRole.USER.value and current_turn:
846+
turns.append(current_turn)
847+
current_turn = []
848+
current_turn.append(EventMessage(payload_item["conversational"]))
826849

827-
current_turn.append(EventMessage(payload_item["conversational"]))
850+
next_token = response.get("nextToken")
851+
if not next_token:
852+
break
828853

829-
# Don't forget the last turn
830-
if current_turn:
854+
if current_turn and len(turns) < k:
831855
turns.append(current_turn)
832856

833-
# Return the last k turns
834-
return turns[:k] if len(turns) > k else turns
835-
857+
return turns[:k]
836858
except ClientError as e:
837859
logger.error("Failed to get last K turns: %s", e)
838860
raise
@@ -1153,7 +1175,7 @@ def get_last_k_turns(
11531175
k: int = 5,
11541176
branch_name: Optional[str] = None,
11551177
include_parent_branches: Optional[bool] = None,
1156-
max_results: int = 100,
1178+
max_results: Optional[int] = None,
11571179
) -> List[List[EventMessage]]:
11581180
"""Delegates to manager.get_last_k_turns."""
11591181
return self._manager.get_last_k_turns(

tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,3 +1096,23 @@ def test_retrieve_customer_context_filters_by_relevance_score(self, mock_memory_
10961096
assert "High relevance 2" in injected_context
10971097
assert "Low relevance 1" not in injected_context
10981098
assert "Low relevance 2" not in injected_context
1099+
1100+
def test_list_messages_default_max_results(self, session_manager, mock_memory_client):
1101+
"""Test listing messages without limit uses default max_results=10000."""
1102+
mock_memory_client.list_events.return_value = []
1103+
1104+
session_manager.list_messages("test-session-456", "test-agent-123")
1105+
1106+
mock_memory_client.list_events.assert_called_once()
1107+
call_kwargs = mock_memory_client.list_events.call_args[1]
1108+
assert call_kwargs["max_results"] == 10000
1109+
1110+
def test_list_messages_with_limit_calculates_max_results(self, session_manager, mock_memory_client):
1111+
"""Test listing messages with limit calculates max_results correctly."""
1112+
mock_memory_client.list_events.return_value = []
1113+
1114+
session_manager.list_messages("test-session-456", "test-agent-123", limit=500, offset=50)
1115+
1116+
mock_memory_client.list_events.assert_called_once()
1117+
call_kwargs = mock_memory_client.list_events.call_args[1]
1118+
assert call_kwargs["max_results"] == 550 # limit + offset

tests/bedrock_agentcore/memory/test_client.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3064,3 +3064,67 @@ def test_wrap_configuration_custom_episodic_override():
30643064
wrapped["reflection"]["customReflectionConfiguration"]["episodicReflectionOverride"]["appendToPrompt"]
30653065
== "Reflect on episodes"
30663066
)
3067+
3068+
3069+
def test_get_last_k_turns_auto_pagination():
3070+
"""Test get_last_k_turns automatically paginates until k turns are found."""
3071+
with patch("boto3.client"):
3072+
client = MemoryClient()
3073+
3074+
mock_gmdp = MagicMock()
3075+
client.gmdp_client = mock_gmdp
3076+
3077+
# First call returns events but not enough turns, with next_token
3078+
# Second call returns more events, no next_token
3079+
mock_gmdp.list_events.side_effect = [
3080+
{
3081+
"events": [
3082+
{"payload": [{"conversational": {"role": "USER", "content": {"text": "Hi"}}}]},
3083+
{"payload": [{"conversational": {"role": "ASSISTANT", "content": {"text": "Hello"}}}]},
3084+
],
3085+
"nextToken": "token-123",
3086+
},
3087+
{
3088+
"events": [
3089+
{"payload": [{"conversational": {"role": "USER", "content": {"text": "How are you?"}}}]},
3090+
{"payload": [{"conversational": {"role": "ASSISTANT", "content": {"text": "Good"}}}]},
3091+
],
3092+
"nextToken": None,
3093+
},
3094+
]
3095+
3096+
# Request 2 turns without max_results - should paginate automatically
3097+
turns = client.get_last_k_turns(memory_id="mem-123", actor_id="user-123", session_id="session-456", k=2)
3098+
3099+
assert len(turns) == 2
3100+
assert mock_gmdp.list_events.call_count == 2
3101+
3102+
3103+
def test_get_last_k_turns_explicit_max_results():
3104+
"""Test get_last_k_turns respects explicitly provided max_results (backward compatible)."""
3105+
with patch("boto3.client"):
3106+
client = MemoryClient()
3107+
3108+
mock_gmdp = MagicMock()
3109+
client.gmdp_client = mock_gmdp
3110+
3111+
# Return events with next_token, but max_results should limit fetching
3112+
mock_gmdp.list_events.return_value = {
3113+
"events": [
3114+
{"payload": [{"conversational": {"role": "USER", "content": {"text": "Hi"}}}]},
3115+
],
3116+
"nextToken": "token-123",
3117+
}
3118+
3119+
# Request with explicit max_results=50 - should respect limit
3120+
client.get_last_k_turns(
3121+
memory_id="mem-123", actor_id="user-123", session_id="session-456", k=200, max_results=50
3122+
)
3123+
3124+
# First call should request up to max_results (min of 100 and 50 = 50)
3125+
first_call_args = mock_gmdp.list_events.call_args_list[0]
3126+
assert first_call_args[1]["maxResults"] == 50
3127+
3128+
# Total events fetched should not exceed max_results
3129+
total_fetched = sum(1 for _ in mock_gmdp.list_events.call_args_list)
3130+
assert total_fetched <= 50 # Should stop after fetching 50 events worth of calls

0 commit comments

Comments
 (0)