diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 48d141819..255d85084 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -461,14 +461,31 @@ public Flowable run(InvocationContext invocationContext) { private Flowable run( Context spanContext, InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(spanContext, invocationContext).cache(); + Flowable currentStepEvents = runOneStep(spanContext, invocationContext); + + Flowable processedEvents = + currentStepEvents + .concatMap( + event -> + invocationContext + .sessionService() + .appendEvent(invocationContext.session(), event) + .flatMap( + registeredEvent -> + invocationContext + .pluginManager() + .onEventCallback(invocationContext, registeredEvent) + .defaultIfEmpty(registeredEvent)) + .toFlowable()) + .cache(); + if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); - return currentStepEvents; + return processedEvents; } - return currentStepEvents.concatWith( - currentStepEvents + return processedEvents.concatWith( + processedEvents .toList() .flatMapPublisher( eventList -> { diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 44a281f72..d72330e7b 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -565,24 +565,29 @@ private Flowable runAgentWithUpdatedSession( .build()); // Agent execution - Flowable agentEvents = - contextWithUpdatedSession - .agent() - .runAsync(contextWithUpdatedSession) - .concatMap( - agentEvent -> - this.sessionService - .appendEvent(updatedSession, agentEvent) - .flatMap( - registeredEvent -> { - // TODO: remove this hack after deprecating runAsync with Session. - copySessionStates(updatedSession, initialContext.session()); - return contextWithUpdatedSession - .pluginManager() - .onEventCallback(contextWithUpdatedSession, registeredEvent) - .defaultIfEmpty(registeredEvent); - }) - .toFlowable()); + Flowable agentEvents; + if (contextWithUpdatedSession.agent() instanceof LlmAgent) { + agentEvents = contextWithUpdatedSession.agent().runAsync(contextWithUpdatedSession); + } else { + agentEvents = + contextWithUpdatedSession + .agent() + .runAsync(contextWithUpdatedSession) + .concatMap( + agentEvent -> + this.sessionService + .appendEvent(updatedSession, agentEvent) + .flatMap( + registeredEvent -> { + // TODO: remove this hack after deprecating runAsync with Session. + copySessionStates(updatedSession, initialContext.session()); + return contextWithUpdatedSession + .pluginManager() + .onEventCallback(contextWithUpdatedSession, registeredEvent) + .defaultIfEmpty(registeredEvent); + }) + .toFlowable()); + } // If beforeRunCallback returns content, emit it and skip agent Context capturedContext = Context.current(); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index ff75c97b0..d847494e0 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -46,6 +46,7 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; import com.google.adk.sessions.BaseSessionService; @@ -1686,4 +1687,67 @@ public void runner_executesSaveArtifactFlow() { // agent was run assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); } + + @Test + public void runAsync_ensuresSequentialConsistencyForTools() { + // Arrange + TestLlm testLlm = + createTestLlm( + createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")), + createTextLlmResponse("Final response")); + + LlmAgent agent = + createTestAgentBuilder(testLlm) + .tools( + ImmutableList.of( + FunctionTool.create(RaceConditionTools.class, "tool1"), + FunctionTool.create(RaceConditionTools.class, "tool2"))) + .build(); + + Runner runner = + Runner.builder().app(App.builder().name("test").rootAgent(agent).build()).build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + // Act + var unused = + runner + .runAsync("user", session.id(), Content.fromParts(Part.fromText("start"))) + .toList() + .blockingGet(); + + // Assert + List requests = testLlm.getRequests(); + assertThat(requests).hasSize(2); + + // Second request should contain the result of tool1 + LlmRequest secondRequest = requests.get(1); + List history = secondRequest.contents(); + + boolean foundToolResponse = false; + for (Content content : history) { + for (Part part : content.parts().get()) { + if (part.functionResponse().isPresent() + && part.functionResponse().get().name().isPresent() + && "tool1".equals(part.functionResponse().get().name().get())) { + foundToolResponse = true; + assertThat(part.functionResponse().get().response().isPresent()).isTrue(); + assertThat(part.functionResponse().get().response().get()) + .isEqualTo(ImmutableMap.of("result", "result_value1")); + } + } + } + assertThat(foundToolResponse).isTrue(); + } + + public static class RaceConditionTools { + private RaceConditionTools() {} + + public static ImmutableMap tool1(String arg) { + return ImmutableMap.of("result", "result_" + arg); + } + + public static ImmutableMap tool2(String input) { + return ImmutableMap.of("status", "received_" + input); + } + } }