Skip to content

Commit d0cd310

Browse files
author
lucas
committed
feat: enhance AgentTool to support custom plugins for observability
1 parent e3ea378 commit d0cd310

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

core/src/main/java/com/google/adk/tools/AgentTool.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
2727
import com.google.adk.agents.LlmAgent;
2828
import com.google.adk.events.Event;
29+
import com.google.adk.plugins.Plugin;
2930
import com.google.adk.runner.InMemoryRunner;
3031
import com.google.adk.runner.Runner;
3132
import com.google.adk.sessions.State;
@@ -46,6 +47,7 @@ public class AgentTool extends BaseTool {
4647

4748
private final BaseAgent agent;
4849
private final boolean skipSummarization;
50+
private final List<Plugin> plugins;
4951

5052
public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath)
5153
throws ConfigurationException {
@@ -62,21 +64,32 @@ public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath)
6264
}
6365

6466
BaseAgent agent = resolvedAgents.get(0);
65-
return AgentTool.create(agent, args.getOrDefault("skipSummarization", false).booleanValue());
67+
return AgentTool.create(
68+
agent, args.getOrDefault("skipSummarization", false).booleanValue(), ImmutableList.of());
69+
}
70+
71+
public static AgentTool create(
72+
BaseAgent agent, boolean skipSummarization, List<? extends Plugin> plugins) {
73+
return new AgentTool(agent, skipSummarization, plugins);
6674
}
6775

6876
public static AgentTool create(BaseAgent agent, boolean skipSummarization) {
69-
return new AgentTool(agent, skipSummarization);
77+
return new AgentTool(agent, skipSummarization, ImmutableList.of());
7078
}
7179

7280
public static AgentTool create(BaseAgent agent) {
73-
return new AgentTool(agent, false);
81+
return new AgentTool(agent, false, ImmutableList.of());
7482
}
7583

7684
protected AgentTool(BaseAgent agent, boolean skipSummarization) {
85+
this(agent, skipSummarization, ImmutableList.of());
86+
}
87+
88+
protected AgentTool(BaseAgent agent, boolean skipSummarization, List<? extends Plugin> plugins) {
7789
super(agent.name(), agent.description());
7890
this.agent = agent;
7991
this.skipSummarization = skipSummarization;
92+
this.plugins = ImmutableList.copyOf(plugins != null ? plugins : ImmutableList.of());
8093
}
8194

8295
@VisibleForTesting
@@ -159,7 +172,7 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
159172
content = Content.fromParts(Part.fromText(input.toString()));
160173
}
161174

162-
Runner runner = new InMemoryRunner(this.agent, toolContext.agentName());
175+
Runner runner = new InMemoryRunner(this.agent, toolContext.agentName(), this.plugins);
163176
// Session state is final, can't update to toolContext state
164177
// session.toBuilder().setState(toolContext.getState());
165178
return runner
@@ -219,3 +232,4 @@ private void updateState(Map<String, Object> stateDelta, Map<String, Object> sta
219232
});
220233
}
221234
}
235+

core/src/test/java/com/google/adk/tools/AgentToolTest.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,62 @@ public void declaration_emptySequentialAgent_fallsBackToRequest() {
664664
.build());
665665
}
666666

667+
@Test
668+
public void create_withPlugins_initializesCorrectly() {
669+
LlmAgent testAgent =
670+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
671+
.name("agent name")
672+
.description("agent description")
673+
.build();
674+
675+
AgentTool agentTool = AgentTool.create(testAgent, false, ImmutableList.of());
676+
677+
assertThat(agentTool).isNotNull();
678+
assertThat(agentTool.declaration()).isPresent();
679+
}
680+
681+
@Test
682+
public void runAsync_withPlugins_usesThem() {
683+
LlmAgent testAgent =
684+
createTestAgentBuilder(
685+
createTestLlm(
686+
LlmResponse.builder()
687+
.content(Content.fromParts(Part.fromText("Sub-agent executed")))
688+
.build()))
689+
.name("sub-agent")
690+
.description("sub-agent description")
691+
.build();
692+
693+
TestPlugin testPlugin = new TestPlugin();
694+
695+
AgentTool agentTool = AgentTool.create(testAgent, false, ImmutableList.of(testPlugin));
696+
697+
ToolContext toolContext = createToolContext(testAgent);
698+
699+
Map<String, Object> result =
700+
agentTool.runAsync(ImmutableMap.of("request", "start"), toolContext).blockingGet();
701+
702+
assertThat(result).containsEntry("result", "Sub-agent executed");
703+
704+
assertThat(testPlugin.wasCalled.get()).isTrue();
705+
}
706+
707+
private static class TestPlugin extends com.google.adk.plugins.BasePlugin {
708+
final java.util.concurrent.atomic.AtomicBoolean wasCalled =
709+
new java.util.concurrent.atomic.AtomicBoolean(false);
710+
711+
TestPlugin() {
712+
super("test-plugin");
713+
}
714+
715+
@Override
716+
public Maybe<Content> onUserMessageCallback(
717+
InvocationContext invocationContext, Content userMessage) {
718+
wasCalled.set(true);
719+
return Maybe.empty();
720+
}
721+
}
722+
667723
private ToolContext createToolContext(BaseAgent agent) {
668724
Session session =
669725
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();

0 commit comments

Comments
 (0)