Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/ts-react-chat/src/routes/api.tanchat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ export const Route = createFileRoute('/api/tanchat')({
`[API Route] Using provider: ${provider}, model: ${selectedModel}`,
)

// Server-side context (e.g., database connections, user session)
// This is separate from client context and only used for server tools
const serverContext = {
// Add server-side context here if needed
// e.g., db, userId from session, etc.
}

const stream = chat({
adapter: adapter as any,
model: selectedModel as any,
Expand All @@ -112,6 +119,7 @@ export const Route = createFileRoute('/api/tanchat')({
messages,
abortController,
conversationId,
context: serverContext,
})
return toStreamResponse(stream, { abortController })
} catch (error: any) {
Expand Down
20 changes: 16 additions & 4 deletions packages/typescript/ai-client/src/chat-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@ import type {
ToolCallPart,
UIMessage,
} from './types'
import type { AnyClientTool, ModelMessage, StreamChunk } from '@tanstack/ai'
import type {
AnyClientTool,
ModelMessage,
StreamChunk,
ToolOptions,
} from '@tanstack/ai'
import type { ConnectionAdapter } from './connection-adapters'
import type { ChatClientEventEmitter } from './events'

export class ChatClient {
export class ChatClient<
TTools extends ReadonlyArray<AnyClientTool> = any,
TContext = unknown,
> {
private processor: StreamProcessor
private connection: ConnectionAdapter
private uniqueId: string
Expand All @@ -26,6 +34,7 @@ export class ChatClient {
private clientToolsRef: { current: Map<string, AnyClientTool> }
private currentStreamId: string | null = null
private currentMessageId: string | null = null
private options: Partial<ToolOptions<TContext>>

private callbacksRef: {
current: {
Expand All @@ -39,9 +48,10 @@ export class ChatClient {
}
}

constructor(options: ChatClientOptions) {
constructor(options: ChatClientOptions<TTools, TContext>) {
this.uniqueId = options.id || this.generateUniqueId('chat')
this.body = options.body || {}
this.options = { context: options.context }
this.connection = options.connection
this.events = new DefaultChatClientEventEmitter(this.uniqueId)

Expand Down Expand Up @@ -135,7 +145,9 @@ export class ChatClient {
const clientTool = this.clientToolsRef.current.get(args.toolName)
if (clientTool?.execute) {
try {
const output = await clientTool.execute(args.input)
const output = await clientTool.execute(args.input, {
context: this.options.context,
})
await this.addToolResult({
toolCallId: args.toolCallId,
tool: args.toolName,
Expand Down
19 changes: 19 additions & 0 deletions packages/typescript/ai-client/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ export interface UIMessage<TTools extends ReadonlyArray<AnyClientTool> = any> {

export interface ChatClientOptions<
TTools extends ReadonlyArray<AnyClientTool> = any,
TContext = unknown,
> {
/**
* Connection adapter for streaming
Expand Down Expand Up @@ -208,6 +209,24 @@ export interface ChatClientOptions<
*/
chunkStrategy?: ChunkStrategy
}

/**
* Context object that is automatically passed to client-side tool execute functions.
*
* This allows client tools to access shared context (like user ID, local storage,
* browser APIs, etc.) without needing to capture them via closures.
*
* Note: This context is only used for client-side tools. Server tools should receive
* their own context from the server-side chat() function.
*
* @example
* const client = new ChatClient({
* connection: fetchServerSentEvents('/api/chat'),
* context: { userId: '123', localStorage },
* tools: [clientTool],
* });
*/
context?: TContext
}

export interface ChatRequestBody {
Expand Down
147 changes: 146 additions & 1 deletion packages/typescript/ai-client/tests/chat-client.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { describe, expect, it, vi } from 'vitest'
import { toolDefinition } from '@tanstack/ai'
import { z } from 'zod'
import { ChatClient } from '../src/chat-client'
import {
createMockConnectionAdapter,
createTextChunks,
createThinkingChunks,
createToolCallChunks,
} from './test-utils'
import type { ToolOptions } from '@tanstack/ai'
import type { UIMessage } from '../src/types'

describe('ChatClient', () => {
Expand Down Expand Up @@ -515,7 +518,7 @@ describe('ChatClient', () => {
// Should have at least one call for the assistant message
const assistantAppendedCall = messageAppendedCalls.find(([, data]) => {
const payload = data as Record<string, unknown>
return payload && payload.role === 'assistant'
return payload.role === 'assistant'
})
expect(assistantAppendedCall).toBeDefined()
})
Expand Down Expand Up @@ -585,4 +588,146 @@ describe('ChatClient', () => {
expect(thinkingCalls.length).toBeGreaterThan(0)
})
})

describe('context support', () => {
it('should pass context to client tool execute functions', async () => {
interface TestContext {
userId: string
localStorage: {
setItem: (key: string, value: string) => void
getItem: (key: string) => string | null
}
}

const mockStorage = {
setItem: vi.fn(),
getItem: vi.fn(() => null),
}

const testContext: TestContext = {
userId: '123',
localStorage: mockStorage,
}

const executeFn = vi.fn(
async <TContext = unknown>(
_args: any,
options: ToolOptions<TContext>,
) => {
const ctx = options.context as TestContext
ctx.localStorage.setItem(
`pref_${ctx.userId}_${_args.key}`,
_args.value,
)
return { success: true }
},
)

const toolDef = toolDefinition({
name: 'savePreference',
description: 'Save user preference',
inputSchema: z.object({
key: z.string(),
value: z.string(),
}),
outputSchema: z.object({
success: z.boolean(),
}),
})

const tool = toolDef.client<TestContext>(executeFn)

const chunks = createToolCallChunks([
{
id: 'tool-1',
name: 'savePreference',
arguments: '{"key":"theme","value":"dark"}',
},
])
const adapter = createMockConnectionAdapter({ chunks })

const client = new ChatClient({
connection: adapter,
tools: [tool],
context: testContext,
})

await client.sendMessage('Save my preference')

// Wait a bit for async tool execution
await new Promise((resolve) => setTimeout(resolve, 10))

// Tool should have been called with context
expect(executeFn).toHaveBeenCalled()
const lastCall = executeFn.mock.calls[0]
expect(lastCall?.[0]).toEqual({ key: 'theme', value: 'dark' })
expect(lastCall?.[1]).toEqual({ context: testContext })

// localStorage should have been called
expect(mockStorage.setItem).toHaveBeenCalledWith('pref_123_theme', 'dark')
})

it('should not send context to server (context is only for client tools)', async () => {
const testContext = {
userId: '123',
sessionId: 'session-456',
}

let capturedBody: any = null
const adapter = createMockConnectionAdapter({
chunks: createTextChunks('Response'),
onConnect: (_messages, body) => {
capturedBody = body
},
})

const client = new ChatClient({
connection: adapter,
context: testContext,
})

await client.sendMessage('Hello')

// Context should NOT be in the request body (only used for client tools)
expect(capturedBody).toBeDefined()
expect(capturedBody.context).toBeUndefined()
})

it('should work without context (context is optional)', async () => {
const executeFn = vi.fn(async (args: any) => {
return { result: args.value }
})

const toolDef = toolDefinition({
name: 'simpleTool',
description: 'Simple tool',
inputSchema: z.object({
value: z.string(),
}),
outputSchema: z.object({
result: z.string(),
}),
})

const tool = toolDef.client(executeFn)

const chunks = createToolCallChunks([
{ id: 'tool-1', name: 'simpleTool', arguments: '{"value":"test"}' },
])
const adapter = createMockConnectionAdapter({ chunks })

const client = new ChatClient({
connection: adapter,
tools: [tool],
})

await client.sendMessage('Test')

// Tool should have been called without context
expect(executeFn).toHaveBeenCalledWith(
{ value: 'test' },
{ context: undefined },
)
})
})
})
9 changes: 7 additions & 2 deletions packages/typescript/ai/src/core/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ import type {
StreamChunk,
Tool,
ToolCall,
ToolOptions,
} from '../types'

interface ChatEngineConfig<
TAdapter extends AIAdapter<any, any, any, any>,
TParams extends ChatOptions<any, any> = ChatOptions<any>,
TParams extends ChatOptions<any, any, any, any, any> = ChatOptions<any>,
> {
adapter: TAdapter
systemPrompts?: Array<string>
Expand All @@ -32,7 +33,7 @@ type CyclePhase = 'processChat' | 'executeToolCalls'

class ChatEngine<
TAdapter extends AIAdapter<any, any, any, any>,
TParams extends ChatOptions<any, any> = ChatOptions<any>,
TParams extends ChatOptions<any, any, any, any, any> = ChatOptions<any>,
> {
private readonly adapter: TAdapter
private readonly params: TParams
Expand All @@ -45,6 +46,7 @@ class ChatEngine<
private readonly streamId: string
private readonly effectiveRequest?: Request | RequestInit
private readonly effectiveSignal?: AbortSignal
private readonly options: Partial<ToolOptions<TParams['context']>>

private messages: Array<ModelMessage>
private iterationCount = 0
Expand Down Expand Up @@ -75,6 +77,7 @@ class ChatEngine<
? { signal: config.params.abortController.signal }
: undefined
this.effectiveSignal = config.params.abortController?.signal
this.options = { context: config.params.context }
}

async *chat(): AsyncGenerator<StreamChunk> {
Expand Down Expand Up @@ -381,6 +384,7 @@ class ChatEngine<
this.tools,
approvals,
clientToolResults,
this.options,
)

if (
Expand Down Expand Up @@ -449,6 +453,7 @@ class ChatEngine<
this.tools,
approvals,
clientToolResults,
this.options,
)

if (
Expand Down
Loading