Skip to content

Commit 3c431e1

Browse files
Merge pull request #612 from jinnigu:feature/subagent-escalation
PiperOrigin-RevId: 892863356
2 parents 1855584 + 88c8b0e commit 3c431e1

File tree

2 files changed

+139
-1
lines changed

2 files changed

+139
-1
lines changed

core/src/main/java/com/google/adk/agents/ParallelAgent.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
148148
for (BaseAgent subAgent : currentSubAgents) {
149149
agentFlowables.add(subAgent.runAsync(updatedInvocationContext).subscribeOn(scheduler));
150150
}
151-
return Flowable.merge(agentFlowables);
151+
return Flowable.merge(agentFlowables)
152+
.takeUntil((Event event) -> event.actions().escalate().orElse(false));
152153
}
153154

154155
/**
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.agents;
18+
19+
import static com.google.adk.testing.TestUtils.createInvocationContext;
20+
import static com.google.common.truth.Truth.assertThat;
21+
import static java.util.concurrent.TimeUnit.MILLISECONDS;
22+
23+
import com.google.adk.events.Event;
24+
import com.google.adk.events.EventActions;
25+
import com.google.common.collect.ImmutableList;
26+
import com.google.genai.types.Content;
27+
import com.google.genai.types.Part;
28+
import io.reactivex.rxjava3.core.Flowable;
29+
import io.reactivex.rxjava3.core.Scheduler;
30+
import io.reactivex.rxjava3.schedulers.TestScheduler;
31+
import org.junit.Test;
32+
import org.junit.runner.RunWith;
33+
import org.junit.runners.JUnit4;
34+
35+
@RunWith(JUnit4.class)
36+
public final class ParallelAgentEscalationTest {
37+
38+
static class TestAgent extends BaseAgent {
39+
private final long delayMillis;
40+
private final Scheduler scheduler;
41+
private final String content;
42+
private final EventActions actions;
43+
44+
private TestAgent(String name, long delayMillis, Scheduler scheduler, String content) {
45+
this(name, delayMillis, scheduler, content, null);
46+
}
47+
48+
private TestAgent(
49+
String name, long delayMillis, Scheduler scheduler, String content, EventActions actions) {
50+
super(name, "Test Agent", ImmutableList.of(), null, null);
51+
this.delayMillis = delayMillis;
52+
this.scheduler = scheduler;
53+
this.content = content;
54+
this.actions = actions;
55+
}
56+
57+
@Override
58+
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
59+
Flowable<Event> event =
60+
Flowable.fromCallable(
61+
() -> {
62+
Event.Builder builder =
63+
Event.builder()
64+
.author(name())
65+
.branch(invocationContext.branch().orElse(null))
66+
.invocationId(invocationContext.invocationId())
67+
.content(Content.fromParts(Part.fromText(content)));
68+
69+
if (actions != null) {
70+
builder.actions(actions);
71+
}
72+
return builder.build();
73+
});
74+
75+
if (delayMillis > 0) {
76+
return event.delay(delayMillis, MILLISECONDS, scheduler);
77+
}
78+
return event;
79+
}
80+
81+
@Override
82+
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
83+
throw new UnsupportedOperationException("Not implemented");
84+
}
85+
}
86+
87+
@Test
88+
public void runAsync_escalationEvent_shortCircuitsOtherAgents() {
89+
TestScheduler testScheduler = new TestScheduler();
90+
91+
TestAgent escalatingAgent =
92+
new TestAgent(
93+
"escalating_agent",
94+
100,
95+
testScheduler,
96+
"Escalating!",
97+
EventActions.builder().escalate(true).build());
98+
TestAgent slowAgent = new TestAgent("slow_agent", 500, testScheduler, "Finished");
99+
TestAgent fastAgent = new TestAgent("fast_agent", 50, testScheduler, "Finished");
100+
101+
ParallelAgent parallelAgent =
102+
ParallelAgent.builder()
103+
.name("parallel_agent")
104+
.subAgents(fastAgent, escalatingAgent, slowAgent)
105+
.scheduler(testScheduler)
106+
.build();
107+
108+
InvocationContext invocationContext = createInvocationContext(parallelAgent);
109+
110+
var subscriber = parallelAgent.runAsync(invocationContext).test();
111+
112+
// Fast agent completes at 50ms (before the escalation)
113+
testScheduler.advanceTimeBy(50, MILLISECONDS);
114+
subscriber.assertValueCount(1);
115+
assertThat(subscriber.values().get(0).author()).isEqualTo("fast_agent");
116+
117+
// Escalating agent completes at 100ms
118+
testScheduler.advanceTimeBy(50, MILLISECONDS);
119+
subscriber.assertValueCount(2);
120+
121+
Event event1 = subscriber.values().get(0);
122+
assertThat(event1.author()).isEqualTo("fast_agent");
123+
124+
Event event2 = subscriber.values().get(1);
125+
assertThat(event2.author()).isEqualTo("escalating_agent");
126+
assertThat(event2.actions().escalate()).hasValue(true);
127+
128+
subscriber.assertComplete();
129+
130+
// Slow agent would complete at 500ms, but test scheduler advances time to prove
131+
// sequence was forcibly terminated!
132+
testScheduler.advanceTimeBy(400, MILLISECONDS);
133+
134+
// Test RxJava Disposal behavior: SlowAgent won't emit anything
135+
subscriber.assertValueCount(2);
136+
}
137+
}

0 commit comments

Comments
 (0)