diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index d6ad38561..0004f1891 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -60,6 +60,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -532,12 +533,34 @@ private Flowable runAgentWithUpdatedSession( contextWithUpdatedSession .agent() .runAsync(contextWithUpdatedSession) + .map( + agentEvent -> { + // We create a temporary shallow copy of the session to pass to the persistence + // service. + // This copy is created BEFORE we add the agentEvent to the in-memory session. + Session sessionForService = + Session.builder(updatedSession.id()) + .appName(updatedSession.appName()) + .userId(updatedSession.userId()) + .state(new HashMap<>(updatedSession.state())) + .events(new ArrayList<>(updatedSession.events())) + .build(); + + // Unblock the in-memory session synchronously as soon as the event is emitted! + // This allows the agent's internal loop (llmFlow) to see the event immediately + // for its next turn without waiting for previous DB writes to complete. + updatedSession.events().add(agentEvent); + + return new EventWithSession(sessionForService, agentEvent); + }) .concatMap( - agentEvent -> + wrapper -> this.sessionService - .appendEvent(updatedSession, agentEvent) + .appendEvent(wrapper.sessionForService(), wrapper.event()) .flatMap( registeredEvent -> { + // Sync state changes back from isolated copy to our primary session + copySessionStates(wrapper.sessionForService(), updatedSession); // TODO: remove this hack after deprecating runAsync with Session. copySessionStates(updatedSession, initialContext.session()); return contextWithUpdatedSession @@ -804,5 +827,8 @@ private static EventsCompactionConfig createEventsCompactionConfig( config.eventRetentionSize()); } + /** A record to wrap the isolated session and the event for sequential persistence. */ + private static record EventWithSession(Session sessionForService, Event event) {} + // TODO: run statelessly } diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 36530faf2..be8755542 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -35,6 +35,7 @@ import static org.mockito.Mockito.when; import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; @@ -614,6 +615,77 @@ public void callbackContextData_preservedAcrossInvocation() { assertThat(contextCaptor.getValue().callbackContextData()).containsEntry(testKey, testValue); } + @Test + public void runAsync_duringMultiTurnExecution_emittedEventsAreVisibleInSubsequentTurn() { + // Setup LLM to return a function call, and then a final response + TestLlm testLlmForRace = + createTestLlm( + createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.builder() + .functionCall( + FunctionCall.builder() + .name(echoTool.name()) + .args(ImmutableMap.of("args_name", "args_value")) + .build()) + .build()) + .build()), + createLlmResponse(createContent("done"))); + + LlmAgent agentForRace = + createTestAgentBuilder(testLlmForRace).tools(ImmutableList.of(echoTool)).build(); + + Runner runnerForRace = + Runner.builder() + .app( + App.builder() + .name("test") + .rootAgent(agentForRace) + .plugins(ImmutableList.of(plugin)) + .build()) + .build(); + + Session sessionForRace = + runnerForRace.sessionService().createSession("test", "user").blockingGet(); + + // Use a mock plugin to check session events in beforeModelCallback + // It should be called for the second turn (after the function call) + AtomicInteger callCount = new AtomicInteger(0); + when(plugin.beforeModelCallback(any(), any())) + .thenAnswer( + invocation -> { + CallbackContext context = invocation.getArgument(0); + int count = callCount.incrementAndGet(); + if (count == 2) { + // This is the second turn, after the function call + // Check if the session contains the function call event + List events = context.events(); + boolean hasFunctionCall = + events.stream() + .flatMap( + e -> + e + .content() + .flatMap(Content::parts) + .orElse(ImmutableList.of()) + .stream()) + .anyMatch(p -> p.functionCall().isPresent()); + assertThat(hasFunctionCall).isTrue(); + } + return Maybe.empty(); + }); + + var unused = + runnerForRace + .runAsync("user", sessionForRace.id(), createContent("start")) + .toList() + .blockingGet(); + + assertThat(callCount.get()).isEqualTo(2); + } + @Test public void runAsync_withSessionKey_success() { var events =