Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fix-toolcallid-tracking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

Use toolCallId instead of generateId for parallel tool execution tracking to prevent premature stream closure
335 changes: 334 additions & 1 deletion packages/ai/src/generate-text/run-tools-transformation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,23 @@ import { describe, expect, it } from 'vitest';
import { z } from 'zod/v4';
import { NoSuchToolError } from '../error/no-such-tool-error';
import { MockTracer } from '../test/mock-tracer';
import { runToolsTransformation } from './run-tools-transformation';
import {
runToolsTransformation,
SingleRequestTextStreamPart,
} from './run-tools-transformation';
import { ToolSet } from './tool-set';

function isToolResult<T extends ToolSet>(
part: SingleRequestTextStreamPart<T>,
): part is SingleRequestTextStreamPart<T> & { type: 'tool-result' } {
return part.type === 'tool-result';
}

function isToolCall<T extends ToolSet>(
part: SingleRequestTextStreamPart<T>,
): part is SingleRequestTextStreamPart<T> & { type: 'tool-call' } {
return part.type === 'tool-call';
}

const testUsage: LanguageModelV3Usage = {
inputTokens: {
Expand Down Expand Up @@ -1140,4 +1156,321 @@ describe('runToolsTransformation', () => {
});
});
});

describe('parallel tool execution', () => {
it('should use toolCallId for tracking (not generateId) to handle parallel tools correctly', async () => {
// Frameworks can override _internal.generateId for message grouping, returning
// a constant pendingMessageId for all calls within a request. Tool execution
// tracking must use toolCallId (unique per LLM tool call) instead.
const pendingMessageId = 'msg-abc123';
const frameworkGenerateId = () => pendingMessageId;

const inputStream: ReadableStream<LanguageModelV3StreamPart> =
convertArrayToReadableStream([
{
type: 'tool-call',
toolCallId: 'unique-call-1',
toolName: 'toolA',
input: `{ "value": "a" }`,
},
{
type: 'tool-call',
toolCallId: 'unique-call-2',
toolName: 'toolB',
input: `{ "value": "b" }`,
},
{
type: 'tool-call',
toolCallId: 'unique-call-3',
toolName: 'toolC',
input: `{ "value": "c" }`,
},
{
type: 'finish',
finishReason: { unified: 'tool-calls', raw: 'tool_calls' },
usage: testUsage,
},
]);

const transformedStream = runToolsTransformation({
generateId: frameworkGenerateId,
tools: {
toolA: {
title: 'Tool A',
inputSchema: z.object({ value: z.string() }),
execute: async ({ value }) => {
await delay(30);
return `${value}-result`;
},
},
toolB: {
title: 'Tool B',
inputSchema: z.object({ value: z.string() }),
execute: async ({ value }) => {
await delay(10);
return `${value}-result`;
},
},
toolC: {
title: 'Tool C',
inputSchema: z.object({ value: z.string() }),
execute: async ({ value }) => {
await delay(20);
return `${value}-result`;
},
},
},
generatorStream: inputStream,
tracer: new MockTracer(),
telemetry: undefined,
messages: [],
system: undefined,
abortSignal: undefined,
repairToolCall: undefined,
experimental_context: undefined,
});

const result = await convertReadableStreamToArray(transformedStream);

// All three tool results should be captured
// (Bug: without the fix, only 1 result would be captured because
// outstandingToolResults Set would use the same ID for all tools)
const toolResults = result.filter(isToolResult);
expect(toolResults).toHaveLength(3);
expect(toolResults.map(r => r.toolCallId).sort()).toEqual([
'unique-call-1',
'unique-call-2',
'unique-call-3',
]);

// Finish should be last
expect(result[result.length - 1]).toMatchObject({
type: 'finish',
});
});

it('should capture all results when multiple tools execute in parallel with different delays', async () => {
const inputStream: ReadableStream<LanguageModelV3StreamPart> =
convertArrayToReadableStream([
{
type: 'tool-call',
toolCallId: 'call-1',
toolName: 'slowTool',
input: `{ "value": "slow" }`,
},
{
type: 'tool-call',
toolCallId: 'call-2',
toolName: 'fastTool',
input: `{ "value": "fast" }`,
},
{
type: 'tool-call',
toolCallId: 'call-3',
toolName: 'mediumTool',
input: `{ "value": "medium" }`,
},
{
type: 'finish',
finishReason: { unified: 'tool-calls', raw: 'tool_calls' },
usage: testUsage,
},
]);

const transformedStream = runToolsTransformation({
generateId: mockId({ prefix: 'id' }),
tools: {
slowTool: {
title: 'Slow Tool',
inputSchema: z.object({ value: z.string() }),
execute: async ({ value }) => {
await delay(50); // Slowest
return `${value}-result`;
},
},
fastTool: {
title: 'Fast Tool',
inputSchema: z.object({ value: z.string() }),
execute: async ({ value }) => {
await delay(10); // Fastest
return `${value}-result`;
},
},
mediumTool: {
title: 'Medium Tool',
inputSchema: z.object({ value: z.string() }),
execute: async ({ value }) => {
await delay(30); // Medium
return `${value}-result`;
},
},
},
generatorStream: inputStream,
tracer: new MockTracer(),
telemetry: undefined,
messages: [],
system: undefined,
abortSignal: undefined,
repairToolCall: undefined,
experimental_context: undefined,
});

const result = await convertReadableStreamToArray(transformedStream);

// All three tool calls should be present
const toolCalls = result.filter(isToolCall);
expect(toolCalls).toHaveLength(3);

// All three tool results should be present
const toolResults = result.filter(isToolResult);
expect(toolResults).toHaveLength(3);
expect(toolResults.map(r => r.toolCallId).sort()).toEqual([
'call-1',
'call-2',
'call-3',
]);

// Finish should be last
expect(result[result.length - 1]).toMatchObject({
type: 'finish',
});
});

it('should not close stream prematurely when fast tool completes before slow tool', async () => {
const executionOrder: string[] = [];

const inputStream: ReadableStream<LanguageModelV3StreamPart> =
convertArrayToReadableStream([
{
type: 'tool-call',
toolCallId: 'slow-call',
toolName: 'slowTool',
input: `{ "value": "slow" }`,
},
{
type: 'tool-call',
toolCallId: 'fast-call',
toolName: 'fastTool',
input: `{ "value": "fast" }`,
},
{
type: 'finish',
finishReason: { unified: 'tool-calls', raw: 'tool_calls' },
usage: testUsage,
},
]);

const transformedStream = runToolsTransformation({
generateId: mockId({ prefix: 'id' }),
tools: {
slowTool: {
title: 'Slow Tool',
inputSchema: z.object({ value: z.string() }),
execute: async ({ value }) => {
await delay(50);
executionOrder.push('slow-completed');
return `${value}-slow-result`;
},
},
fastTool: {
title: 'Fast Tool',
inputSchema: z.object({ value: z.string() }),
execute: async ({ value }) => {
await delay(5);
executionOrder.push('fast-completed');
return `${value}-fast-result`;
},
},
},
generatorStream: inputStream,
tracer: new MockTracer(),
telemetry: undefined,
messages: [],
system: undefined,
abortSignal: undefined,
repairToolCall: undefined,
experimental_context: undefined,
});

const result = await convertReadableStreamToArray(transformedStream);

// Fast tool should complete first
expect(executionOrder).toEqual(['fast-completed', 'slow-completed']);

// Both results should be captured
const toolResults = result.filter(isToolResult);
expect(toolResults).toHaveLength(2);
expect(toolResults.map(r => r.output).sort()).toEqual([
'fast-fast-result',
'slow-slow-result',
]);

// Stream should close properly after all tools complete
expect(result[result.length - 1]).toMatchObject({
type: 'finish',
});
});

it('should handle many parallel tool calls without losing results', async () => {
const toolCount = 10;
const toolCalls = Array.from({ length: toolCount }, (_, i) => ({
type: 'tool-call' as const,
toolCallId: `call-${i}`,
toolName: 'parallelTool',
input: `{ "index": ${i} }`,
}));

const inputStream: ReadableStream<LanguageModelV3StreamPart> =
convertArrayToReadableStream([
...toolCalls,
{
type: 'finish',
finishReason: { unified: 'tool-calls', raw: 'tool_calls' },
usage: testUsage,
},
]);

const transformedStream = runToolsTransformation({
generateId: mockId({ prefix: 'id' }),
tools: {
parallelTool: {
title: 'Parallel Tool',
inputSchema: z.object({ index: z.number() }),
execute: async ({ index }) => {
// Random delay to simulate real-world variance
await delay(Math.random() * 20);
return `result-${index}`;
},
},
},
generatorStream: inputStream,
tracer: new MockTracer(),
telemetry: undefined,
messages: [],
system: undefined,
abortSignal: undefined,
repairToolCall: undefined,
experimental_context: undefined,
});

const result = await convertReadableStreamToArray(transformedStream);

// All tool results should be captured
const toolResults = result.filter(isToolResult);
expect(toolResults).toHaveLength(toolCount);

// Verify all results are present (order may vary)
const resultOutputs = toolResults.map(r => r.output).sort();
const expectedOutputs = Array.from(
{ length: toolCount },
(_, i) => `result-${i}`,
).sort();
expect(resultOutputs).toEqual(expectedOutputs);

// Finish should be last
expect(result[result.length - 1]).toMatchObject({
type: 'finish',
});
});
});
});
6 changes: 5 additions & 1 deletion packages/ai/src/generate-text/run-tools-transformation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,11 @@ export function runToolsTransformation<TOOLS extends ToolSet>({

// Only execute tools that are not provider-executed:
if (tool.execute != null && toolCall.providerExecuted !== true) {
const toolExecutionId = generateId(); // use our own id to guarantee uniqueness
// Use toolCallId for tracking - it's unique per tool call from the LLM.
// Don't use generateId() here because frameworks can override it for
// message grouping (returning the same ID for all tools in a request),
// which would cause the Set to track only one tool instead of all.
const toolExecutionId = toolCall.toolCallId;
outstandingToolResults.add(toolExecutionId);

// Note: we don't await the tool execution here (by leaving out 'await' on recordSpan),
Expand Down