diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index d6ad38561..ff5067c67 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -448,6 +448,12 @@ protected Flowable runAsyncImpl( BaseAgent rootAgent = this.agent; String invocationId = InvocationContext.newInvocationContextId(); + // Pre-merge stateDelta so onUserMessageCallback can access it. + // Safe: session is a copy; persistence still happens via appendNewMessageToSession. + if (stateDelta != null && !stateDelta.isEmpty()) { + stateDelta.forEach((key, value) -> session.state().put(key, value)); + } + // Create initial context InvocationContext initialContext = newInvocationContextBuilder(session) diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 36530faf2..2f217c3c8 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -877,6 +877,38 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { assertThat(sessionInCallback.state()).containsEntry("number", 123); } + @Test + public void onUserMessageCallback_withStateDelta_seesMergedState() { + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + when(plugin.onUserMessageCallback(contextCaptor.capture(), any())).thenReturn(Maybe.empty()); + + ImmutableMap stateDelta = + ImmutableMap.of("callback_key", "callback_value", "number", 123); + + var unused = + runner + .runAsync( + "user", + session.id(), + createContent("test with state"), + RunConfig.builder().build(), + stateDelta) + .toList() + .blockingGet(); + + // Verify onUserMessageCallback was called + verify(plugin).onUserMessageCallback(any(), any()); + + // Verify the context passed to onUserMessageCallback has the merged state + InvocationContext capturedContext = contextCaptor.getValue(); + Session sessionInCallback = capturedContext.session(); + + // Verify state delta was merged before onUserMessageCallback was invoked + assertThat(sessionInCallback.state()).containsEntry("callback_key", "callback_value"); + assertThat(sessionInCallback.state()).containsEntry("number", 123); + } + @Test public void runAsync_ensureEventsAreAppendedInOrder() throws Exception { Event event1 = TestUtils.createEvent("1");