diff --git a/.changeset/list-syntax-toolcontext.md b/.changeset/list-syntax-toolcontext.md index 0a4e0bfa0..f50fa8216 100644 --- a/.changeset/list-syntax-toolcontext.md +++ b/.changeset/list-syntax-toolcontext.md @@ -1,5 +1,5 @@ --- -"@livekit/agents": minor +'@livekit/agents': minor --- -**BREAKING**: `Agent({ tools })` and `agent.updateTools()` now accept a flat list `(FunctionTool | ProviderDefinedTool)[]` instead of a `Record` map, and `llm.tool({ ... })` requires a `name` field. `ToolContext` is now a Python-parity class with `functionTools` / `providerTools` / `toolsets` accessors, plus `flatten()`, `hasTool(name)`, `getFunctionTool(name)`, `updateTools()`, `copy()`, and `equals()`. To match the Python reference, registering two **different** function-tool instances under the same `name` now throws `duplicate function name: ` instead of silently overriding the earlier entry; passing the **same instance** twice is a no-op. `agent.toolCtx` returns a defensive copy so callers can no longer mutate the agent's internal state. `LLM.chat({ toolCtx })` accepts either a `ToolContext` instance or a raw `(FunctionTool | ProviderDefinedTool)[]` array (`ToolCtxInput`) and normalizes it internally, so callers don't have to construct a `ToolContext` themselves. Stateful `Toolset` containers are not part of this release — the `toolsets` accessor currently returns an empty list and `TODO`s in `tool_context.ts` mark every site where Python's Toolset support will plug in later. +**BREAKING**: `Agent({ tools })` and `agent.updateTools()` now accept a flat list `(FunctionTool | ProviderDefinedTool | Toolset)[]` instead of a `Record` map, and `llm.tool({ ... })` requires a `name` field. `ToolContext` is now a Python-parity class with `functionTools` / `providerTools` / `toolsets` accessors, plus `flatten()`, `hasTool(name)`, `getFunctionTool(name)`, `updateTools()`, `copy()`, and `equals()`. To match the Python reference, registering two **different** function-tool instances under the same `name` now throws `duplicate function name: ` instead of silently overriding the earlier entry; passing the **same instance** twice is a no-op. `agent.toolCtx` returns a defensive copy so callers can no longer mutate the agent's internal state. `LLM.chat({ toolCtx })` accepts either a `ToolContext` instance or a raw `(FunctionTool | ProviderDefinedTool | Toolset)[]` array (`ToolCtxInput`) and normalizes it internally, so callers don't have to construct a `ToolContext` themselves. diff --git a/.changeset/quick-meals-breathe.md b/.changeset/quick-meals-breathe.md new file mode 100644 index 000000000..d32233b93 --- /dev/null +++ b/.changeset/quick-meals-breathe.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Adds base `Toolset` support: a stateful container for a group of tools with `setup()` / `aclose()` lifecycle hooks. Toolsets can be passed directly into `Agent({ tools: [...] })` alongside individual function tools; their tools are flattened into the agent's `ToolContext` and the runtime drives `setup()` on activity start, `aclose()` on close, and a setup/close diff when `agent.updateTools()` adds or removes Toolsets mid-session. Per-toolset `setup()` errors are logged but do not abort the activity. The `IGNORE_ON_ENTER` flag is also respected for function tools nested inside a Toolset. Every LLM and realtime plugin tool builder iterates `ToolContext.flatten()` so toolset-contributed tools are correctly advertised. Also exports `ToolCalledEvent` / `ToolCompletedEvent` payload types. diff --git a/agents/src/llm/index.ts b/agents/src/llm/index.ts index 4837f2cb9..bbb77a2fa 100644 --- a/agents/src/llm/index.ts +++ b/agents/src/llm/index.ts @@ -10,12 +10,15 @@ export { ToolContext, ToolError, ToolFlag, + Toolset, toToolContext, type AgentHandoff, type FunctionTool, type ProviderDefinedTool, type Tool, + type ToolCalledEvent, type ToolChoice, + type ToolCompletedEvent, type ToolContextEntry, type ToolCtxInput, type ToolOptions, diff --git a/agents/src/llm/tool_context.test.ts b/agents/src/llm/tool_context.test.ts index 4d9cdb83d..7a38b8b5c 100644 --- a/agents/src/llm/tool_context.test.ts +++ b/agents/src/llm/tool_context.test.ts @@ -5,7 +5,7 @@ import { describe, expect, it } from 'vitest'; import { z } from 'zod'; import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; -import { ToolContext, type ToolOptions, tool } from './tool_context.js'; +import { ToolContext, type ToolOptions, Toolset, tool } from './tool_context.js'; import { createToolOptions, oaiParams } from './utils.js'; describe('Tool Context', () => { @@ -580,3 +580,78 @@ describe('ToolContext', () => { expect(ctx.flatten()).toEqual([b, a, provider]); }); }); + +describe('Toolset', () => { + const makeFn = (name: string) => + tool({ + name, + description: `${name} tool`, + execute: async () => name, + }); + + it('exposes its id and the tools it was constructed with', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const ts = new Toolset({ id: 'set1', tools: [a, b] }); + + expect(ts.id).toBe('set1'); + expect(ts.tools).toEqual([a, b]); + }); + + it('default setup and aclose are no-ops', async () => { + const ts = new Toolset({ id: 'noop', tools: [] }); + await expect(ts.setup()).resolves.toBeUndefined(); + await expect(ts.aclose()).resolves.toBeUndefined(); + }); + + it('lets subclasses override lifecycle hooks', async () => { + const events: string[] = []; + class Recording extends Toolset { + override async setup(): Promise { + events.push(`setup:${this.id}`); + } + override async aclose(): Promise { + events.push(`close:${this.id}`); + } + } + + const ts = new Recording({ id: 'rec', tools: [] }); + await ts.setup(); + await ts.aclose(); + expect(events).toEqual(['setup:rec', 'close:rec']); + }); + + it('is flattened into a ToolContext: function tools merged, toolset tracked', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const ts = new Toolset({ id: 'set', tools: [a, b] }); + const direct = makeFn('direct'); + + const ctx = new ToolContext([direct, ts]); + + expect(Object.keys(ctx.functionTools).sort()).toEqual(['a', 'b', 'direct']); + expect(ctx.toolsets).toEqual([ts]); + }); + + it('throws when a Toolset contributes a duplicate function name', () => { + // Mirrors Python's `add_tool`: a name collision between top-level and toolset-contributed + // tools is an error, not silent overwrite. + const a1 = makeFn('a'); + const a2 = makeFn('a'); + const ts = new Toolset({ id: 'collides', tools: [a2] }); + + expect(() => new ToolContext([a1, ts])).toThrow(/duplicate function name: a/); + }); + + it('equals() compares toolsets as identity sets, not by order', () => { + // Matches Python's `{id(ts) for ts in self._tool_sets}` semantics. + const ts1 = new Toolset({ id: 'one', tools: [] }); + const ts2 = new Toolset({ id: 'two', tools: [] }); + + expect(new ToolContext([ts1, ts2]).equals(new ToolContext([ts2, ts1]))).toBe(true); + + const ts3 = new Toolset({ id: 'three', tools: [] }); + expect(new ToolContext([ts1, ts2]).equals(new ToolContext([ts1, ts3]))).toBe(false); + expect(new ToolContext([ts1]).equals(new ToolContext([ts1, ts2]))).toBe(false); + }); +}); diff --git a/agents/src/llm/tool_context.ts b/agents/src/llm/tool_context.ts index df714d57a..f2153a487 100644 --- a/agents/src/llm/tool_context.ts +++ b/agents/src/llm/tool_context.ts @@ -196,6 +196,43 @@ export interface FunctionTool< [FUNCTION_TOOL_SYMBOL]: true; } +export interface ToolCalledEvent { + ctx: RunContext; + arguments: Record; +} + +export interface ToolCompletedEvent { + ctx: RunContext; + output?: { type: 'output'; value: unknown } | { type: 'error'; value: Error }; +} + +/** + * A stateful collection of tools sharing a lifecycle. Tools registered through a `Toolset` are + * flattened into the surrounding `ToolContext`, while the `Toolset` itself is tracked so its + * `setup()` / `aclose()` hooks can be invoked by the agent runtime. + */ +export class Toolset { + readonly #id: string; + readonly #tools: Tool[]; + + constructor({ id, tools }: { id: string; tools: readonly Tool[] }) { + this.#id = id; + this.#tools = [...tools]; + } + + get id(): string { + return this.#id; + } + + get tools(): readonly Tool[] { + return this.#tools; + } + + async setup(): Promise {} + + async aclose(): Promise {} +} + /** * Convenience input shape accepted by APIs that want to take a list of tools directly without * forcing callers to wrap them in `new ToolContext(...)`. @@ -217,24 +254,18 @@ export function toToolContext( return input instanceof ToolContext ? input : new ToolContext(input); } -//TODO: toolset - accept stateful `Toolset` containers alongside `FunctionTool` / // eslint-disable-next-line @typescript-eslint/no-explicit-any -- ToolContext entries accept any function-tool parameter/result types export type ToolContextEntry = // eslint-disable-next-line @typescript-eslint/no-explicit-any - FunctionTool | ProviderDefinedTool; + FunctionTool | ProviderDefinedTool | Toolset; export class ToolContext { - // TODO: toolset - widen entries to `FunctionTool | ProviderDefinedTool | Toolset` once Toolset - // lands so this stays heterogeneous like Python's `Sequence[Tool | Toolset]`. private _tools: ToolContextEntry[] = []; // eslint-disable-next-line @typescript-eslint/no-explicit-any -- ToolContext stores generic function tools private _functionToolsMap: Map> = new Map(); private _providerTools: ProviderDefinedTool[] = []; - // TODO: toolset - populate when Toolset support is supported. - // so the `toolsets` getter and `equals` toolset-identity check stay byte-compatible with the - private _toolSets: unknown[] = []; + private _toolsets: Toolset[] = []; - // TODO: toolset - widen `tools` to `Sequence` once Toolset lands. constructor(tools: readonly ToolContextEntry[] = []) { this.updateTools(tools); } @@ -254,13 +285,9 @@ export class ToolContext { return this._providerTools; } - /** - * A copy of all tool sets in the tool context. - * - * TODO: toolset - wire up once Toolset is ported. - */ - get toolsets(): unknown[] { - return this._toolSets; + /** A copy of all toolsets registered in the context. */ + get toolsets(): readonly Toolset[] { + return [...this._toolsets]; } /** @@ -287,16 +314,22 @@ export class ToolContext { return this._providerTools.some((tool) => tool.id === name); } - // TODO: toolset - widen `tools` to `Sequence` once Toolset lands. updateTools(tools: readonly ToolContextEntry[]): void { this._tools = [...tools]; this._functionToolsMap = new Map(); this._providerTools = []; - this._toolSets = []; + this._toolsets = []; - // Mirrors Python's recursive `add_tool` (minus Toolset flattening, which is TODO). // eslint-disable-next-line @typescript-eslint/no-explicit-any -- accepts any tool shape const addTool = (tool: any): void => { + if (tool instanceof Toolset) { + for (const inner of tool.tools) { + addTool(inner); + } + this._toolsets.push(tool); + return; + } + if (isProviderDefinedTool(tool)) { this._providerTools.push(tool); return; @@ -314,15 +347,9 @@ export class ToolContext { return; } - // TODO: toolset - if (tool instanceof Toolset) { for (const t of tool.tools) addTool(t); - // this._toolSets.push(tool); return; } - throw new Error(`unknown tool type: ${typeof tool}`); }; - // TODO: toolset - Python also chains `find_function_tools(self)` here so subclasses can - // declare tools as class members. JS doesn't use that decorator pattern, so we only walk - // the explicit input list. for (const tool of tools) { addTool(tool); } @@ -352,10 +379,15 @@ export class ToolContext { return false; } } - // TODO: toolset - once Toolset lands, also compare `_toolSets` as identity sets per Python - // self_tool_set_ids = {id(ts) for ts in self._tool_sets} - // other_tool_set_ids = {id(ts) for ts in other._tool_sets} - // if self_tool_set_ids != other_tool_set_ids: return False + if (this._toolsets.length !== other._toolsets.length) { + return false; + } + const otherToolsets = new Set(other._toolsets); + for (const ts of this._toolsets) { + if (!otherToolsets.has(ts)) { + return false; + } + } return true; } } diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index 3068b408d..012c026ee 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -31,14 +31,17 @@ import { type InputSpeechStartedEvent, type InputSpeechStoppedEvent, type InputTranscriptionCompleted, + isFunctionTool, LLM, RealtimeModel, type RealtimeModelError, type RealtimeSession, + type Tool, type ToolChoice, ToolContext, type ToolContextEntry, ToolFlag, + Toolset, } from '../llm/index.js'; import type { LLMError } from '../llm/llm.js'; import { isSameToolChoice } from '../llm/tool_context.js'; @@ -215,6 +218,7 @@ export class AgentActivity implements RecognitionHooks { private toolChoice: ToolChoice | null = null; private _preemptiveGeneration?: PreemptiveGeneration; private _preemptiveGenerationCount = 0; + private _toolsetsSetup = false; private interruptionDetector?: AdaptiveInterruptionDetector; private isInterruptionDetectionEnabled: boolean; private isInterruptionByAudioActivityEnabled: boolean; @@ -421,6 +425,8 @@ export class AgentActivity implements RecognitionHooks { this.agent._agentActivity = this; + await this.setupToolsets(); + if (this.llm instanceof RealtimeModel) { const rtReused = reuseResources?.rtSession !== undefined; @@ -767,13 +773,20 @@ export class AgentActivity implements RecognitionHooks { } async updateTools(tools: readonly ToolContextEntry[]): Promise { - const oldToolNames = new Set(Object.keys(this.agent._toolCtx.functionTools)); + const oldToolCtx = this.agent._toolCtx; + const oldToolNames = new Set(Object.keys(oldToolCtx.functionTools)); + const oldToolsets = oldToolCtx.toolsets; const newToolCtx = new ToolContext(tools); const newToolNames = new Set(Object.keys(newToolCtx.functionTools)); + const newToolsets = newToolCtx.toolsets; const toolsAdded = [...newToolNames].filter((name) => !oldToolNames.has(name)); const toolsRemoved = [...oldToolNames].filter((name) => !newToolNames.has(name)); + const addedToolsets = newToolsets.filter((ts) => !oldToolsets.includes(ts)); + const removedToolsets = oldToolsets.filter((ts) => !newToolsets.includes(ts)); + await this.setupToolsetList(addedToolsets); this.agent._toolCtx = newToolCtx; + await this.closeToolsetList(removedToolsets); if (toolsAdded.length > 0 || toolsRemoved.length > 0) { const configUpdate = new AgentConfigUpdate({ @@ -1735,11 +1748,13 @@ export class AgentActivity implements RecognitionHooks { const tools: ToolContext = shouldFilterTools ? new ToolContext( - this.agent.toolCtx.tools.filter((t) => { - if (t.type === 'function') { - return !(t.flags & ToolFlag.IGNORE_ON_ENTER); + this.agent.toolCtx.tools.flatMap((t): ToolContextEntry[] => { + const keepFn = (fn: Tool): boolean => + !isFunctionTool(fn) || !(fn.flags & ToolFlag.IGNORE_ON_ENTER); + if (t instanceof Toolset) { + return t.tools.filter(keepFn) as ToolContextEntry[]; } - return true; + return keepFn(t) ? [t] : []; }), ) : this.agent.toolCtx; @@ -3728,9 +3743,42 @@ export class AgentActivity implements RecognitionHooks { this.realtimeSpans?.clear(); await this.realtimeSession?.close(); await this.audioRecognition?.close(); + await this.closeToolsets(); this.realtimeSession = undefined; this.audioRecognition = undefined; } + + private async setupToolsets(): Promise { + // Guard against resume() re-entering _startSession on an activity whose toolsets are + // already initialized. + if (this._toolsetsSetup) return; + this._toolsetsSetup = true; + await this.setupToolsetList(this.agent.toolCtx.toolsets); + } + + private async closeToolsets(): Promise { + if (!this._toolsetsSetup) return; + this._toolsetsSetup = false; + await this.closeToolsetList(this.agent.toolCtx.toolsets); + } + + private async setupToolsetList(toolsets: readonly Toolset[]): Promise { + const outputs = await Promise.allSettled(toolsets.map((ts) => ts.setup())); + for (const output of outputs) { + if (output.status === 'rejected') { + this.logger.error({ error: output.reason }, 'error setting up toolset'); + } + } + } + + private async closeToolsetList(toolsets: readonly Toolset[]): Promise { + const outputs = await Promise.allSettled(toolsets.map((ts) => ts.aclose())); + for (const output of outputs) { + if (output.status === 'rejected') { + this.logger.error({ error: output.reason }, 'error closing toolset'); + } + } + } } function toOaiToolChoice(toolChoice: ToolChoice | null): ToolChoice | undefined { diff --git a/examples/src/basic_toolsets.ts b/examples/src/basic_toolsets.ts new file mode 100644 index 000000000..cb0cf2b6e --- /dev/null +++ b/examples/src/basic_toolsets.ts @@ -0,0 +1,180 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { + type JobContext, + type JobProcess, + ServerOptions, + cli, + defineAgent, + inference, + llm, + voice, +} from '@livekit/agents'; +import * as livekit from '@livekit/agents-plugin-livekit'; +import * as silero from '@livekit/agents-plugin-silero'; +import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; +import { fileURLToPath } from 'node:url'; +import { z } from 'zod'; + +class InfoTask extends voice.AgentTask { + private key: string; + + constructor(key: string, sharedToolset: llm.Toolset) { + super({ + instructions: `Collect the user's ${key}. Once you have it, call saveUserInfo IMMEDIATELY. No chitchat.`, + tools: [ + sharedToolset, + llm.tool({ + name: 'saveUserInfo', + description: `Save the user's ${key} to the database`, + parameters: z.object({ + [key]: z.string(), + }), + execute: async (args) => { + this.complete(args[key] as string); + return `Thanks, collected ${key} successfully: ${args[key]}`; + }, + }), + ], + }); + this.key = key; + } + + async onEnter() { + this.session.generateReply({ userInput: `Ask the user for their ${this.key}` }); + } +} + +function makeWeatherAgent(returnHome: () => voice.Agent) { + const weatherToolset = new llm.Toolset({ + id: 'weather_tools', + tools: [ + llm.tool({ + name: 'getWeather', + description: 'Get the weather for a given location', + parameters: z.object({ location: z.string() }), + execute: async ({ location }) => `The weather in ${location} is sunny today.`, + }), + ], + }); + + return new voice.Agent({ + instructions: 'You are a weather agent. Provide weather information then hand back when done.', + tools: [ + weatherToolset, + llm.tool({ + name: 'finishWeatherConversation', + description: 'Call this when you want to finish the weather conversation', + execute: async () => { + return llm.handoff({ agent: returnHome(), returns: 'Transfer back to main agent.' }); + }, + }), + ], + }); +} + +class MainAgent extends voice.Agent { + private locationToolset: llm.Toolset; + + constructor(locationToolset: llm.Toolset) { + super({ + instructions: + 'You are a helpful assistant. Use the location toolset for weather/timezone. Use transferToWeather when the user asks about weather. Use swapToolset / reapplyTools to exercise updateTools.', + tools: [ + locationToolset, + llm.tool({ + name: 'transferToWeather', + description: 'Call this when the user wants to know the weather', + execute: async () => { + return llm.handoff({ + agent: makeWeatherAgent(() => new MainAgent(locationToolset)), + returns: "Let's switch to the weather agent.", + }); + }, + }), + llm.tool({ + name: 'swapToolset', + description: 'Replace the active toolset with a brand-new toolset (tests updateTools).', + execute: async () => { + const replacement = new llm.Toolset({ + id: 'location_tools_v2', + tools: [ + llm.tool({ + name: 'getWeather', + description: 'v2 weather', + parameters: z.object({ location: z.string() }), + execute: async ({ location }) => `v2: ${location} -> sunny`, + }), + ], + }); + await this.updateTools([replacement]); + return 'Swapped toolset.'; + }, + }), + llm.tool({ + name: 'reapplyTools', + description: 'Re-apply the current tool list unchanged (idempotent updateTools).', + execute: async () => { + await this.updateTools([...this.toolCtx.tools]); + return 'Re-applied the same tool list.'; + }, + }), + ], + }); + this.locationToolset = locationToolset; + } + + async onEnter() { + const name = await new InfoTask('name', this.locationToolset).run(); + await this.session.say( + `Got it, ${name}. Ask me about weather, or say "swap" / "reapply" to exercise updateTools.`, + ); + } +} + +export default defineAgent({ + prewarm: async (proc: JobProcess) => { + proc.userData.vad = await silero.VAD.load(); + }, + entry: async (ctx: JobContext) => { + const locationToolset = new llm.Toolset({ + id: 'location_tools', + tools: [ + llm.tool({ + name: 'getWeather', + description: 'Get the weather for a given location.', + parameters: z.object({ location: z.string() }), + execute: async ({ location }) => `The weather in ${location} is sunny.`, + }), + llm.tool({ + name: 'lookupTimezone', + description: 'Look up the timezone for a city or region.', + parameters: z.object({ location: z.string() }), + execute: async ({ location }) => `${location} is in the America/Los_Angeles timezone.`, + }), + ], + }); + + const session = new voice.AgentSession({ + vad: ctx.proc.userData.vad! as silero.VAD, + stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), + llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), + tts: new inference.TTS({ + model: 'cartesia/sonic-3', + voice: '9626c31c-bec5-4cca-baa8-f8ba9e84c8bc', + }), + turnDetection: new livekit.turnDetector.MultilingualModel(), + }); + + await session.start({ + agent: new MainAgent(locationToolset), + room: ctx.room, + inputOptions: { noiseCancellation: BackgroundVoiceCancellation() }, + }); + + session.say('Hello! I will ask you a quick question, then we can chat.'); + }, +}); + +cli.runApp(new ServerOptions({ agent: fileURLToPath(import.meta.url) })); diff --git a/plugins/google/src/utils.ts b/plugins/google/src/utils.ts index 5548c076e..64a52a6c1 100644 --- a/plugins/google/src/utils.ts +++ b/plugins/google/src/utils.ts @@ -139,8 +139,10 @@ function isEmptyObjectSchema(jsonSchema: JSONSchema7Definition): boolean { export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclaration[] { const functionDeclarations: FunctionDeclaration[] = []; - for (const [name, tool] of Object.entries(toolCtx.functionTools)) { - const { description, parameters } = tool; + for (const tool of toolCtx.flatten()) { + // TODO: support provider-defined tools in the Gemini schema. + if (!llm.isFunctionTool(tool)) continue; + const { name, description, parameters } = tool; const jsonSchema = llm.toJsonSchema(parameters, false); // Create a deep copy to prevent the Google GenAI library from mutating the schema diff --git a/plugins/mistralai/src/llm.ts b/plugins/mistralai/src/llm.ts index f6685b042..f80bc8bcc 100644 --- a/plugins/mistralai/src/llm.ts +++ b/plugins/mistralai/src/llm.ts @@ -211,14 +211,16 @@ export class LLMStream extends llm.LLMStream { // eslint-disable-next-line @typescript-eslint/no-explicit-any const toolsList: any[] = []; - if (this.toolCtx && Object.keys(this.toolCtx.functionTools).length > 0) { - for (const [name, func] of Object.entries(this.toolCtx.functionTools)) { + if (this.toolCtx) { + for (const t of this.toolCtx.flatten()) { + // TODO: support provider-defined tools in the Mistral schema. + if (!llm.isFunctionTool(t)) continue; toolsList.push({ type: 'function' as const, function: { - name, - description: func.description, - parameters: llm.toJsonSchema(func.parameters, true, false), + name: t.name, + description: t.description, + parameters: llm.toJsonSchema(t.parameters, true, false), }, }); } diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 8481e0f47..94f1e2988 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -698,11 +698,12 @@ export class RealtimeSession extends llm.RealtimeSession { // TODO(brian): these logics below are noops I think, leaving it here to keep // parity with the python but we should remove them later const retainedToolNames = new Set(ev.session.tools.map((tool) => tool.name)); - const retainedTools = Object.entries(_tools.functionTools) - .filter(([name]) => retainedToolNames.has(name)) - .map(([, tool]) => tool); + // Keep provider tools and Toolsets as-is; only drop function tools the server didn't accept. + const retainedEntries = _tools.tools.filter( + (entry) => !llm.isFunctionTool(entry) || retainedToolNames.has(entry.name), + ); - this._tools = new llm.ToolContext(retainedTools); + this._tools = new llm.ToolContext(retainedEntries); unlock(); } @@ -710,21 +711,25 @@ export class RealtimeSession extends llm.RealtimeSession { private createToolsUpdateEvent(_tools: llm.ToolContext): api_proto.SessionUpdateEvent { const oaiTools: api_proto.Tool[] = []; - for (const [name, tool] of Object.entries(_tools.functionTools)) { - const { parameters: toolParameters, description } = tool; + for (const t of _tools.flatten()) { + // TODO: support provider-defined tools in the Realtime session-update schema. + if (!llm.isFunctionTool(t)) continue; try { const parameters = llm.toJsonSchema( - toolParameters, + t.parameters, ) as unknown as api_proto.Tool['parameters']; oaiTools.push({ - name, - description, + name: t.name, + description: t.description, parameters: parameters, type: 'function', }); } catch (e) { - this.#logger.error({ name, tool }, "OpenAI Realtime API doesn't support this tool type"); + this.#logger.error( + { name: t.name, tool: t }, + "OpenAI Realtime API doesn't support this tool type", + ); continue; } } diff --git a/plugins/openai/src/responses/llm.ts b/plugins/openai/src/responses/llm.ts index 9a255d046..20363a05f 100644 --- a/plugins/openai/src/responses/llm.ts +++ b/plugins/openai/src/responses/llm.ts @@ -186,25 +186,27 @@ class ResponsesHttpLLMStream extends llm.LLMStream { 'openai.responses', )) as OpenAI.Responses.ResponseInputItem[]; + // TODO: support provider-defined tools in the Responses schema. const tools = this.toolCtx - ? Object.entries(this.toolCtx.functionTools).map(([name, func]) => { - const oaiParams = { - type: 'function' as const, - name: name, - description: func.description, - parameters: llm.toJsonSchema( - func.parameters, - true, - this.strictToolSchema, - ) as unknown as OpenAI.Responses.FunctionTool['parameters'], - } as OpenAI.Responses.FunctionTool; - - if (this.strictToolSchema) { - oaiParams.strict = true; - } - - return oaiParams; - }) + ? this.toolCtx + .flatten() + .filter(llm.isFunctionTool) + .map((t) => { + const oaiParams = { + type: 'function' as const, + name: t.name, + description: t.description, + parameters: llm.toJsonSchema( + t.parameters, + true, + this.strictToolSchema, + ) as unknown as OpenAI.Responses.FunctionTool['parameters'], + } as OpenAI.Responses.FunctionTool; + if (this.strictToolSchema) { + oaiParams.strict = true; + } + return oaiParams; + }) : undefined; const requestOptions: Record = { ...this.modelOptions }; diff --git a/plugins/openai/src/ws/llm.ts b/plugins/openai/src/ws/llm.ts index d22d7a753..64f2d641f 100644 --- a/plugins/openai/src/ws/llm.ts +++ b/plugins/openai/src/ws/llm.ts @@ -429,25 +429,29 @@ export class WSLLMStream extends llm.LLMStream { 'openai.responses', )) as OpenAI.Responses.ResponseInputItem[]; + // TODO: support provider-defined tools in the Responses schema. const tools = this.toolCtx - ? Object.entries(this.toolCtx.functionTools).map(([name, func]) => { - const oaiParams = { - type: 'function' as const, - name, - description: func.description, - parameters: llm.toJsonSchema( - func.parameters, - true, - this.#strictToolSchema, - ) as unknown as OpenAI.Responses.FunctionTool['parameters'], - } as OpenAI.Responses.FunctionTool; - - if (this.#strictToolSchema) { - oaiParams.strict = true; - } - - return oaiParams; - }) + ? this.toolCtx + .flatten() + .filter(llm.isFunctionTool) + .map((t) => { + const oaiParams = { + type: 'function' as const, + name: t.name, + description: t.description, + parameters: llm.toJsonSchema( + t.parameters, + true, + this.#strictToolSchema, + ) as unknown as OpenAI.Responses.FunctionTool['parameters'], + } as OpenAI.Responses.FunctionTool; + + if (this.#strictToolSchema) { + oaiParams.strict = true; + } + + return oaiParams; + }) : undefined; const requestOptions: Record = { ...this.#modelOptions }; diff --git a/plugins/phonic/src/realtime/realtime_model.ts b/plugins/phonic/src/realtime/realtime_model.ts index 09933b580..fd0e9baf9 100644 --- a/plugins/phonic/src/realtime/realtime_model.ts +++ b/plugins/phonic/src/realtime/realtime_model.ts @@ -368,23 +368,25 @@ export class RealtimeSession extends llm.RealtimeSession { } this._tools = tools.copy(); - this.toolDefinitions = Object.entries(tools.functionTools).map(([name, tool]) => ({ - type: 'custom_websocket', - tool_schema: { - type: 'function', - function: { - name, - description: tool.description, - parameters: llm.toJsonSchema(tool.parameters), - strict: true, + // TODO: support provider-defined tools in the Phonic schema. + this.toolDefinitions = tools + .flatten() + .filter(llm.isFunctionTool) + .map((t) => ({ + type: 'custom_websocket' as const, + tool_schema: { + type: 'function' as const, + function: { + name: t.name, + description: t.description, + parameters: llm.toJsonSchema(t.parameters), + strict: true, + }, }, - }, - tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, - // Tool chaining and tool calls during speech are not supported at this time - // for ease of implementation within the RealtimeSession generations framework - wait_for_speech_before_tool_call: true, - allow_tool_chaining: false, - })); + tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, + wait_for_speech_before_tool_call: true, + allow_tool_chaining: false, + })); this.toolsReady.resolve(); } @@ -404,21 +406,25 @@ export class RealtimeSession extends llm.RealtimeSession { } if (tools !== undefined) { this._tools = tools.copy(); - this.toolDefinitions = Object.entries(tools.functionTools).map(([name, tool]) => ({ - type: 'custom_websocket', - tool_schema: { - type: 'function', - function: { - name, - description: tool.description, - parameters: llm.toJsonSchema(tool.parameters), - strict: true, + // TODO: support provider-defined tools in the Phonic schema. + this.toolDefinitions = tools + .flatten() + .filter(llm.isFunctionTool) + .map((t) => ({ + type: 'custom_websocket' as const, + tool_schema: { + type: 'function' as const, + function: { + name: t.name, + description: t.description, + parameters: llm.toJsonSchema(t.parameters), + strict: true, + }, }, - }, - tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, - wait_for_speech_before_tool_call: true, - allow_tool_chaining: false, - })); + tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, + wait_for_speech_before_tool_call: true, + allow_tool_chaining: false, + })); } if (chatCtx !== undefined) { this._chatCtx = chatCtx.copy(); diff --git a/plugins/test/src/llm.ts b/plugins/test/src/llm.ts index 534dce2df..1f912654a 100644 --- a/plugins/test/src/llm.ts +++ b/plugins/test/src/llm.ts @@ -200,6 +200,57 @@ export const llm = async (llm: llmlib.LLM, skipOptionalArgs: boolean) => { expect(JSON.parse(calls[0]!.args).address).toBeUndefined(); }); }); + + describe('toolset', async () => { + const buildToolsetContext = () => { + const weatherToolset = new llmlib.Toolset({ + id: 'weather_toolset', + tools: [ + llmlib.tool({ + name: 'getWeather', + description: 'Get the current weather in a given location', + parameters: z.object({ + location: z.string().describe('The city and state, e.g. San Francisco, CA'), + unit: z.enum(['celsius', 'fahrenheit']).describe('The temperature unit to use'), + }), + execute: async () => {}, + }), + ], + }); + + const directTool = llmlib.tool({ + name: 'playMusic', + description: 'Play music', + parameters: z.object({ + name: z.string().describe('The artist and name of the song'), + }), + execute: async () => {}, + }); + + return new llmlib.ToolContext([weatherToolset, directTool]); + }; + + it('should call a function tool that lives inside a Toolset', async () => { + const ctx = buildToolsetContext(); + const calls = await requestFncCall( + llm, + "What's the weather in San Francisco, in Celsius?", + ctx, + ); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.name).toStrictEqual('getWeather'); + expect(JSON.parse(calls[0]!.args).unit).toStrictEqual('celsius'); + }); + + it('should expose direct tools alongside Toolset tools', async () => { + const ctx = buildToolsetContext(); + const calls = await requestFncCall(llm, 'Play the song "Bohemian Rhapsody" by Queen.', ctx); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.name).toStrictEqual('playMusic'); + }); + }); }); };