From ea6e397909d531e83a361b88fd3f5a13b116e309 Mon Sep 17 00:00:00 2001 From: YuqiGuo105 Date: Wed, 8 Apr 2026 21:49:24 -0600 Subject: [PATCH] fix: parallel tool execution blocked by synchronous FunctionTool.runAsync Fixes #735 When ToolExecutionMode.PARALLEL is set and the LLM returns multiple function calls, tools were still executing sequentially because callTool() invoked tool.runAsync() as a plain method call, causing FunctionTool to execute func.invoke() synchronously on the subscribing thread before returning a Single. Since concatMapEager eagerly subscribes to all inner Observables, it could not dispatch work to IO threads if the subscription itself was blocking. Fix: wrap tool.runAsync() in Single.defer() so the call is deferred until subscription time, then apply subscribeOn(Schedulers.io()) to move the actual subscription (and thus the synchronous func.invoke() call) onto an IO thread. concatMapEager then subscribes to all tool Singles eagerly, each on its own IO thread, achieving true parallelism while preserving result order. Added two new timing-based tests that verify: - PARALLEL mode: two 1000ms tools complete in <1500ms (not ~2000ms) - SEQUENTIAL mode: two 1000ms tools take >=2000ms (serial contract unchanged) --- .../google/adk/flows/llmflows/Functions.java | 4 +- .../adk/flows/llmflows/FunctionsTest.java | 139 ++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 0b0e5b4d5..a30b909f3 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -48,6 +48,7 @@ import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.functions.Function; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -591,8 +592,9 @@ private static Maybe> maybeInvokeAfterToolCall( private static Maybe> callTool( BaseTool tool, Map args, ToolContext toolContext) { - return tool.runAsync(args, toolContext) + return Single.defer(() -> tool.runAsync(args, toolContext)) .toMaybe() + .subscribeOn(Schedulers.io()) .doOnError(t -> Span.current().recordException(t)) .onErrorResumeNext( e -> diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index d5db4d4b3..3438b7fee 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -27,12 +27,18 @@ import com.google.adk.agents.RunConfig.ToolExecutionMode; import com.google.adk.events.Event; import com.google.adk.testing.TestUtils; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -389,4 +395,137 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal ImmutableList result = Functions.getAskUserConfirmationFunctionCalls(event); assertThat(result).containsExactly(confirmationCall1, confirmationCall2); } + + /** + * A tool that blocks for a specified duration, simulating a slow I/O operation. Uses + * Single.fromCallable to ensure the sleep is deferred until subscription time. + */ + private static class SlowTool extends BaseTool { + private final String toolName; + private final long sleepMillis; + + SlowTool(String name, long sleepMillis) { + super(name, "A slow tool for testing parallel execution"); + this.toolName = name; + this.sleepMillis = sleepMillis; + } + + @Override + public Optional declaration() { + return Optional.of(FunctionDeclaration.builder().name(toolName).build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + return Single.fromCallable( + () -> { + Thread.sleep(sleepMillis); + return ImmutableMap.of("tool", toolName, "status", "done"); + }); + } + } + + @Test + public void handleFunctionCalls_parallelMode_shouldExecuteConcurrently() { + long sleepTime = 1000; + SlowTool slowTool1 = new SlowTool("slow_tool_1", sleepTime); + SlowTool slowTool2 = new SlowTool("slow_tool_2", sleepTime); + + InvocationContext invocationContext = + createInvocationContext( + createRootAgent(), + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build()); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_1") + .name("slow_tool_1") + .args(ImmutableMap.of()) + .build()) + .build(), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_2") + .name("slow_tool_2") + .args(ImmutableMap.of()) + .build()) + .build())) + .build(); + + long startTime = System.currentTimeMillis(); + Event result = + Functions.handleFunctionCalls( + invocationContext, + event, + ImmutableMap.of("slow_tool_1", slowTool1, "slow_tool_2", slowTool2)) + .blockingGet(); + long duration = System.currentTimeMillis() - startTime; + + // If parallel, duration should be ~1000ms, not ~2000ms. + assertThat(duration).isAtLeast(sleepTime); + assertThat(duration).isLessThan((long) (1.5 * sleepTime)); + + // Verify results are returned in correct order (concatMapEager preserves order). + assertThat(result).isNotNull(); + assertThat(result.content().get().parts().get()).hasSize(2); + assertThat(result.content().get().parts().get().get(0).functionResponse().get().name()) + .hasValue("slow_tool_1"); + assertThat(result.content().get().parts().get().get(1).functionResponse().get().name()) + .hasValue("slow_tool_2"); + } + + @Test + public void handleFunctionCalls_sequentialMode_shouldExecuteSerially() { + long sleepTime = 1000; + SlowTool slowTool1 = new SlowTool("slow_tool_1", sleepTime); + SlowTool slowTool2 = new SlowTool("slow_tool_2", sleepTime); + + InvocationContext invocationContext = + createInvocationContext( + createRootAgent(), + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.SEQUENTIAL).build()); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_1") + .name("slow_tool_1") + .args(ImmutableMap.of()) + .build()) + .build(), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_2") + .name("slow_tool_2") + .args(ImmutableMap.of()) + .build()) + .build())) + .build(); + + long startTime = System.currentTimeMillis(); + Event result = + Functions.handleFunctionCalls( + invocationContext, + event, + ImmutableMap.of("slow_tool_1", slowTool1, "slow_tool_2", slowTool2)) + .blockingGet(); + long duration = System.currentTimeMillis() - startTime; + + // Sequential: duration should be >= 2 * sleepTime. + assertThat(duration).isAtLeast(2 * sleepTime); + + assertThat(result).isNotNull(); + assertThat(result.content().get().parts().get()).hasSize(2); + } }