diff --git a/src/model/providers/openai/stream.ts b/src/model/providers/openai/stream.ts index 1d0cb14f..b599a7b2 100644 --- a/src/model/providers/openai/stream.ts +++ b/src/model/providers/openai/stream.ts @@ -186,8 +186,9 @@ function toolCallEvents( if (typeof record.id === "string") { current.id = record.id; } - if (typeof fn.name === "string") { - current.name = fn.name; + const name = readToolCallName(record, fn); + if (name) { + current.name = name; } if (!state.toolCalls.has(index)) { @@ -201,12 +202,13 @@ function toolCallEvents( }); } - if (typeof fn.arguments === "string") { - current.argumentsBuffer = `${current.argumentsBuffer ?? ""}${fn.arguments}`; + const argumentsDelta = readToolCallArgumentsDelta(record, fn); + if (argumentsDelta !== undefined) { + current.argumentsBuffer = `${current.argumentsBuffer ?? ""}${argumentsDelta}`; events.push({ type: "tool_call_delta", id: current.id ?? generateStreamToolCallId(index), - delta: fn.arguments, + delta: argumentsDelta, raw, }); } @@ -221,6 +223,18 @@ function finishToolCalls(state: OpenAIStreamState, raw: unknown): CanonicalModel const events: CanonicalModelEvent[] = []; for (const [index, toolCall] of state.toolCalls.entries()) { + const name = readNonEmptyString(toolCall.name); + if (!name) { + throw new ModelProviderError({ + provider: "openai", + protocol: "openai", + code: "missing_tool_name", + message: "OpenAI stream emitted a tool call without a function name.", + retryable: true, + raw, + }); + } + const rawArguments = toolCall.argumentsBuffer ?? "{}"; let input: unknown; try { @@ -255,7 +269,7 @@ function finishToolCalls(state: OpenAIStreamState, raw: unknown): CanonicalModel type: "tool_call_end", toolCall: { id: readNonEmptyString(toolCall.id) ?? generateStreamToolCallId(index), - name: toolCall.name ?? "", + name, input, raw, }, @@ -277,6 +291,27 @@ function readNonEmptyString(value: unknown): string | undefined { return typeof value === "string" && value.trim().length > 0 ? value : undefined; } +function readToolCallName( + record: Record, + fn: Record, +): string | undefined { + return readNonEmptyString(fn.name) + ?? readNonEmptyString(record.name) + ?? readNonEmptyString(record.function_name) + ?? readNonEmptyString(record.tool_name); +} + +function readToolCallArgumentsDelta( + record: Record, + fn: Record, +): string | undefined { + if (typeof fn.arguments === "string") return fn.arguments; + if (typeof record.arguments === "string") return record.arguments; + if (typeof record.input === "string") return record.input; + if (record.input !== undefined) return JSON.stringify(record.input); + return undefined; +} + function generateStreamToolCallId(index: number): string { return `call_${index}`; } diff --git a/tests/model/providers/openai/stream.test.ts b/tests/model/providers/openai/stream.test.ts new file mode 100644 index 00000000..29977ef5 --- /dev/null +++ b/tests/model/providers/openai/stream.test.ts @@ -0,0 +1,143 @@ +import test from "node:test"; +import assert from "node:assert/strict"; + +import { + createOpenAIStreamState, + normalizeOpenAIStreamEvent, +} from "../../../../src/model/providers/openai/stream.js"; +import { ModelProviderError } from "../../../../src/model/protocol/errors.js"; + +test("normalizeOpenAIStreamEvent parses standard OpenAI streaming tool calls", () => { + const state = createOpenAIStreamState(); + + normalizeOpenAIStreamEvent({ + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_1", + type: "function", + function: { + name: "read_file", + arguments: "{\"file_path\":\"README.md\"", + }, + }, + ], + }, + }, + ], + }, state); + + const events = normalizeOpenAIStreamEvent({ + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { + arguments: "}", + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + }, state); + + const toolCallEnd = events.find((event) => event.type === "tool_call_end"); + assert.ok(toolCallEnd); + if (toolCallEnd.type !== "tool_call_end") { + throw new Error(`Expected tool_call_end, got ${toolCallEnd.type}`); + } + assert.equal(toolCallEnd.toolCall.id, "call_1"); + assert.equal(toolCallEnd.toolCall.name, "read_file"); + assert.deepEqual(toolCallEnd.toolCall.input, { file_path: "README.md" }); +}); + +test("normalizeOpenAIStreamEvent accepts top-level tool call name and arguments variants", () => { + const state = createOpenAIStreamState(); + + normalizeOpenAIStreamEvent({ + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_2", + name: "read_file", + arguments: "{\"file_path\":\"README.md\"", + }, + ], + }, + }, + ], + }, state); + + const events = normalizeOpenAIStreamEvent({ + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + input: "}", + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + }, state); + + const toolCallEnd = events.find((event) => event.type === "tool_call_end"); + assert.ok(toolCallEnd); + if (toolCallEnd.type !== "tool_call_end") { + throw new Error(`Expected tool_call_end, got ${toolCallEnd.type}`); + } + assert.equal(toolCallEnd.toolCall.id, "call_2"); + assert.equal(toolCallEnd.toolCall.name, "read_file"); + assert.deepEqual(toolCallEnd.toolCall.input, { file_path: "README.md" }); +}); + +test("normalizeOpenAIStreamEvent rejects tool calls without a function name", () => { + const state = createOpenAIStreamState(); + + normalizeOpenAIStreamEvent({ + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_3", + function: { + arguments: "{\"file_path\":\"README.md\"}", + }, + }, + ], + }, + }, + ], + }, state); + + assert.throws( + () => normalizeOpenAIStreamEvent({ + choices: [ + { + delta: {}, + finish_reason: "tool_calls", + }, + ], + }, state), + (error: unknown) => { + assert.ok(error instanceof ModelProviderError); + assert.equal(error.error.code, "missing_tool_name"); + return true; + }, + ); +});