diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index d6ad38561..44a281f72 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -45,6 +45,7 @@ import com.google.adk.utils.CollectionUtils; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.MapMaker; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.AudioTranscriptionConfig; import com.google.genai.types.Content; @@ -57,6 +58,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.subjects.CompletableSubject; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -64,6 +66,7 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import org.jspecify.annotations.Nullable; /** The main class for the GenAI Agents runner. */ @@ -76,6 +79,8 @@ public class Runner { private final PluginManager pluginManager; @Nullable private final EventsCompactionConfig eventsCompactionConfig; @Nullable private final ContextCacheConfig contextCacheConfig; + private final ConcurrentMap activeSessionCompletables = + new MapMaker().weakValues().makeMap(); /** Builder for {@link Runner}. */ public static class Builder { @@ -380,25 +385,57 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { + Flowable result = + Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", + sessionId, userId))); + })) + .flatMapPublisher( + session -> + this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); + return Flowable.defer( - () -> - this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession( - appName, userId, (Map) null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format( - "Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher( - session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) - .compose(Tracing.trace("invocation")); + () -> { + if (sessionId == null) { + return result; + } + + CompletableSubject requestCompletion = CompletableSubject.create(); + + Completable[] previousHolder = new Completable[1]; + + activeSessionCompletables.compute( + sessionId, + (key, current) -> { + previousHolder[0] = current; + return requestCompletion; + }); + + Completable previous = previousHolder[0]; + + Flowable sequenced = + (previous == null) ? result : previous.onErrorComplete().andThen(result); + + return sequenced.doFinally( + () -> { + requestCompletion.onComplete(); + activeSessionCompletables.remove(sessionId, requestCompletion); + }); + }); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -740,6 +777,9 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { for (Event event : events) { String author = event.author(); + if (author == null) { + continue; + } if (author.equals("user")) { continue; } diff --git a/core/src/main/java/com/google/adk/sessions/Session.java b/core/src/main/java/com/google/adk/sessions/Session.java index 94504fd96..24251619d 100644 --- a/core/src/main/java/com/google/adk/sessions/Session.java +++ b/core/src/main/java/com/google/adk/sessions/Session.java @@ -123,7 +123,7 @@ public Builder userId(String userId) { @CanIgnoreReturnValue @JsonProperty("events") public Builder events(List events) { - this.events = Collections.synchronizedList(events); + this.events = Collections.synchronizedList(new ArrayList<>(events)); return this; } diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 36530faf2..ff75c97b0 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -27,10 +27,13 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.stream; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -72,13 +75,17 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.subjects.PublishSubject; import io.reactivex.rxjava3.subscribers.TestSubscriber; +import java.time.Instant; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -614,6 +621,244 @@ public void callbackContextData_preservedAcrossInvocation() { assertThat(contextCaptor.getValue().callbackContextData()).containsEntry(testKey, testValue); } + @Test + public void runAsync_passesSessionSnapshotToPersistenceService() { + BaseSessionService mockSessionService = mock(BaseSessionService.class); + Event agentEvent = Event.builder().id("agent-event").author("agent").build(); + + // Mock agent to return one event + BaseAgent mockAgent = mock(BaseAgent.class); + when(mockAgent.runAsync(any())).thenReturn(Flowable.just(agentEvent)); + + // Mock session service + Session testSession = Session.builder("session-id").appName("test").userId("user").build(); + when(mockSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenReturn(Maybe.just(testSession)); + when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(agentEvent)); + + Runner runnerWithMockService = + Runner.builder() + .app(App.builder().name("test").rootAgent(mockAgent).build()) + .sessionService(mockSessionService) + .build(); + + var unused = + runnerWithMockService + .runAsync("user", "session-id", createContent("start")) + .toList() + .blockingGet(); + + ArgumentCaptor sessionCaptor = ArgumentCaptor.forClass(Session.class); + ArgumentCaptor eventCaptor = ArgumentCaptor.forClass(Event.class); + + // We expect 2 calls to appendEvent: one for user message, one for agent response. + verify(mockSessionService, times(2)) + .appendEvent(sessionCaptor.capture(), eventCaptor.capture()); + + List capturedSessions = sessionCaptor.getAllValues(); + + // The second call should be for the agent response + Session sessionForAgentEvent = capturedSessions.get(1); + + assertThat(sessionForAgentEvent.id()).isEqualTo("session-id"); + + // Verify it is a snapshot (does not contain the agent event itself) + assertThat(sessionForAgentEvent.events()).doesNotContain(agentEvent); + } + + @Test + public void runAsync_multiEventExecution_lastUpdateTimeProgresses() throws Exception { + BaseSessionService mockSessionService = mock(BaseSessionService.class); + + Event event1 = Event.builder().id("event-1").author("agent").timestamp(200).build(); + Event event2 = Event.builder().id("event-2").author("agent").timestamp(300).build(); + + BaseAgent mockAgent = mock(BaseAgent.class); + when(mockAgent.runAsync(any())).thenReturn(Flowable.just(event1, event2)); + + // Initial session with timestamp 100 + Session testSession = + Session.builder("session-id") + .appName("test") + .userId("user") + .lastUpdateTime(Instant.ofEpochMilli(100)) + .build(); + + when(mockSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenReturn(Maybe.just(testSession)); + + // Mock appendEvent to return the event passed to it and capture timestamps + List capturedTimestamps = new ArrayList<>(); + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> { + Session s = invocation.getArgument(0); + Event e = invocation.getArgument(1); + capturedTimestamps.add(s.lastUpdateTime()); + if (!Objects.equals(e.author(), "user")) { + s.lastUpdateTime(Instant.ofEpochMilli(e.timestamp())); + } + return Single.just(e); + }); + + Runner runnerWithMockService = + Runner.builder() + .app(App.builder().name("test").rootAgent(mockAgent).build()) + .sessionService(mockSessionService) + .build(); + + var unused = + runnerWithMockService + .runAsync("user", "session-id", createContent("start")) + .toList() + .blockingGet(); + + ArgumentCaptor sessionCaptor = ArgumentCaptor.forClass(Session.class); + ArgumentCaptor eventCaptor = ArgumentCaptor.forClass(Event.class); + + // We expect 3 calls to appendEvent: + // 1 for user message + // 2 for agent events (event1, event2) + verify(mockSessionService, times(3)) + .appendEvent(sessionCaptor.capture(), eventCaptor.capture()); + + // Verify timestamp for event1 call is the initial timestamp (100) + assertThat(capturedTimestamps.get(1)).isEqualTo(Instant.ofEpochMilli(100)); + + // Verify timestamp for event2 call is the timestamp of event1 (200) + assertThat(capturedTimestamps.get(2)).isEqualTo(Instant.ofEpochMilli(200)); + } + + @Test + public void runAsync_concurrentCalls_staleRead() throws Exception { + BaseSessionService mockSessionService = mock(BaseSessionService.class); + Event agentEvent = Event.builder().id("agent-event").author("agent").build(); + + BaseAgent mockAgent = mock(BaseAgent.class); + when(mockAgent.runAsync(any())).thenReturn(Flowable.just(agentEvent)); + + Session initialSession = Session.builder("session-id").appName("test").userId("user").build(); + AtomicReference dbSession = new AtomicReference<>(initialSession); + + when(mockSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenAnswer(invocation -> Maybe.just(dbSession.get())); + + PublishSubject appendSubject = PublishSubject.create(); + + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> { + Session s = invocation.getArgument(0); + Event e = invocation.getArgument(1); + return appendSubject + .firstOrError() + .doOnSuccess( + event -> { + List newEvents = new ArrayList<>(s.events()); + newEvents.add(e); + Session updated = + Session.builder(s.id()) + .appName(s.appName()) + .userId(s.userId()) + .state(s.state()) + .events(newEvents) + .build(); + dbSession.set(updated); + }); + }); + + Runner runnerWithMockService = + Runner.builder() + .app(App.builder().name("test").rootAgent(mockAgent).build()) + .sessionService(mockSessionService) + .build(); + + TestSubscriber subscriber1 = new TestSubscriber<>(); + runnerWithMockService + .runAsync("user", "session-id", createContent("message 1")) + .subscribe(subscriber1); + + TestSubscriber subscriber2 = new TestSubscriber<>(); + runnerWithMockService + .runAsync("user", "session-id", createContent("message 2")) + .subscribe(subscriber2); + + appendSubject.onNext(agentEvent); // Completes first appendEvent (user msg 1) + appendSubject.onNext(agentEvent); // Completes second appendEvent (agent event 1) + appendSubject.onNext(agentEvent); // Completes third appendEvent (user msg 2) + appendSubject.onNext(agentEvent); // Completes fourth appendEvent (agent event 2) + + subscriber1.awaitDone(5, SECONDS); + subscriber2.awaitDone(5, SECONDS); + + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + verify(mockAgent, times(2)).runAsync(contextCaptor.capture()); + + List capturedContexts = contextCaptor.getAllValues(); + InvocationContext context2 = capturedContexts.get(1); + + assertThat(simplifyEvents(context2.session().events())).contains("user: message 1"); + } + + @Test + public void runAsync_concurrentCalls_firstFails_secondSucceeds() throws Exception { + BaseSessionService mockSessionService = mock(BaseSessionService.class); + Event agentEvent = Event.builder().id("agent-event").author("agent").build(); + + BaseAgent mockAgent = mock(BaseAgent.class); + when(mockAgent.runAsync(any())) + .thenReturn(Flowable.error(new RuntimeException("Agent failed"))) + .thenReturn(Flowable.just(agentEvent)); + + Session initialSession = Session.builder("session-id").appName("test").userId("user").build(); + AtomicReference dbSession = new AtomicReference<>(initialSession); + + when(mockSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenAnswer(invocation -> Maybe.just(dbSession.get())); + + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> { + Session s = invocation.getArgument(0); + Event e = invocation.getArgument(1); + List newEvents = new ArrayList<>(s.events()); + newEvents.add(e); + Session updated = + Session.builder(s.id()) + .appName(s.appName()) + .userId(s.userId()) + .state(s.state()) + .events(newEvents) + .build(); + dbSession.set(updated); + return Single.just(e); + }); + + Runner runnerWithMockService = + Runner.builder() + .app(App.builder().name("test").rootAgent(mockAgent).build()) + .sessionService(mockSessionService) + .build(); + + TestSubscriber subscriber1 = new TestSubscriber<>(); + runnerWithMockService + .runAsync("user", "session-id", createContent("message 1")) + .subscribe(subscriber1); + + TestSubscriber subscriber2 = new TestSubscriber<>(); + runnerWithMockService + .runAsync("user", "session-id", createContent("message 2")) + .subscribe(subscriber2); + + subscriber1.awaitDone(5, SECONDS); + subscriber2.awaitDone(5, SECONDS); + + subscriber1.assertError(RuntimeException.class); + subscriber2.assertComplete(); + subscriber2.assertValue(agentEvent); + } + @Test public void runAsync_withSessionKey_success() { var events = diff --git a/core/src/test/java/com/google/adk/sessions/SessionTest.java b/core/src/test/java/com/google/adk/sessions/SessionTest.java new file mode 100644 index 000000000..a96b63acd --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/SessionTest.java @@ -0,0 +1,36 @@ +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.events.Event; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SessionTest { + + @Test + public void builder_events_createsMutableCopy() { + Event event1 = + Event.builder().author("user").content(Content.fromParts(Part.fromText("hi"))).build(); + Event event2 = + Event.builder().author("model").content(Content.fromParts(Part.fromText("hello"))).build(); + ImmutableList immutableList = ImmutableList.of(event1); + + Session session = + Session.builder("session-id") + .appName("test-app") + .userId("test-user") + .events(immutableList) + .build(); + + // Verify we can add to the list + session.events().add(event2); + + assertThat(session.events()).containsExactly(event1, event2).inOrder(); + } +}