diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index c85ac6ff94..7d82069b60 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -474,7 +474,11 @@ async def _run_node_async( with tracer.start_as_current_span('invocation'): # 1. Setup session = await self._get_or_create_session( - user_id=user_id, session_id=session_id + user_id=user_id, + session_id=session_id, + get_session_config=run_config.get_session_config + if run_config + else None, ) # Validate and resolve resume inputs @@ -1000,7 +1004,9 @@ async def run_async( if self.agent.mode == 'chat': session = await self._get_or_create_session( - user_id=user_id, session_id=session_id + user_id=user_id, + session_id=session_id, + get_session_config=run_config.get_session_config, ) # when the chat coordinator has task-mode sub-agents, # the wrapper handles delegation via ctx.run_node. Don't let diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 22f6ac55f4..10b2c52b15 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -1695,9 +1695,28 @@ async def test_run_async_passes_get_session_config(): ), ) + events_seen_by_agent = [] + + class EventCheckingAgent(BaseAgent): + + def __init__(self, name: str): + super().__init__(name=name, sub_agents=[]) + + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + events_seen_by_agent.extend(invocation_context.session.events) + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ), + ) + runner = Runner( app_name=TEST_APP_ID, - agent=MockAgent("test_agent"), + agent=EventCheckingAgent("test_agent"), session_service=session_service, artifact_service=InMemoryArtifactService(), ) @@ -1720,6 +1739,13 @@ async def test_run_async_passes_get_session_config(): assert len(events) >= 1 assert events[0].author == "test_agent" + # The agent should have only seen 3 historical events + 1 new message = 4 events. + assert len(events_seen_by_agent) == 4 + assert events_seen_by_agent[0].invocation_id == "inv_7" + assert events_seen_by_agent[1].invocation_id == "inv_8" + assert events_seen_by_agent[2].invocation_id == "inv_9" + assert events_seen_by_agent[3].content.parts[0].text == "hello" + @pytest.mark.asyncio async def test_run_async_teardown_on_aclose():