Skip to content

Commit 1dd8ca9

Browse files
committed
feat: tools context
1 parent a3295cd commit 1dd8ca9

File tree

11 files changed

+547
-31
lines changed

11 files changed

+547
-31
lines changed

examples/ts-react-chat/src/routes/api.tanchat.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export const Route = createFileRoute('/api/tanchat')({
5959
const abortController = new AbortController()
6060

6161
const body = await request.json()
62-
const { messages, data } = body
62+
const { messages, data, context } = body
6363

6464
// Extract provider, model, and conversationId from data
6565
const provider: Provider = data?.provider || 'openai'
@@ -112,6 +112,7 @@ export const Route = createFileRoute('/api/tanchat')({
112112
messages,
113113
abortController,
114114
conversationId,
115+
...(context !== undefined && { context }),
115116
})
116117
return toStreamResponse(stream, { abortController })
117118
} catch (error: any) {

packages/typescript/ai-client/src/chat-client.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ import type { AnyClientTool, ModelMessage, StreamChunk } from '@tanstack/ai'
1414
import type { ConnectionAdapter } from './connection-adapters'
1515
import type { ChatClientEventEmitter } from './events'
1616

17-
export class ChatClient {
17+
export class ChatClient<
18+
TTools extends ReadonlyArray<AnyClientTool> = any,
19+
TContext = unknown,
20+
> {
1821
private processor: StreamProcessor
1922
private connection: ConnectionAdapter
2023
private uniqueId: string
@@ -26,6 +29,7 @@ export class ChatClient {
2629
private clientToolsRef: { current: Map<string, AnyClientTool> }
2730
private currentStreamId: string | null = null
2831
private currentMessageId: string | null = null
32+
private context?: TContext
2933

3034
private callbacksRef: {
3135
current: {
@@ -39,9 +43,10 @@ export class ChatClient {
3943
}
4044
}
4145

42-
constructor(options: ChatClientOptions) {
46+
constructor(options: ChatClientOptions<TTools, TContext>) {
4347
this.uniqueId = options.id || this.generateUniqueId('chat')
4448
this.body = options.body || {}
49+
this.context = options.context
4550
this.connection = options.connection
4651
this.events = new DefaultChatClientEventEmitter(this.uniqueId)
4752

@@ -135,7 +140,7 @@ export class ChatClient {
135140
const clientTool = this.clientToolsRef.current.get(args.toolName)
136141
if (clientTool?.execute) {
137142
try {
138-
const output = await clientTool.execute(args.input)
143+
const output = await clientTool.execute(args.input, this.context)
139144
await this.addToolResult({
140145
toolCallId: args.toolCallId,
141146
tool: args.toolName,
@@ -298,10 +303,11 @@ export class ChatClient {
298303
// Call onResponse callback
299304
await this.callbacksRef.current.onResponse()
300305

301-
// Include conversationId in the body for server-side event correlation
306+
// Include conversationId and context in the body for server-side event correlation
302307
const bodyWithConversationId = {
303308
...this.body,
304309
conversationId: this.uniqueId,
310+
...(this.context !== undefined && { context: this.context }),
305311
}
306312

307313
// Connect and stream

packages/typescript/ai-client/src/types.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ export interface UIMessage<TTools extends ReadonlyArray<AnyClientTool> = any> {
133133

134134
export interface ChatClientOptions<
135135
TTools extends ReadonlyArray<AnyClientTool> = any,
136+
TContext = unknown,
136137
> {
137138
/**
138139
* Connection adapter for streaming
@@ -208,6 +209,25 @@ export interface ChatClientOptions<
208209
*/
209210
chunkStrategy?: ChunkStrategy
210211
}
212+
213+
/**
214+
* Context object that is automatically passed to all tool execute functions.
215+
*
216+
* This allows tools to access shared context (like user ID, local storage,
217+
* browser APIs, etc.) without needing to capture them via closures.
218+
* Works for both client and server tools.
219+
*
220+
* The context is also sent to the server in the request body, so server tools
221+
* can access the same context.
222+
*
223+
* @example
224+
* const client = new ChatClient({
225+
* connection: fetchServerSentEvents('/api/chat'),
226+
* context: { userId: '123', localStorage },
227+
* tools: [getUserData],
228+
* });
229+
*/
230+
context?: TContext
211231
}
212232

213233
export interface ChatRequestBody {

packages/typescript/ai-client/tests/chat-client.test.ts

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import { describe, expect, it, vi } from 'vitest'
2+
import { toolDefinition } from '@tanstack/ai'
3+
import { z } from 'zod'
24
import { ChatClient } from '../src/chat-client'
35
import {
46
createMockConnectionAdapter,
@@ -515,7 +517,7 @@ describe('ChatClient', () => {
515517
// Should have at least one call for the assistant message
516518
const assistantAppendedCall = messageAppendedCalls.find(([, data]) => {
517519
const payload = data as Record<string, unknown>
518-
return payload && payload.role === 'assistant'
520+
return payload.role === 'assistant'
519521
})
520522
expect(assistantAppendedCall).toBeDefined()
521523
})
@@ -585,4 +587,141 @@ describe('ChatClient', () => {
585587
expect(thinkingCalls.length).toBeGreaterThan(0)
586588
})
587589
})
590+
591+
describe('context support', () => {
592+
it('should pass context to client tool execute functions', async () => {
593+
interface TestContext {
594+
userId: string
595+
localStorage: {
596+
setItem: (key: string, value: string) => void
597+
getItem: (key: string) => string | null
598+
}
599+
}
600+
601+
const mockStorage = {
602+
setItem: vi.fn(),
603+
getItem: vi.fn(() => null),
604+
}
605+
606+
const testContext: TestContext = {
607+
userId: '123',
608+
localStorage: mockStorage,
609+
}
610+
611+
const executeFn = vi.fn(async (_args: any, context?: unknown) => {
612+
const ctx = context as TestContext | undefined
613+
if (ctx) {
614+
ctx.localStorage.setItem(
615+
`pref_${ctx.userId}_${_args.key}`,
616+
_args.value,
617+
)
618+
return { success: true }
619+
}
620+
return { success: false }
621+
})
622+
623+
const toolDef = toolDefinition({
624+
name: 'savePreference',
625+
description: 'Save user preference',
626+
inputSchema: z.object({
627+
key: z.string(),
628+
value: z.string(),
629+
}),
630+
outputSchema: z.object({
631+
success: z.boolean(),
632+
}),
633+
})
634+
635+
const tool = toolDef.client<TestContext>(executeFn)
636+
637+
const chunks = createToolCallChunks([
638+
{
639+
id: 'tool-1',
640+
name: 'savePreference',
641+
arguments: '{"key":"theme","value":"dark"}',
642+
},
643+
])
644+
const adapter = createMockConnectionAdapter({ chunks })
645+
646+
const client = new ChatClient({
647+
connection: adapter,
648+
tools: [tool],
649+
context: testContext,
650+
})
651+
652+
await client.sendMessage('Save my preference')
653+
654+
// Wait a bit for async tool execution
655+
await new Promise((resolve) => setTimeout(resolve, 10))
656+
657+
// Tool should have been called with context
658+
expect(executeFn).toHaveBeenCalled()
659+
const lastCall = executeFn.mock.calls[0]
660+
expect(lastCall?.[0]).toEqual({ key: 'theme', value: 'dark' })
661+
expect(lastCall?.[1]).toEqual(testContext)
662+
663+
// localStorage should have been called
664+
expect(mockStorage.setItem).toHaveBeenCalledWith('pref_123_theme', 'dark')
665+
})
666+
667+
it('should send context to server in request body', async () => {
668+
const testContext = {
669+
userId: '123',
670+
sessionId: 'session-456',
671+
}
672+
673+
let capturedBody: any = null
674+
const adapter = createMockConnectionAdapter({
675+
chunks: createTextChunks('Response'),
676+
onConnect: (_messages, body) => {
677+
capturedBody = body
678+
},
679+
})
680+
681+
const client = new ChatClient({
682+
connection: adapter,
683+
context: testContext,
684+
})
685+
686+
await client.sendMessage('Hello')
687+
688+
// Context should be in the request body
689+
expect(capturedBody).toBeDefined()
690+
expect(capturedBody.context).toEqual(testContext)
691+
})
692+
693+
it('should work without context (context is optional)', async () => {
694+
const executeFn = vi.fn(async (args: any) => {
695+
return { result: args.value }
696+
})
697+
698+
const toolDef = toolDefinition({
699+
name: 'simpleTool',
700+
description: 'Simple tool',
701+
inputSchema: z.object({
702+
value: z.string(),
703+
}),
704+
outputSchema: z.object({
705+
result: z.string(),
706+
}),
707+
})
708+
709+
const tool = toolDef.client(executeFn)
710+
711+
const chunks = createToolCallChunks([
712+
{ id: 'tool-1', name: 'simpleTool', arguments: '{"value":"test"}' },
713+
])
714+
const adapter = createMockConnectionAdapter({ chunks })
715+
716+
const client = new ChatClient({
717+
connection: adapter,
718+
tools: [tool],
719+
})
720+
721+
await client.sendMessage('Test')
722+
723+
// Tool should have been called without context
724+
expect(executeFn).toHaveBeenCalledWith({ value: 'test' }, undefined)
725+
})
726+
})
588727
})

packages/typescript/ai/src/core/chat.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import type {
2020

2121
interface ChatEngineConfig<
2222
TAdapter extends AIAdapter<any, any, any, any>,
23-
TParams extends ChatOptions<any, any> = ChatOptions<any>,
23+
TParams extends ChatOptions<any, any, any, any, any> = ChatOptions<any>,
2424
> {
2525
adapter: TAdapter
2626
systemPrompts?: Array<string>
@@ -32,7 +32,7 @@ type CyclePhase = 'processChat' | 'executeToolCalls'
3232

3333
class ChatEngine<
3434
TAdapter extends AIAdapter<any, any, any, any>,
35-
TParams extends ChatOptions<any, any> = ChatOptions<any>,
35+
TParams extends ChatOptions<any, any, any, any, any> = ChatOptions<any>,
3636
> {
3737
private readonly adapter: TAdapter
3838
private readonly params: TParams
@@ -45,6 +45,7 @@ class ChatEngine<
4545
private readonly streamId: string
4646
private readonly effectiveRequest?: Request | RequestInit
4747
private readonly effectiveSignal?: AbortSignal
48+
private readonly context?: TParams['context']
4849

4950
private messages: Array<ModelMessage>
5051
private iterationCount = 0
@@ -75,6 +76,7 @@ class ChatEngine<
7576
? { signal: config.params.abortController.signal }
7677
: undefined
7778
this.effectiveSignal = config.params.abortController?.signal
79+
this.context = config.params.context
7880
}
7981

8082
async *chat(): AsyncGenerator<StreamChunk> {
@@ -449,6 +451,7 @@ class ChatEngine<
449451
this.tools,
450452
approvals,
451453
clientToolResults,
454+
this.context,
452455
)
453456

454457
if (

packages/typescript/ai/src/tools/tool-calls.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ export class ToolCallManager {
110110
*/
111111
async *executeTools(
112112
doneChunk: DoneStreamChunk,
113+
context?: unknown,
113114
): AsyncGenerator<ToolResultStreamChunk, Array<ModelMessage>, void> {
114115
const toolCallsArray = this.getToolCalls()
115116
const toolResults: Array<ModelMessage> = []
@@ -141,8 +142,8 @@ export class ToolCallManager {
141142
}
142143
}
143144

144-
// Execute the tool
145-
let result = await tool.execute(args)
145+
// Execute the tool with context if available
146+
let result = await tool.execute(args, context)
146147

147148
// Validate output against outputSchema if provided
148149
if (tool.outputSchema && result !== undefined && result !== null) {
@@ -238,12 +239,14 @@ interface ExecuteToolCallsResult {
238239
* @param tools - Available tools with their configurations
239240
* @param approvals - Map of approval decisions (approval.id -> approved boolean)
240241
* @param clientResults - Map of client-side execution results (toolCallId -> result)
242+
* @param context - Optional context object to pass to tool execute functions
241243
*/
242244
export async function executeToolCalls(
243245
toolCalls: Array<ToolCall>,
244246
tools: ReadonlyArray<Tool>,
245247
approvals: Map<string, boolean> = new Map(),
246248
clientResults: Map<string, any> = new Map(),
249+
context?: unknown,
247250
): Promise<ExecuteToolCallsResult> {
248251
const results: Array<ToolResult> = []
249252
const needsApproval: Array<ApprovalRequest> = []
@@ -375,7 +378,7 @@ export async function executeToolCalls(
375378
// Execute after approval
376379
const startTime = Date.now()
377380
try {
378-
let result = await tool.execute(input)
381+
let result = await tool.execute(input, context)
379382
const duration = Date.now() - startTime
380383

381384
// Validate output against outputSchema if provided
@@ -433,7 +436,7 @@ export async function executeToolCalls(
433436
// CASE 3: Normal server tool - execute immediately
434437
const startTime = Date.now()
435438
try {
436-
let result = await tool.execute(input)
439+
let result = await tool.execute(input, context)
437440
const duration = Date.now() - startTime
438441

439442
// Validate output against outputSchema if provided

0 commit comments

Comments
 (0)