Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.functions.Function;
import io.reactivex.rxjava3.schedulers.Schedulers;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -591,8 +592,9 @@ private static Maybe<Map<String, Object>> maybeInvokeAfterToolCall(

private static Maybe<Map<String, Object>> callTool(
BaseTool tool, Map<String, Object> args, ToolContext toolContext) {
return tool.runAsync(args, toolContext)
return Single.defer(() -> tool.runAsync(args, toolContext))
.toMaybe()
.subscribeOn(Schedulers.io())
.doOnError(t -> Span.current().recordException(t))
.onErrorResumeNext(
e ->
Expand Down
139 changes: 139 additions & 0 deletions core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,18 @@
import com.google.adk.agents.RunConfig.ToolExecutionMode;
import com.google.adk.events.Event;
import com.google.adk.testing.TestUtils;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Single;
import java.util.Map;
import java.util.Optional;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -389,4 +395,137 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal
ImmutableList<FunctionCall> result = Functions.getAskUserConfirmationFunctionCalls(event);
assertThat(result).containsExactly(confirmationCall1, confirmationCall2);
}

/**
* A tool that blocks for a specified duration, simulating a slow I/O operation. Uses
* Single.fromCallable to ensure the sleep is deferred until subscription time.
*/
private static class SlowTool extends BaseTool {
private final String toolName;
private final long sleepMillis;

SlowTool(String name, long sleepMillis) {
super(name, "A slow tool for testing parallel execution");
this.toolName = name;
this.sleepMillis = sleepMillis;
}

@Override
public Optional<FunctionDeclaration> declaration() {
return Optional.of(FunctionDeclaration.builder().name(toolName).build());
}

@Override
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
return Single.fromCallable(
() -> {
Thread.sleep(sleepMillis);
return ImmutableMap.<String, Object>of("tool", toolName, "status", "done");
});
}
}

@Test
public void handleFunctionCalls_parallelMode_shouldExecuteConcurrently() {
long sleepTime = 1000;
SlowTool slowTool1 = new SlowTool("slow_tool_1", sleepTime);
SlowTool slowTool2 = new SlowTool("slow_tool_2", sleepTime);

InvocationContext invocationContext =
createInvocationContext(
createRootAgent(),
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());

Event event =
createEvent("event").toBuilder()
.content(
Content.fromParts(
Part.builder()
.functionCall(
FunctionCall.builder()
.id("call_1")
.name("slow_tool_1")
.args(ImmutableMap.of())
.build())
.build(),
Part.builder()
.functionCall(
FunctionCall.builder()
.id("call_2")
.name("slow_tool_2")
.args(ImmutableMap.of())
.build())
.build()))
.build();

long startTime = System.currentTimeMillis();
Event result =
Functions.handleFunctionCalls(
invocationContext,
event,
ImmutableMap.of("slow_tool_1", slowTool1, "slow_tool_2", slowTool2))
.blockingGet();
long duration = System.currentTimeMillis() - startTime;

// If parallel, duration should be ~1000ms, not ~2000ms.
assertThat(duration).isAtLeast(sleepTime);
assertThat(duration).isLessThan((long) (1.5 * sleepTime));

// Verify results are returned in correct order (concatMapEager preserves order).
assertThat(result).isNotNull();
assertThat(result.content().get().parts().get()).hasSize(2);
assertThat(result.content().get().parts().get().get(0).functionResponse().get().name())
.hasValue("slow_tool_1");
assertThat(result.content().get().parts().get().get(1).functionResponse().get().name())
.hasValue("slow_tool_2");
}

@Test
public void handleFunctionCalls_sequentialMode_shouldExecuteSerially() {
long sleepTime = 1000;
SlowTool slowTool1 = new SlowTool("slow_tool_1", sleepTime);
SlowTool slowTool2 = new SlowTool("slow_tool_2", sleepTime);

InvocationContext invocationContext =
createInvocationContext(
createRootAgent(),
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.SEQUENTIAL).build());

Event event =
createEvent("event").toBuilder()
.content(
Content.fromParts(
Part.builder()
.functionCall(
FunctionCall.builder()
.id("call_1")
.name("slow_tool_1")
.args(ImmutableMap.of())
.build())
.build(),
Part.builder()
.functionCall(
FunctionCall.builder()
.id("call_2")
.name("slow_tool_2")
.args(ImmutableMap.of())
.build())
.build()))
.build();

long startTime = System.currentTimeMillis();
Event result =
Functions.handleFunctionCalls(
invocationContext,
event,
ImmutableMap.of("slow_tool_1", slowTool1, "slow_tool_2", slowTool2))
.blockingGet();
long duration = System.currentTimeMillis() - startTime;

// Sequential: duration should be >= 2 * sleepTime.
assertThat(duration).isAtLeast(2 * sleepTime);

assertThat(result).isNotNull();
assertThat(result.content().get().parts().get()).hasSize(2);
}
}