From 6aa3988d10a5c7c96565cba614abf4a08b660851 Mon Sep 17 00:00:00 2001 From: Padma Komarina Date: Thu, 5 Mar 2026 15:44:25 -0500 Subject: [PATCH] fix: batch create agent state alongwith update_agent state events, flush messages and agent states separately --- .../integrations/strands/session_manager.py | 246 +++++---- .../test_agentcore_memory_session_manager.py | 468 ++++++++++++++---- 2 files changed, 511 insertions(+), 203 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index fc7eac3e..1baf234f 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -283,30 +283,48 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A if session_id != self.config.session_id: raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=self.session_id, - payload=[ - {"blob": json.dumps(session_agent.to_dict())}, - ], - eventTimestamp=self._get_monotonic_timestamp(), - metadata={ - STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value}, - AGENT_ID_KEY: {"stringValue": session_agent.agent_id}, - }, - ) - # Cache the created_at timestamp to avoid re-fetching on updates if session_agent.created_at: self._agent_created_at_cache[session_agent.agent_id] = session_agent.created_at - logger.info( - "Created agent: %s in session: %s with event %s", - session_agent.agent_id, - session_id, - event.get("event", {}).get("eventId"), - ) + if self.config.batch_size > 1: + # Buffer the agent state events + should_flush = False + with self._agent_state_lock: + self._agent_state_buffer.append((session_id, session_agent)) + should_flush = len(self._agent_state_buffer) >= self.config.batch_size + + # Flush only agent states outside the lock to prevent deadlock + if should_flush: + self._flush_agent_states_only() + + logger.info( + "Buffered agent creation: %s in session: %s", + session_agent.agent_id, + session_id, + ) + else: + # Immediate send when batching is disabled + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.session_id, + payload=[ + {"blob": json.dumps(session_agent.to_dict())}, + ], + eventTimestamp=self._get_monotonic_timestamp(), + metadata={ + STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value}, + AGENT_ID_KEY: {"stringValue": session_agent.agent_id}, + }, + ) + + logger.info( + "Created agent: %s in session: %s with event %s", + session_agent.agent_id, + session_id, + event.get("event", {}).get("eventId"), + ) def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: """Read agent data from AgentCore Memory events. @@ -395,20 +413,18 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A """ agent_id = session_agent.agent_id + # Verify agent exists and get created_at timestamp if not cached if agent_id not in self._agent_created_at_cache: previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) if previous_agent is None: raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") + + # Set created_at from cache before creating the update event session_agent.created_at = self._agent_created_at_cache[agent_id] - if self.config.batch_size > 1: - # Buffer the agent state update - with self._agent_state_lock: - self._agent_state_buffer.append((session_id, session_agent)) - else: - # Immediate send create_event without buffering - # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent` - self.create_agent(session_id, session_agent) + # Create a new agent event (AgentCore Memory is immutable) + # create_agent will handle batching and caching appropriately + self.create_agent(session_id, session_agent) def create_message( self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any @@ -466,9 +482,9 @@ def create_message( self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp)) should_flush = len(self._message_buffer) >= self.config.batch_size - # Flush outside the lock to prevent deadlock + # Flush only messages outside the lock to prevent deadlock if should_flush: - self._flush_messages() + self._flush_messages_only() return {} # No eventId yet @@ -711,116 +727,148 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: # region Batching support - def _flush_messages(self) -> list[dict[str, Any]]: - """Flush all buffered messages and agent state to AgentCore Memory. - - Call this method to send any remaining buffered messages and agent state when batch_size > 1. - This is automatically called when the buffer reaches batch_size, but should - also be called explicitly when the session is complete (via close() or context manager). + def _flush_messages_only(self) -> list[dict[str, Any]]: + """Flush only buffered messages to AgentCore Memory. + Call this method to send any remaining buffered messages when batch_size > 1. + This is called when the message buffer reaches batch_size. Messages are batched by session_id - all conversational messages for the same session are combined into a single create_event() call to reduce API calls. Blob messages (>9KB) are sent individually as they require a different API path. - Agent state updates are sent after messages. Returns: list[dict[str, Any]]: List of created event responses from AgentCore Memory. Raises: - SessionException: If any message or agent state creation fails. On failure, all messages - and agent state remain in the buffer to prevent data loss. + SessionException: If message creation fails. On failure, messages remain in the buffer. """ with self._message_lock: messages_to_send = list(self._message_buffer) - with self._agent_state_lock: - agent_states_to_send = list(self._agent_state_buffer) - - if not messages_to_send and not agent_states_to_send: + if not messages_to_send: return [] - # Group conversational messages by session_id, preserve order - # Structure: {session_id: {"messages": [...], "timestamp": latest_timestamp}} + # Group all messages by session_id, combining conversational and blob messages + # Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp}} session_groups: dict[str, dict[str, Any]] = {} - blob_messages: list[tuple[str, list[tuple[str, str]], datetime]] = [] for session_id, messages, is_blob, monotonic_timestamp in messages_to_send: + if session_id not in session_groups: + session_groups[session_id] = {"payload": [], "timestamp": monotonic_timestamp} + if is_blob: - # Blobs cannot be combined - collect them separately - blob_messages.append((session_id, messages, monotonic_timestamp)) + # Add blob messages to payload + for msg in messages: + session_groups[session_id]["payload"].append({"blob": json.dumps(msg)}) else: - # Group conversational messages by session_id - if session_id not in session_groups: - session_groups[session_id] = {"messages": [], "timestamp": monotonic_timestamp} - # Extend messages list to preserve order (earlier messages first) - session_groups[session_id]["messages"].extend(messages) - # Use the latest timestamp for the combined event - if monotonic_timestamp > session_groups[session_id]["timestamp"]: - session_groups[session_id]["timestamp"] = monotonic_timestamp + # Add conversational messages to payload + for text, role in messages: + session_groups[session_id]["payload"].append( + {"conversational": {"content": {"text": text}, "role": role.upper()}} + ) + + # Use the latest timestamp for the combined event + if monotonic_timestamp > session_groups[session_id]["timestamp"]: + session_groups[session_id]["timestamp"] = monotonic_timestamp results = [] try: - # Send one create_event per session_id with combined messages + # Send one create_event per session_id with all messages (conversational + blob) for session_id, group in session_groups.items(): - event = self.memory_client.create_event( - memory_id=self.config.memory_id, - actor_id=self.config.actor_id, - session_id=session_id, - messages=group["messages"], - event_timestamp=group["timestamp"], - ) - results.append(event) - logger.debug("Flushed batched event for session %s: %s", session_id, event.get("eventId")) - - # Send blob messages individually (they use a different API path) - for session_id, messages, monotonic_timestamp in blob_messages: event = self.memory_client.gmdp_client.create_event( memoryId=self.config.memory_id, actorId=self.config.actor_id, sessionId=session_id, - payload=[ - {"blob": json.dumps(messages[0])}, - ], - eventTimestamp=monotonic_timestamp, - ) - results.append(event) - logger.debug("Flushed blob event for session %s: %s", session_id, event.get("eventId")) - - # Flush agent state updates after messages - batch all agent states into a single API call - if agent_states_to_send: - # Convert all agent states to payload format - agent_state_payloads = [] - for _session_id, session_agent in agent_states_to_send: - agent_state_payloads.append({"blob": json.dumps(session_agent.to_dict())}) - - # Send all agent states in a single batched create_event call - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=self.config.session_id, - payload=agent_state_payloads, - eventTimestamp=self._get_monotonic_timestamp(), - metadata={ - STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value}, - }, + payload=group["payload"], + eventTimestamp=group["timestamp"], ) results.append(event) logger.debug( - "Flushed %d agent states in batched event: %s", len(agent_states_to_send), event.get("eventId") + "Flushed batched event for session %s with %d messages: %s", + session_id, + len(group["payload"]), + event.get("eventId"), ) - # Clear buffers only after ALL messages and agent state succeed + # Clear message buffer only after ALL messages succeed with self._message_lock: self._message_buffer.clear() + except Exception as e: + logger.error("Failed to flush messages to AgentCore Memory: %s", e) + raise SessionException(f"Failed to flush messages: {e}") from e + + logger.info("Flushed %d message events to AgentCore Memory", len(results)) + return results + + def _flush_agent_states_only(self) -> list[dict[str, Any]]: + """Flush only buffered agent states to AgentCore Memory. + + Call this method to send any remaining agent state when batch_size > 1. + This is called when the agent state buffer reaches batch_size. + All agent states are batched into a single create_event() call. + + Returns: + list[dict[str, Any]]: List of created event responses from AgentCore Memory. + + Raises: + SessionException: If agent state creation fails. On failure, agent states remain in the buffer. + """ + with self._agent_state_lock: + agent_states_to_send = list(self._agent_state_buffer) + + if not agent_states_to_send: + return [] + + results = [] + try: + # Convert all agent states to payload format + agent_state_payloads = [] + for _session_id, session_agent in agent_states_to_send: + agent_state_payloads.append({"blob": json.dumps(session_agent.to_dict())}) + + # Send all agent states in a single batched create_event call + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.config.session_id, + payload=agent_state_payloads, + eventTimestamp=self._get_monotonic_timestamp(), + metadata={ + STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value}, + }, + ) + results.append(event) + logger.debug( + "Flushed %d agent states in batched event: %s", len(agent_states_to_send), event.get("eventId") + ) + + # Clear agent state buffer only after success with self._agent_state_lock: self._agent_state_buffer.clear() except Exception as e: - logger.error("Failed to flush messages and agent state to AgentCore Memory: %s", e) - raise SessionException(f"Failed to flush messages and agent state: {e}") from e + logger.error("Failed to flush agent states to AgentCore Memory: %s", e) + raise SessionException(f"Failed to flush agent states: {e}") from e - logger.info("Flushed %d events to AgentCore Memory", len(results)) + logger.info("Flushed %d agent state events to AgentCore Memory", len(results)) + return results + + def _flush_messages(self) -> list[dict[str, Any]]: + """Flush all buffered messages and agent state to AgentCore Memory. + + Call this method to send any remaining buffered messages and agent state messages. + This is automatically called when the session is complete (via close() or context manager). + + Returns: + list[dict[str, Any]]: List of created event responses from AgentCore Memory. + + Raises: + SessionException: If any message or agent state creation fails. + """ + results = [] + results.extend(self._flush_messages_only()) + results.extend(self._flush_agent_states_only()) return results def pending_message_count(self) -> int: diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index c36cd7ff..b57fa0ae 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1361,7 +1361,7 @@ def test_pending_message_count_with_buffered_messages(self, batching_session_man def test_buffer_auto_flushes_at_batch_size(self, batching_session_manager, mock_memory_client): """Test buffer automatically flushes when reaching batch_size.""" - mock_memory_client.create_event.return_value = {"eventId": "event_123"} + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} # Add exactly batch_size messages (5) for i in range(5): @@ -1375,7 +1375,7 @@ def test_buffer_auto_flushes_at_batch_size(self, batching_session_manager, mock_ # Buffer should have been flushed assert batching_session_manager.pending_message_count() == 0 # One batched API call for all messages in the same session - assert mock_memory_client.create_event.call_count == 1 + assert mock_memory_client.gmdp_client.create_event.call_count == 1 def test_create_message_returns_empty_dict_when_buffered(self, batching_session_manager): """Test create_message returns empty dict when message is buffered.""" @@ -1408,12 +1408,12 @@ def test_pending_agent_state_count_with_buffered_states(self, batching_session_m agent.state["description"] = f"Updated description {i}" batching_session_manager.update_agent("test-session-456", agent) - # Should have 3 agent states in buffer (all updates preserved) - assert batching_session_manager.pending_agent_state_count() == 3 + # Should have 4 agent states in buffer (1 initial create + 3 updates) + assert batching_session_manager.pending_agent_state_count() == 4 # Verify no additional create_agent calls were made (still buffered) - assert mock_memory_client.gmdp_client.create_event.call_count == 1 # Only the initial create_agent + assert mock_memory_client.gmdp_client.create_event.call_count == 0 # All buffered, none flushed - def test_agent_state_buffer_keeps_latest_per_agent(self, batching_session_manager, mock_memory_client): + def test_agent_state_buffer_keeps_state_per_agent(self, batching_session_manager, mock_memory_client): """Test that agent state buffer preserves all agent state updates.""" # Create two agents agent1 = SessionAgent( @@ -1429,21 +1429,21 @@ def test_agent_state_buffer_keeps_latest_per_agent(self, batching_session_manage batching_session_manager.create_agent("test-session-456", agent1) batching_session_manager.create_agent("test-session-456", agent2) - # Update agent1 multiple times - for i in range(3): - agent1.state["description"] = f"Agent 1 update {i}" - batching_session_manager.update_agent("test-session-456", agent1) + # Update agent1 once + agent1.state["description"] = "Agent 1 update" + batching_session_manager.update_agent("test-session-456", agent1) # Update agent2 once agent2.state["description"] = "Agent 2 updated" batching_session_manager.update_agent("test-session-456", agent2) - # Should have 4 agent states in buffer (all updates preserved: 3 for agent1 + 1 for agent2) + # Total: 2 creates + 1 update + 1 update = 4 states in buffer (batch_size=5, so no auto-flush) + # Should have 4 agent states in buffer (all preserved: 2 creates + 1 for agent1 + 1 for agent2) assert batching_session_manager.pending_agent_state_count() == 4 def test_agent_state_flushed_with_messages(self, batching_session_manager, mock_memory_client): """Test that agent states are flushed along with messages.""" - mock_memory_client.create_event.return_value = {"eventId": "event_123"} + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} # Create agent agent = SessionAgent( @@ -1467,7 +1467,8 @@ def test_agent_state_flushed_with_messages(self, batching_session_manager, mock_ # Verify both are buffered assert batching_session_manager.pending_message_count() == 3 - assert batching_session_manager.pending_agent_state_count() == 1 + # Should have 2 agent states: 1 initial create + 1 update + assert batching_session_manager.pending_agent_state_count() == 2 # Flush batching_session_manager._flush_messages() @@ -1476,10 +1477,9 @@ def test_agent_state_flushed_with_messages(self, batching_session_manager, mock_ assert batching_session_manager.pending_message_count() == 0 assert batching_session_manager.pending_agent_state_count() == 0 - # Verify create_event was called for messages and agent state - # 1 initial create_agent + 1 batched message call + 1 agent state update - assert mock_memory_client.create_event.call_count == 1 # batched messages - assert mock_memory_client.gmdp_client.create_event.call_count == 2 # initial + update + # Verify create_event was called for messages and agent states + # 2 calls total: 1 for batched messages + 1 for batched agent states + assert mock_memory_client.gmdp_client.create_event.call_count == 2 def test_agent_state_preserved_on_flush_failure(self, batching_session_manager, mock_memory_client): """Test that agent states remain in buffer if flush fails.""" @@ -1495,7 +1495,8 @@ def test_agent_state_preserved_on_flush_failure(self, batching_session_manager, agent.state["description"] = "Updated" batching_session_manager.update_agent("test-session-456", agent) - assert batching_session_manager.pending_agent_state_count() == 1 + # Should have 2 states: 1 initial create + 1 update + assert batching_session_manager.pending_agent_state_count() == 2 # Make flush fail mock_memory_client.gmdp_client.create_event.side_effect = Exception("API Error") @@ -1504,8 +1505,8 @@ def test_agent_state_preserved_on_flush_failure(self, batching_session_manager, with pytest.raises(SessionException): batching_session_manager._flush_messages() - # Agent state should still be in buffer - assert batching_session_manager.pending_agent_state_count() == 1 + # Agent states should still be in buffer (2 states preserved) + assert batching_session_manager.pending_agent_state_count() == 2 class TestBatchingFlush: @@ -1518,7 +1519,7 @@ def test__flush_messages_empty_buffer(self, batching_session_manager): def test__flush_messages_sends_all_buffered(self, batching_session_manager, mock_memory_client): """Test _flush_messages sends all buffered messages in a single batched call.""" - mock_memory_client.create_event.return_value = {"eventId": "event_123"} + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} # Add 3 messages (below batch_size of 10) for i in range(3): @@ -1537,17 +1538,17 @@ def test__flush_messages_sends_all_buffered(self, batching_session_manager, mock # One batched API call for all messages in the same session assert len(results) == 1 assert batching_session_manager.pending_message_count() == 0 - assert mock_memory_client.create_event.call_count == 1 + assert mock_memory_client.gmdp_client.create_event.call_count == 1 def test__flush_messages_maintains_order(self, batching_session_manager, mock_memory_client): """Test _flush_messages maintains message order within batched payload.""" sent_payloads = [] def track_create_event(**kwargs): - sent_payloads.append(kwargs.get("messages")) + sent_payloads.append(kwargs.get("payload")) return {"eventId": f"event_{len(sent_payloads)}"} - mock_memory_client.create_event.side_effect = track_create_event + mock_memory_client.gmdp_client.create_event.side_effect = track_create_event # Add messages with distinct content for i in range(3): @@ -1562,14 +1563,14 @@ def track_create_event(**kwargs): # Should be one batched call with messages in order assert len(sent_payloads) == 1 - combined_messages = sent_payloads[0] - assert len(combined_messages) == 3 - for i, msg in enumerate(combined_messages): - assert f"Message_{i}" in msg[0] + combined_payload = sent_payloads[0] + assert len(combined_payload) == 3 + for i, item in enumerate(combined_payload): + assert f"Message_{i}" in item["conversational"]["content"]["text"] def test__flush_messages_clears_buffer(self, batching_session_manager, mock_memory_client): """Test _flush_messages clears the buffer after sending.""" - mock_memory_client.create_event.return_value = {"eventId": "event_123"} + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} message = SessionMessage( message={"role": "user", "content": [{"text": "Hello"}]}, @@ -1588,7 +1589,7 @@ def test__flush_messages_clears_buffer(self, batching_session_manager, mock_memo def test__flush_messages_exception_handling(self, batching_session_manager, mock_memory_client): """Test _flush_messages raises SessionException on failure.""" - mock_memory_client.create_event.side_effect = Exception("API Error") + mock_memory_client.gmdp_client.create_event.side_effect = Exception("API Error") message = SessionMessage( message={"role": "user", "content": [{"text": "Hello"}]}, @@ -1600,9 +1601,11 @@ def test__flush_messages_exception_handling(self, batching_session_manager, mock with pytest.raises(SessionException, match="Failed to flush messages"): batching_session_manager._flush_messages() - def test_partial_flush_failure_preserves_all_messages(self, batching_session_manager, mock_memory_client): + def test__flush_messages_partial_flush_failure_preserves_all_messages( + self, batching_session_manager, mock_memory_client + ): """Test that on flush failure, all messages remain in buffer to prevent data loss.""" - mock_memory_client.create_event.side_effect = Exception("API Error") + mock_memory_client.gmdp_client.create_event.side_effect = Exception("API Error") # Add multiple messages for i in range(3): @@ -1623,22 +1626,24 @@ def test_partial_flush_failure_preserves_all_messages(self, batching_session_man assert batching_session_manager.pending_message_count() == 3 # Fix the mock and retry - should succeed now - mock_memory_client.create_event.side_effect = None - mock_memory_client.create_event.return_value = {"eventId": "event_123"} + mock_memory_client.gmdp_client.create_event.side_effect = None + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} results = batching_session_manager._flush_messages() assert len(results) == 1 # One batched call for all messages assert batching_session_manager.pending_message_count() == 0 - def test_batching_combines_messages_for_same_session(self, batching_session_manager, mock_memory_client): + def test__flush_messages_batching_combines_messages_for_same_session( + self, batching_session_manager, mock_memory_client + ): """Test that multiple messages for the same session are combined into one API call.""" sent_payloads = [] def track_create_event(**kwargs): - sent_payloads.append(kwargs.get("messages")) + sent_payloads.append(kwargs.get("payload")) return {"eventId": f"event_{len(sent_payloads)}"} - mock_memory_client.create_event.side_effect = track_create_event + mock_memory_client.gmdp_client.create_event.side_effect = track_create_event # Add 5 messages to the same session for i in range(5): @@ -1652,15 +1657,17 @@ def track_create_event(**kwargs): batching_session_manager._flush_messages() # Should be ONE API call with all 5 messages combined - assert mock_memory_client.create_event.call_count == 1 + assert mock_memory_client.gmdp_client.create_event.call_count == 1 assert len(sent_payloads) == 1 # The combined payload should have all 5 messages assert len(sent_payloads[0]) == 5 # Messages should be in order for i in range(5): - assert f"Message_{i}" in sent_payloads[0][i][0] + assert f"Message_{i}" in sent_payloads[0][i]["conversational"]["content"]["text"] - def test_multiple_sessions_grouped_into_separate_api_calls(self, batching_session_manager, mock_memory_client): + def test__flush_messages_multiple_sessions_grouped_into_separate_api_calls( + self, batching_session_manager, mock_memory_client + ): """Test that messages to different sessions are grouped into separate API calls. Note: In normal usage, create_message enforces session_id == config.session_id, @@ -1672,12 +1679,12 @@ def test_multiple_sessions_grouped_into_separate_api_calls(self, batching_sessio calls_by_session = {} def track_create_event(**kwargs): - session_id = kwargs.get("session_id") - messages = kwargs.get("messages") - calls_by_session[session_id] = messages + session_id = kwargs.get("sessionId") + payload = kwargs.get("payload") + calls_by_session[session_id] = payload return {"eventId": f"event_{session_id}"} - mock_memory_client.create_event.side_effect = track_create_event + mock_memory_client.gmdp_client.create_event.side_effect = track_create_event # Directly populate buffer with messages for multiple sessions # Buffer format: (session_id, messages, is_blob, monotonic_timestamp) @@ -1694,31 +1701,33 @@ def track_create_event(**kwargs): batching_session_manager._flush_messages() # Should be TWO API calls - one per session - assert mock_memory_client.create_event.call_count == 2 + assert mock_memory_client.gmdp_client.create_event.call_count == 2 assert len(calls_by_session) == 2 # Session A should have 3 messages combined assert "session-A" in calls_by_session assert len(calls_by_session["session-A"]) == 3 - assert calls_by_session["session-A"][0] == ("SessionA_Message_0", "user") - assert calls_by_session["session-A"][1] == ("SessionA_Message_1", "user") - assert calls_by_session["session-A"][2] == ("SessionA_Message_2", "user") + assert calls_by_session["session-A"][0]["conversational"]["content"]["text"] == "SessionA_Message_0" + assert calls_by_session["session-A"][1]["conversational"]["content"]["text"] == "SessionA_Message_1" + assert calls_by_session["session-A"][2]["conversational"]["content"]["text"] == "SessionA_Message_2" # Session B should have 3 messages combined assert "session-B" in calls_by_session assert len(calls_by_session["session-B"]) == 3 for i in range(3): - assert calls_by_session["session-B"][i] == (f"SessionB_Message_{i}", "user") + assert calls_by_session["session-B"][i]["conversational"]["content"]["text"] == f"SessionB_Message_{i}" - def test_latest_timestamp_used_for_combined_events(self, batching_session_manager, mock_memory_client): + def test__flush_messages_latest_timestamp_used_for_combined_events( + self, batching_session_manager, mock_memory_client + ): """Test that the latest timestamp from grouped messages is used for the combined event.""" captured_timestamps = [] def track_create_event(**kwargs): - captured_timestamps.append(kwargs.get("event_timestamp")) + captured_timestamps.append(kwargs.get("eventTimestamp")) return {"eventId": "event_123"} - mock_memory_client.create_event.side_effect = track_create_event + mock_memory_client.gmdp_client.create_event.side_effect = track_create_event # Add messages with different timestamps (out of order) timestamps = ["2024-01-01T12:05:00Z", "2024-01-01T12:01:00Z", "2024-01-01T12:10:00Z"] @@ -1741,7 +1750,9 @@ def track_create_event(**kwargs): # Account for monotonic timestamp adjustment (may add microseconds) assert captured_timestamps[0] >= expected_latest - def test_partial_failure_multiple_sessions_preserves_buffer(self, batching_session_manager, mock_memory_client): + def test__flush_messages_partial_failure_multiple_sessions_preserves_buffer( + self, batching_session_manager, mock_memory_client + ): """Test that when one session fails, ALL messages remain in buffer. Note: Tests internal grouping logic by directly manipulating buffer. @@ -1749,12 +1760,12 @@ def test_partial_failure_multiple_sessions_preserves_buffer(self, batching_sessi from datetime import datetime, timezone def fail_on_second_session(**kwargs): - session_id = kwargs.get("session_id") + session_id = kwargs.get("sessionId") if session_id == "session-B": raise Exception("API Error for session B") return {"eventId": f"event_{session_id}"} - mock_memory_client.create_event.side_effect = fail_on_second_session + mock_memory_client.gmdp_client.create_event.side_effect = fail_on_second_session # Directly populate buffer with messages for multiple sessions base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) @@ -1775,8 +1786,8 @@ def fail_on_second_session(**kwargs): # This is because buffer is only cleared after ALL succeed assert batching_session_manager.pending_message_count() == 4 - def test_blob_messages_sent_individually_not_batched(self, batching_session_manager, mock_memory_client): - """Test that multiple blob messages are sent as individual API calls, not batched.""" + def test_blob_messages_sent_batched(self, batching_session_manager, mock_memory_client): + """Test that multiple blob messages are sent as batched.""" blob_calls = [] def track_blob_event(**kwargs): @@ -1784,7 +1795,6 @@ def track_blob_event(**kwargs): return {"event": {"eventId": f"blob_event_{len(blob_calls)}"}} mock_memory_client.gmdp_client.create_event.side_effect = track_blob_event - mock_memory_client.create_event.return_value = {"eventId": "conv_event"} # Add multiple blob messages (>9KB each) for i in range(3): @@ -1798,15 +1808,17 @@ def track_blob_event(**kwargs): batching_session_manager._flush_messages() - # Each blob should be sent individually (3 separate API calls) - assert mock_memory_client.gmdp_client.create_event.call_count == 3 - assert len(blob_calls) == 3 + # Blobs are now batched together in one call with multiple payloads + assert mock_memory_client.gmdp_client.create_event.call_count == 1 + assert len(blob_calls) == 1 - # Verify each blob was sent separately with correct content - for i, call in enumerate(blob_calls): - assert "payload" in call - assert "blob" in call["payload"][0] - assert f"blob_{i}_" in call["payload"][0]["blob"] + # Verify the batched call contains all 3 blobs + call = blob_calls[0] + assert "payload" in call + assert len(call["payload"]) == 3 + for i in range(3): + assert "blob" in call["payload"][i] + assert f"blob_{i}_" in call["payload"][i]["blob"] def test_mixed_sessions_with_blobs_and_conversational(self, batching_session_manager, mock_memory_client): """Test complex scenario: multiple sessions with both blob and conversational messages. @@ -1815,20 +1827,15 @@ def test_mixed_sessions_with_blobs_and_conversational(self, batching_session_man """ from datetime import datetime, timezone - conv_calls = {} - blob_calls = [] - - def track_conv_event(**kwargs): - session_id = kwargs.get("session_id") - conv_calls[session_id] = kwargs.get("messages") - return {"eventId": f"conv_event_{session_id}"} + calls_by_session = {} - def track_blob_event(**kwargs): - blob_calls.append(kwargs) - return {"event": {"eventId": f"blob_event_{len(blob_calls)}"}} + def track_create_event(**kwargs): + session_id = kwargs.get("sessionId") + payload = kwargs.get("payload") + calls_by_session[session_id] = payload + return {"eventId": f"event_{session_id}"} - mock_memory_client.create_event.side_effect = track_conv_event - mock_memory_client.gmdp_client.create_event.side_effect = track_blob_event + mock_memory_client.gmdp_client.create_event.side_effect = track_create_event # Directly populate buffer with mixed messages base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) @@ -1845,23 +1852,278 @@ def track_blob_event(**kwargs): batching_session_manager._flush_messages() - # Should have: - # - 2 conversational API calls (one per session) - # - 1 blob API call - assert mock_memory_client.create_event.call_count == 2 + # Should have 2 gmdp_client.create_event calls (one per session) + # Each session combines conversational and blob messages + assert mock_memory_client.gmdp_client.create_event.call_count == 2 + + # Session A should have 3 items in payload (2 conversational + 1 blob) + assert "session-A" in calls_by_session + assert len(calls_by_session["session-A"]) == 3 + + # Session B should have 1 conversational message + assert "session-B" in calls_by_session + assert len(calls_by_session["session-B"]) == 1 + + def test__flush_messages_calls_both_flush_methods(self, batching_session_manager, mock_memory_client): + """Test that _flush_messages() calls both _flush_messages_only() and _flush_agent_states_only().""" + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} + + # Add messages + for i in range(2): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # Add agent state + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Test agent"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + # Verify both buffers have content + assert batching_session_manager.pending_message_count() == 2 + assert batching_session_manager.pending_agent_state_count() == 1 + + # Flush all + results = batching_session_manager._flush_messages() + + # Should have 2 API calls: 1 for messages + 1 for agent states + assert mock_memory_client.gmdp_client.create_event.call_count == 2 + assert len(results) == 2 + + # Both buffers should be cleared + assert batching_session_manager.pending_message_count() == 0 + assert batching_session_manager.pending_agent_state_count() == 0 + + def test__flush_messages_with_only_messages(self, batching_session_manager, mock_memory_client): + """Test that _flush_messages() works when only messages are buffered.""" + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} + + # Add only messages (no agent states) + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + assert batching_session_manager.pending_agent_state_count() == 0 + + # Flush all + results = batching_session_manager._flush_messages() + + # Should have 1 API call for messages only + assert mock_memory_client.gmdp_client.create_event.call_count == 1 + assert len(results) == 1 + assert batching_session_manager.pending_message_count() == 0 + + def test__flush_messages_with_only_agent_states(self, batching_session_manager, mock_memory_client): + """Test that _flush_messages() works when only agent states are buffered.""" + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} + + # Add only agent states (no messages) + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Test agent"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + agent.state["description"] = "Updated" + batching_session_manager.update_agent("test-session-456", agent) + + assert batching_session_manager.pending_message_count() == 0 + assert batching_session_manager.pending_agent_state_count() == 2 + + # Flush all + results = batching_session_manager._flush_messages() + + # Should have 1 API call for agent states only assert mock_memory_client.gmdp_client.create_event.call_count == 1 + assert len(results) == 1 + assert batching_session_manager.pending_agent_state_count() == 0 - # Session A conversational messages should be batched together - assert "session-A" in conv_calls - assert len(conv_calls["session-A"]) == 2 + def test__flush_agent_states_only_empty_buffer(self, batching_session_manager): + """Test _flush_agent_states_only with empty buffer returns empty list.""" + results = batching_session_manager._flush_agent_states_only() + assert results == [] - # Session B conversational message - assert "session-B" in conv_calls - assert len(conv_calls["session-B"]) == 1 + def test__flush_agent_states_only_sends_all_buffered(self, batching_session_manager, mock_memory_client): + """Test _flush_agent_states_only sends all buffered agent states in a single batched call.""" + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} - # Blob sent separately - assert len(blob_calls) == 1 - assert "blob_A_" in blob_calls[0]["payload"][0]["blob"] + # Create agent + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Initial"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + # Update agent state twice + for i in range(2): + agent.state["description"] = f"Updated {i}" + batching_session_manager.update_agent("test-session-456", agent) + + # Should have 3 agent states: 1 create + 2 updates + assert batching_session_manager.pending_agent_state_count() == 3 + + # Flush agent states only + results = batching_session_manager._flush_agent_states_only() + + # One batched API call for all agent states + assert len(results) == 1 + assert batching_session_manager.pending_agent_state_count() == 0 + assert mock_memory_client.gmdp_client.create_event.call_count == 1 + + # Verify the call had metadata for agent state + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + assert "metadata" in call_kwargs + assert "stateType" in call_kwargs["metadata"] + + def test__flush_agent_states_only_preserves_messages(self, batching_session_manager, mock_memory_client): + """Test _flush_agent_states_only preserves message buffer.""" + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} + + # Add messages + for i in range(2): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # Add agent state + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Test agent"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + assert batching_session_manager.pending_message_count() == 2 + assert batching_session_manager.pending_agent_state_count() == 1 + + # Flush only agent states + batching_session_manager._flush_agent_states_only() + + # Agent states should be flushed, messages should remain + assert batching_session_manager.pending_agent_state_count() == 0 + assert batching_session_manager.pending_message_count() == 2 + assert mock_memory_client.gmdp_client.create_event.call_count == 1 + + def test__flush_agent_states_only_clears_buffer(self, batching_session_manager, mock_memory_client): + """Test _flush_agent_states_only clears the agent state buffer after sending.""" + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} + + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Test agent"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + # First flush + batching_session_manager._flush_agent_states_only() + assert batching_session_manager.pending_agent_state_count() == 0 + + # Second flush should be no-op + results = batching_session_manager._flush_agent_states_only() + assert results == [] + + def test__flush_agent_states_only_exception_handling(self, batching_session_manager, mock_memory_client): + """Test _flush_agent_states_only raises SessionException on failure.""" + mock_memory_client.gmdp_client.create_event.side_effect = Exception("API Error") + + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Test agent"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + with pytest.raises(SessionException, match="Failed to flush agent states"): + batching_session_manager._flush_agent_states_only() + + def test__flush_agent_states_only_failure_preserves_agent_states( + self, batching_session_manager, mock_memory_client + ): + """Test that on flush failure, all agent states remain in buffer to prevent data loss.""" + mock_memory_client.gmdp_client.create_event.side_effect = Exception("API Error") + + # Create agent and update twice + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Initial"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + agent.state["description"] = "Updated 1" + batching_session_manager.update_agent("test-session-456", agent) + + agent.state["description"] = "Updated 2" + batching_session_manager.update_agent("test-session-456", agent) + + assert batching_session_manager.pending_agent_state_count() == 3 + + # Flush should fail + with pytest.raises(SessionException): + batching_session_manager._flush_agent_states_only() + + # All agent states should still be in buffer (not cleared on failure) + assert batching_session_manager.pending_agent_state_count() == 3 + + # Fix the mock and retry - should succeed now + mock_memory_client.gmdp_client.create_event.side_effect = None + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} + + results = batching_session_manager._flush_agent_states_only() + assert len(results) == 1 + assert batching_session_manager.pending_agent_state_count() == 0 + + def test__flush_agent_states_only_batches_multiple_states(self, batching_session_manager, mock_memory_client): + """Test that multiple agent states are batched into a single API call.""" + sent_payloads = [] + + def track_create_event(**kwargs): + sent_payloads.append(kwargs.get("payload")) + return {"eventId": f"event_{len(sent_payloads)}"} + + mock_memory_client.gmdp_client.create_event.side_effect = track_create_event + + # Create agent and update 4 times + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Initial"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + for i in range(4): + agent.state["description"] = f"Updated {i}" + batching_session_manager.update_agent("test-session-456", agent) + + # Should have 5 agent states: 1 create + 4 updates + assert batching_session_manager.pending_agent_state_count() == 5 + + batching_session_manager._flush_agent_states_only() + + # Should be ONE API call with all 5 agent states combined + assert mock_memory_client.gmdp_client.create_event.call_count == 1 + assert len(sent_payloads) == 1 + # The combined payload should have all 5 agent states as blobs + assert len(sent_payloads[0]) == 5 + for item in sent_payloads[0]: + assert "blob" in item class TestBatchingBackwardsCompatibility: @@ -1912,7 +2174,7 @@ def test_context_manager_returns_self(self, batching_session_manager): def test_context_manager_flushes_on_exit(self, batching_session_manager, mock_memory_client): """Test __exit__ flushes pending messages.""" - mock_memory_client.create_event.return_value = {"eventId": "event_123"} + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} with batching_session_manager: message = SessionMessage( @@ -1927,11 +2189,11 @@ def test_context_manager_flushes_on_exit(self, batching_session_manager, mock_me # After exiting context, should have flushed assert batching_session_manager.pending_message_count() == 0 - mock_memory_client.create_event.assert_called_once() + mock_memory_client.gmdp_client.create_event.assert_called_once() def test_context_manager_flushes_on_exception(self, batching_session_manager, mock_memory_client): """Test __exit__ flushes even when exception occurs.""" - mock_memory_client.create_event.return_value = {"eventId": "event_123"} + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} try: with batching_session_manager: @@ -1947,13 +2209,13 @@ def test_context_manager_flushes_on_exception(self, batching_session_manager, mo # Should have flushed despite exception assert batching_session_manager.pending_message_count() == 0 - mock_memory_client.create_event.assert_called_once() + mock_memory_client.gmdp_client.create_event.assert_called_once() def test_exit_preserves_original_exception_when_flush_fails( self, batching_session_manager, mock_memory_client, caplog ): """Test __exit__ logs flush failure and preserves the original exception.""" - mock_memory_client.create_event.side_effect = RuntimeError("flush failed") + mock_memory_client.gmdp_client.create_event.side_effect = RuntimeError("flush failed") with caplog.at_level(logging.ERROR): with pytest.raises(ValueError, match="original error"): @@ -1975,7 +2237,7 @@ def test_exit_raises_flush_exception_when_no_original_exception( self, batching_session_manager, mock_memory_client, caplog ): """Test __exit__ still raises flush exceptions when no original exception.""" - mock_memory_client.create_event.side_effect = RuntimeError("flush failed") + mock_memory_client.gmdp_client.create_event.side_effect = RuntimeError("flush failed") with caplog.at_level(logging.ERROR): with pytest.raises(SessionException, match="flush failed"): @@ -1997,7 +2259,7 @@ class TestBatchingClose: def test_close_flushes_pending_messages(self, batching_session_manager, mock_memory_client): """Test close() flushes all pending messages in a batched call.""" - mock_memory_client.create_event.return_value = {"eventId": "event_123"} + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"} # Add messages for i in range(3): @@ -2015,7 +2277,7 @@ def test_close_flushes_pending_messages(self, batching_session_manager, mock_mem assert batching_session_manager.pending_message_count() == 0 # One batched API call for all messages in the same session - assert mock_memory_client.create_event.call_count == 1 + assert mock_memory_client.gmdp_client.create_event.call_count == 1 def test_close_with_empty_buffer(self, batching_session_manager, mock_memory_client): """Test close() with empty buffer is a no-op.""" @@ -2051,8 +2313,7 @@ def test_blob_message_sent_via_gmdp_client(self, batching_session_manager, mock_ def test_mixed_conversational_and_blob_messages(self, batching_session_manager, mock_memory_client): """Test batching correctly handles mix of conversational and blob messages.""" - mock_memory_client.create_event.return_value = {"eventId": "conv_event"} - mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "blob_event"}} + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "conv_event"} # Add small (conversational) message small_message = SessionMessage( @@ -2074,9 +2335,8 @@ def test_mixed_conversational_and_blob_messages(self, batching_session_manager, # Flush batching_session_manager._flush_messages() - # Verify both paths were used - assert mock_memory_client.create_event.call_count == 1 # Conversational - assert mock_memory_client.gmdp_client.create_event.call_count == 1 # Blob + # Both messages should be sent via gmdp_client.create_event (batched together) + assert mock_memory_client.gmdp_client.create_event.call_count == 1 class TestThinkingModeCompatibility: