Skip to content

Commit 9e865da

Browse files
authored
feat: add buffering for agent state events (#295)
1 parent bbc00a7 commit 9e865da

2 files changed

Lines changed: 313 additions & 23 deletions

File tree

src/bedrock_agentcore/memory/integrations/strands/session_manager.py

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def __init__(
120120

121121
# Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp)
122122
self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = []
123-
self._buffer_lock = threading.Lock()
123+
self._message_lock = threading.Lock()
124+
125+
# Agent state buffering - stores all agent state updates: (session_id, agent)
126+
self._agent_state_buffer: list[tuple[str, SessionAgent]] = []
127+
self._agent_state_lock = threading.Lock()
124128

125129
# Cache for agent created_at timestamps to avoid fetching on every update
126130
self._agent_created_at_cache: dict[str, datetime] = {}
@@ -397,8 +401,14 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
397401
raise SessionException(f"Agent {agent_id} in session {session_id} does not exist")
398402
session_agent.created_at = self._agent_created_at_cache[agent_id]
399403

400-
# Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
401-
self.create_agent(session_id, session_agent)
404+
if self.config.batch_size > 1:
405+
# Buffer the agent state update
406+
with self._agent_state_lock:
407+
self._agent_state_buffer.append((session_id, session_agent))
408+
else:
409+
# Immediate send create_event without buffering
410+
# Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
411+
self.create_agent(session_id, session_agent)
402412

403413
def create_message(
404414
self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any
@@ -452,7 +462,7 @@ def create_message(
452462
if self.config.batch_size > 1:
453463
# Buffer the pre-processed message
454464
should_flush = False
455-
with self._buffer_lock:
465+
with self._message_lock:
456466
self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp))
457467
should_flush = len(self._message_buffer) >= self.config.batch_size
458468

@@ -702,27 +712,31 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
702712
# region Batching support
703713

704714
def _flush_messages(self) -> list[dict[str, Any]]:
705-
"""Flush all buffered messages to AgentCore Memory.
715+
"""Flush all buffered messages and agent state to AgentCore Memory.
706716
707-
Call this method to send any remaining buffered messages when batch_size > 1.
717+
Call this method to send any remaining buffered messages and agent state when batch_size > 1.
708718
This is automatically called when the buffer reaches batch_size, but should
709719
also be called explicitly when the session is complete (via close() or context manager).
710720
711721
Messages are batched by session_id - all conversational messages for the same
712722
session are combined into a single create_event() call to reduce API calls.
713723
Blob messages (>9KB) are sent individually as they require a different API path.
724+
Agent state updates are sent after messages.
714725
715726
Returns:
716727
list[dict[str, Any]]: List of created event responses from AgentCore Memory.
717728
718729
Raises:
719-
SessionException: If any message creation fails. On failure, all messages
720-
remain in the buffer to prevent data loss.
730+
SessionException: If any message or agent state creation fails. On failure, all messages
731+
and agent state remain in the buffer to prevent data loss.
721732
"""
722-
with self._buffer_lock:
733+
with self._message_lock:
723734
messages_to_send = list(self._message_buffer)
724735

725-
if not messages_to_send:
736+
with self._agent_state_lock:
737+
agent_states_to_send = list(self._agent_state_buffer)
738+
739+
if not messages_to_send and not agent_states_to_send:
726740
return []
727741

728742
# Group conversational messages by session_id, preserve order
@@ -772,13 +786,39 @@ def _flush_messages(self) -> list[dict[str, Any]]:
772786
results.append(event)
773787
logger.debug("Flushed blob event for session %s: %s", session_id, event.get("eventId"))
774788

775-
# Clear buffer only after ALL messages succeed
776-
with self._buffer_lock:
789+
# Flush agent state updates after messages - batch all agent states into a single API call
790+
if agent_states_to_send:
791+
# Convert all agent states to payload format
792+
agent_state_payloads = []
793+
for _session_id, session_agent in agent_states_to_send:
794+
agent_state_payloads.append({"blob": json.dumps(session_agent.to_dict())})
795+
796+
# Send all agent states in a single batched create_event call
797+
event = self.memory_client.gmdp_client.create_event(
798+
memoryId=self.config.memory_id,
799+
actorId=self.config.actor_id,
800+
sessionId=self.config.session_id,
801+
payload=agent_state_payloads,
802+
eventTimestamp=self._get_monotonic_timestamp(),
803+
metadata={
804+
STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value},
805+
},
806+
)
807+
results.append(event)
808+
logger.debug(
809+
"Flushed %d agent states in batched event: %s", len(agent_states_to_send), event.get("eventId")
810+
)
811+
812+
# Clear buffers only after ALL messages and agent state succeed
813+
with self._message_lock:
777814
self._message_buffer.clear()
778815

816+
with self._agent_state_lock:
817+
self._agent_state_buffer.clear()
818+
779819
except Exception as e:
780-
logger.error("Failed to flush messages to AgentCore Memory for session: %s", e)
781-
raise SessionException(f"Failed to flush messages: {e}") from e
820+
logger.error("Failed to flush messages and agent state to AgentCore Memory: %s", e)
821+
raise SessionException(f"Failed to flush messages and agent state: {e}") from e
782822

783823
logger.info("Flushed %d events to AgentCore Memory", len(results))
784824
return results
@@ -789,9 +829,18 @@ def pending_message_count(self) -> int:
789829
Returns:
790830
int: Number of buffered messages waiting to be sent.
791831
"""
792-
with self._buffer_lock:
832+
with self._message_lock:
793833
return len(self._message_buffer)
794834

835+
def pending_agent_state_count(self) -> int:
836+
"""Return the number of agent states pending in the buffer.
837+
838+
Returns:
839+
int: Number of buffered agent states waiting to be sent.
840+
"""
841+
with self._agent_state_lock:
842+
return len(self._agent_state_buffer)
843+
795844
def close(self) -> None:
796845
"""Explicitly flush pending messages and close the session manager.
797846
@@ -860,16 +909,21 @@ def _start_flush_timer(self) -> None:
860909
def _interval_flush_callback(self) -> None:
861910
"""Callback executed by the flush timer.
862911
863-
Flushes the buffer if it contains messages, then reschedules the timer.
912+
Flushes the buffer if it contains messages or agent states, then reschedules the timer.
864913
"""
865914
try:
866-
# Only flush if there are messages in the buffer
867-
pending = self.pending_message_count()
868-
if pending > 0:
869-
logger.debug("Interval flush triggered: %d message(s) pending", pending)
915+
# Only flush if there are messages or agent states in the buffer
916+
pending_messages = self.pending_message_count()
917+
pending_agent_states = self.pending_agent_state_count()
918+
if pending_messages > 0 or pending_agent_states > 0:
919+
logger.debug(
920+
"Interval flush triggered: %d message(s) and %d agent state(s) pending",
921+
pending_messages,
922+
pending_agent_states,
923+
)
870924
self._flush_messages()
871925
else:
872-
logger.debug("Interval flush skipped: buffer is empty")
926+
logger.debug("Interval flush skipped: buffers are empty")
873927

874928
# Reschedule the timer (unless shutdown)
875929
if not self._shutdown and self.config.flush_interval_seconds:

0 commit comments

Comments
 (0)