Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/bedrock_agentcore/memory/integrations/strands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ class AgentCoreMemoryConfig(BaseModel):
retrieval_config: Optional dictionary mapping namespaces to retrieval configurations
batch_size: Number of messages to batch before sending to AgentCore Memory.
Default of 1 means immediate sending (no batching). Max 100.
context_tag: XML tag name used to wrap retrieved memory context injected into messages.
Default is "user_context".
"""

memory_id: str = Field(min_length=1)
session_id: str = Field(min_length=1)
actor_id: str = Field(min_length=1)
retrieval_config: Optional[Dict[str, RetrievalConfig]] = None
batch_size: int = Field(default=1, ge=1, le=100)
context_tag: str = Field(default="user_context", min_length=1)
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig):
if all_context:
context_text = "\n".join(all_context)
event.agent.messages[-1]["content"].insert(
0, {"text": f"<retrieved_memory>{context_text}</retrieved_memory>"}
0, {"text": f"<{self.config.context_tag}>{context_text}</{self.config.context_tag}>"}
)
logger.info("Retrieved %s customer context items", len(all_context))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1937,7 +1937,7 @@ def test_retrieve_customer_context_does_not_append_assistant_message(
# Memory prepended, original query remains last
content = mock_agent.messages[0]["content"]
assert len(content) == 2
assert "<retrieved_memory>" in content[0]["text"]
assert "<user_context>" in content[0]["text"]
assert content[1]["text"] == "What are my preferences?"

def test_retrieve_customer_context_no_assistant_message_multi_turn(
Expand Down Expand Up @@ -1982,5 +1982,86 @@ def test_retrieve_customer_context_no_assistant_message_multi_turn(
# Memory injected into last user message
content = mock_agent.messages[-1]["content"]
assert len(content) == 2
assert "<retrieved_memory>" in content[0]["text"]
assert "<user_context>" in content[0]["text"]
assert content[1]["text"] == "What do I like to eat?"

def test_retrieve_customer_context_custom_context_tag(self, mock_memory_client):
"""Test that a custom context_tag is used when configured."""
custom_config = AgentCoreMemoryConfig(
memory_id="test-memory-123",
session_id="test-session-456",
actor_id="test-actor-789",
retrieval_config={"user_preferences/{actorId}/": RetrievalConfig(top_k=5, relevance_score=0.3)},
context_tag="retrieved_memory",
)

mock_memory_client.retrieve_memories.return_value = [
{"content": {"text": "User likes sushi"}},
]

with patch(
"bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient",
return_value=mock_memory_client,
):
with patch("boto3.Session") as mock_boto_session:
mock_session = Mock()
mock_session.region_name = "us-west-2"
mock_session.client.return_value = Mock()
mock_boto_session.return_value = mock_session

with patch(
"strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None
):
manager = AgentCoreMemorySessionManager(custom_config)

mock_agent = Mock()
mock_agent.messages = [{"role": "user", "content": [{"text": "What do I like?"}]}]

event = MessageAddedEvent(
agent=mock_agent, message={"role": "user", "content": [{"text": "What do I like?"}]}
)
manager.retrieve_customer_context(event)

content = mock_agent.messages[0]["content"]
assert "<retrieved_memory>" in content[0]["text"]
assert "</retrieved_memory>" in content[0]["text"]

def test_retrieve_customer_context_default_context_tag(self, mock_memory_client):
"""Test that the default context_tag is user_context."""
default_config = AgentCoreMemoryConfig(
memory_id="test-memory-123",
session_id="test-session-456",
actor_id="test-actor-789",
retrieval_config={"user_preferences/{actorId}/": RetrievalConfig(top_k=5, relevance_score=0.3)},
)

mock_memory_client.retrieve_memories.return_value = [
{"content": {"text": "User likes sushi"}},
]

with patch(
"bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient",
return_value=mock_memory_client,
):
with patch("boto3.Session") as mock_boto_session:
mock_session = Mock()
mock_session.region_name = "us-west-2"
mock_session.client.return_value = Mock()
mock_boto_session.return_value = mock_session

with patch(
"strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None
):
manager = AgentCoreMemorySessionManager(default_config)

mock_agent = Mock()
mock_agent.messages = [{"role": "user", "content": [{"text": "What do I like?"}]}]

event = MessageAddedEvent(
agent=mock_agent, message={"role": "user", "content": [{"text": "What do I like?"}]}
)
manager.retrieve_customer_context(event)

content = mock_agent.messages[0]["content"]
assert "<user_context>" in content[0]["text"]
assert "</user_context>" in content[0]["text"]
4 changes: 2 additions & 2 deletions tests_integ/memory/integrations/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_session_manager_with_retrieval_config_adds_context(self, test_memory_lt
response2 = agent("What do I like to eat?")
assert response2 is not None
assert "sushi" in str(agent.messages)
assert "<retrieved_memory>" in str(agent.messages)
assert "<user_context>" in str(agent.messages)

def test_multiple_namespace_retrieval_config(self, test_memory_ltm):
"""Test session manager with multiple namespace retrieval configurations."""
Expand Down Expand Up @@ -182,7 +182,7 @@ def test_multiple_namespace_retrieval_config(self, test_memory_ltm):
response2 = agent("What do I like to eat?")
assert response2 is not None
assert "sushi" in str(agent.messages)
assert "<retrieved_memory>" in str(agent.messages)
assert "<user_context>" in str(agent.messages)

def test_session_manager_error_handling(self):
"""Test session manager error handling with invalid configuration."""
Expand Down
Loading