diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 0b0e5b4d5..a30b909f3 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -48,6 +48,7 @@ import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.functions.Function; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -591,8 +592,9 @@ private static Maybe> maybeInvokeAfterToolCall( private static Maybe> callTool( BaseTool tool, Map args, ToolContext toolContext) { - return tool.runAsync(args, toolContext) + return Single.defer(() -> tool.runAsync(args, toolContext)) .toMaybe() + .subscribeOn(Schedulers.io()) .doOnError(t -> Span.current().recordException(t)) .onErrorResumeNext( e -> diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index d5db4d4b3..3438b7fee 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -27,12 +27,18 @@ import com.google.adk.agents.RunConfig.ToolExecutionMode; import com.google.adk.events.Event; import com.google.adk.testing.TestUtils; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -389,4 +395,137 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal ImmutableList result = Functions.getAskUserConfirmationFunctionCalls(event); assertThat(result).containsExactly(confirmationCall1, confirmationCall2); } + + /** + * A tool that blocks for a specified duration, simulating a slow I/O operation. Uses + * Single.fromCallable to ensure the sleep is deferred until subscription time. + */ + private static class SlowTool extends BaseTool { + private final String toolName; + private final long sleepMillis; + + SlowTool(String name, long sleepMillis) { + super(name, "A slow tool for testing parallel execution"); + this.toolName = name; + this.sleepMillis = sleepMillis; + } + + @Override + public Optional declaration() { + return Optional.of(FunctionDeclaration.builder().name(toolName).build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + return Single.fromCallable( + () -> { + Thread.sleep(sleepMillis); + return ImmutableMap.of("tool", toolName, "status", "done"); + }); + } + } + + @Test + public void handleFunctionCalls_parallelMode_shouldExecuteConcurrently() { + long sleepTime = 1000; + SlowTool slowTool1 = new SlowTool("slow_tool_1", sleepTime); + SlowTool slowTool2 = new SlowTool("slow_tool_2", sleepTime); + + InvocationContext invocationContext = + createInvocationContext( + createRootAgent(), + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build()); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_1") + .name("slow_tool_1") + .args(ImmutableMap.of()) + .build()) + .build(), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_2") + .name("slow_tool_2") + .args(ImmutableMap.of()) + .build()) + .build())) + .build(); + + long startTime = System.currentTimeMillis(); + Event result = + Functions.handleFunctionCalls( + invocationContext, + event, + ImmutableMap.of("slow_tool_1", slowTool1, "slow_tool_2", slowTool2)) + .blockingGet(); + long duration = System.currentTimeMillis() - startTime; + + // If parallel, duration should be ~1000ms, not ~2000ms. + assertThat(duration).isAtLeast(sleepTime); + assertThat(duration).isLessThan((long) (1.5 * sleepTime)); + + // Verify results are returned in correct order (concatMapEager preserves order). + assertThat(result).isNotNull(); + assertThat(result.content().get().parts().get()).hasSize(2); + assertThat(result.content().get().parts().get().get(0).functionResponse().get().name()) + .hasValue("slow_tool_1"); + assertThat(result.content().get().parts().get().get(1).functionResponse().get().name()) + .hasValue("slow_tool_2"); + } + + @Test + public void handleFunctionCalls_sequentialMode_shouldExecuteSerially() { + long sleepTime = 1000; + SlowTool slowTool1 = new SlowTool("slow_tool_1", sleepTime); + SlowTool slowTool2 = new SlowTool("slow_tool_2", sleepTime); + + InvocationContext invocationContext = + createInvocationContext( + createRootAgent(), + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.SEQUENTIAL).build()); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_1") + .name("slow_tool_1") + .args(ImmutableMap.of()) + .build()) + .build(), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_2") + .name("slow_tool_2") + .args(ImmutableMap.of()) + .build()) + .build())) + .build(); + + long startTime = System.currentTimeMillis(); + Event result = + Functions.handleFunctionCalls( + invocationContext, + event, + ImmutableMap.of("slow_tool_1", slowTool1, "slow_tool_2", slowTool2)) + .blockingGet(); + long duration = System.currentTimeMillis() - startTime; + + // Sequential: duration should be >= 2 * sleepTime. + assertThat(duration).isAtLeast(2 * sleepTime); + + assertThat(result).isNotNull(); + assertThat(result.content().get().parts().get()).hasSize(2); + } }