@@ -120,7 +120,11 @@ def __init__(
120120
121121 # Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp)
122122 self ._message_buffer : list [tuple [str , list [tuple [str , str ]], bool , datetime ]] = []
123- self ._buffer_lock = threading .Lock ()
123+ self ._message_lock = threading .Lock ()
124+
125+ # Agent state buffering - stores all agent state updates: (session_id, agent)
126+ self ._agent_state_buffer : list [tuple [str , SessionAgent ]] = []
127+ self ._agent_state_lock = threading .Lock ()
124128
125129 # Cache for agent created_at timestamps to avoid fetching on every update
126130 self ._agent_created_at_cache : dict [str , datetime ] = {}
@@ -397,8 +401,14 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
397401 raise SessionException (f"Agent { agent_id } in session { session_id } does not exist" )
398402 session_agent .created_at = self ._agent_created_at_cache [agent_id ]
399403
400- # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
401- self .create_agent (session_id , session_agent )
404+ if self .config .batch_size > 1 :
405+ # Buffer the agent state update
406+ with self ._agent_state_lock :
407+ self ._agent_state_buffer .append ((session_id , session_agent ))
408+ else :
409+ # Immediate send create_event without buffering
410+ # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
411+ self .create_agent (session_id , session_agent )
402412
403413 def create_message (
404414 self , session_id : str , agent_id : str , session_message : SessionMessage , ** kwargs : Any
@@ -452,7 +462,7 @@ def create_message(
452462 if self .config .batch_size > 1 :
453463 # Buffer the pre-processed message
454464 should_flush = False
455- with self ._buffer_lock :
465+ with self ._message_lock :
456466 self ._message_buffer .append ((session_id , messages , is_blob , monotonic_timestamp ))
457467 should_flush = len (self ._message_buffer ) >= self .config .batch_size
458468
@@ -702,27 +712,31 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
702712 # region Batching support
703713
704714 def _flush_messages (self ) -> list [dict [str , Any ]]:
705- """Flush all buffered messages to AgentCore Memory.
715+ """Flush all buffered messages and agent state to AgentCore Memory.
706716
707- Call this method to send any remaining buffered messages when batch_size > 1.
717+ Call this method to send any remaining buffered messages and agent state when batch_size > 1.
708718 This is automatically called when the buffer reaches batch_size, but should
709719 also be called explicitly when the session is complete (via close() or context manager).
710720
711721 Messages are batched by session_id - all conversational messages for the same
712722 session are combined into a single create_event() call to reduce API calls.
713723 Blob messages (>9KB) are sent individually as they require a different API path.
724+ Agent state updates are sent after messages.
714725
715726 Returns:
716727 list[dict[str, Any]]: List of created event responses from AgentCore Memory.
717728
718729 Raises:
719- SessionException: If any message creation fails. On failure, all messages
720- remain in the buffer to prevent data loss.
730+ SessionException: If any message or agent state creation fails. On failure, all messages
731+ and agent state remain in the buffer to prevent data loss.
721732 """
722- with self ._buffer_lock :
733+ with self ._message_lock :
723734 messages_to_send = list (self ._message_buffer )
724735
725- if not messages_to_send :
736+ with self ._agent_state_lock :
737+ agent_states_to_send = list (self ._agent_state_buffer )
738+
739+ if not messages_to_send and not agent_states_to_send :
726740 return []
727741
728742 # Group conversational messages by session_id, preserve order
@@ -772,13 +786,39 @@ def _flush_messages(self) -> list[dict[str, Any]]:
772786 results .append (event )
773787 logger .debug ("Flushed blob event for session %s: %s" , session_id , event .get ("eventId" ))
774788
775- # Clear buffer only after ALL messages succeed
776- with self ._buffer_lock :
789+ # Flush agent state updates after messages - batch all agent states into a single API call
790+ if agent_states_to_send :
791+ # Convert all agent states to payload format
792+ agent_state_payloads = []
793+ for _session_id , session_agent in agent_states_to_send :
794+ agent_state_payloads .append ({"blob" : json .dumps (session_agent .to_dict ())})
795+
796+ # Send all agent states in a single batched create_event call
797+ event = self .memory_client .gmdp_client .create_event (
798+ memoryId = self .config .memory_id ,
799+ actorId = self .config .actor_id ,
800+ sessionId = self .config .session_id ,
801+ payload = agent_state_payloads ,
802+ eventTimestamp = self ._get_monotonic_timestamp (),
803+ metadata = {
804+ STATE_TYPE_KEY : {"stringValue" : StateType .AGENT .value },
805+ },
806+ )
807+ results .append (event )
808+ logger .debug (
809+ "Flushed %d agent states in batched event: %s" , len (agent_states_to_send ), event .get ("eventId" )
810+ )
811+
812+ # Clear buffers only after ALL messages and agent state succeed
813+ with self ._message_lock :
777814 self ._message_buffer .clear ()
778815
816+ with self ._agent_state_lock :
817+ self ._agent_state_buffer .clear ()
818+
779819 except Exception as e :
780- logger .error ("Failed to flush messages to AgentCore Memory for session : %s" , e )
781- raise SessionException (f"Failed to flush messages: { e } " ) from e
820+ logger .error ("Failed to flush messages and agent state to AgentCore Memory: %s" , e )
821+ raise SessionException (f"Failed to flush messages and agent state : { e } " ) from e
782822
783823 logger .info ("Flushed %d events to AgentCore Memory" , len (results ))
784824 return results
@@ -789,9 +829,18 @@ def pending_message_count(self) -> int:
789829 Returns:
790830 int: Number of buffered messages waiting to be sent.
791831 """
792- with self ._buffer_lock :
832+ with self ._message_lock :
793833 return len (self ._message_buffer )
794834
835+ def pending_agent_state_count (self ) -> int :
836+ """Return the number of agent states pending in the buffer.
837+
838+ Returns:
839+ int: Number of buffered agent states waiting to be sent.
840+ """
841+ with self ._agent_state_lock :
842+ return len (self ._agent_state_buffer )
843+
795844 def close (self ) -> None :
796845 """Explicitly flush pending messages and close the session manager.
797846
@@ -860,16 +909,21 @@ def _start_flush_timer(self) -> None:
860909 def _interval_flush_callback (self ) -> None :
861910 """Callback executed by the flush timer.
862911
863- Flushes the buffer if it contains messages, then reschedules the timer.
912+ Flushes the buffer if it contains messages or agent states , then reschedules the timer.
864913 """
865914 try :
866- # Only flush if there are messages in the buffer
867- pending = self .pending_message_count ()
868- if pending > 0 :
869- logger .debug ("Interval flush triggered: %d message(s) pending" , pending )
915+ # Only flush if there are messages or agent states in the buffer
916+ pending_messages = self .pending_message_count ()
917+ pending_agent_states = self .pending_agent_state_count ()
918+ if pending_messages > 0 or pending_agent_states > 0 :
919+ logger .debug (
920+ "Interval flush triggered: %d message(s) and %d agent state(s) pending" ,
921+ pending_messages ,
922+ pending_agent_states ,
923+ )
870924 self ._flush_messages ()
871925 else :
872- logger .debug ("Interval flush skipped: buffer is empty" )
926+ logger .debug ("Interval flush skipped: buffers are empty" )
873927
874928 # Reschedule the timer (unless shutdown)
875929 if not self ._shutdown and self .config .flush_interval_seconds :
0 commit comments