|
1 | 1 | import { describe, expect, it, vi } from 'vitest' |
| 2 | +import { toolDefinition } from '@tanstack/ai' |
| 3 | +import { z } from 'zod' |
2 | 4 | import { ChatClient } from '../src/chat-client' |
3 | 5 | import { |
4 | 6 | createMockConnectionAdapter, |
@@ -515,7 +517,7 @@ describe('ChatClient', () => { |
515 | 517 | // Should have at least one call for the assistant message |
516 | 518 | const assistantAppendedCall = messageAppendedCalls.find(([, data]) => { |
517 | 519 | const payload = data as Record<string, unknown> |
518 | | - return payload && payload.role === 'assistant' |
| 520 | + return payload.role === 'assistant' |
519 | 521 | }) |
520 | 522 | expect(assistantAppendedCall).toBeDefined() |
521 | 523 | }) |
@@ -585,4 +587,141 @@ describe('ChatClient', () => { |
585 | 587 | expect(thinkingCalls.length).toBeGreaterThan(0) |
586 | 588 | }) |
587 | 589 | }) |
| 590 | + |
| 591 | + describe('context support', () => { |
| 592 | + it('should pass context to client tool execute functions', async () => { |
| 593 | + interface TestContext { |
| 594 | + userId: string |
| 595 | + localStorage: { |
| 596 | + setItem: (key: string, value: string) => void |
| 597 | + getItem: (key: string) => string | null |
| 598 | + } |
| 599 | + } |
| 600 | + |
| 601 | + const mockStorage = { |
| 602 | + setItem: vi.fn(), |
| 603 | + getItem: vi.fn(() => null), |
| 604 | + } |
| 605 | + |
| 606 | + const testContext: TestContext = { |
| 607 | + userId: '123', |
| 608 | + localStorage: mockStorage, |
| 609 | + } |
| 610 | + |
| 611 | + const executeFn = vi.fn(async (_args: any, context?: unknown) => { |
| 612 | + const ctx = context as TestContext | undefined |
| 613 | + if (ctx) { |
| 614 | + ctx.localStorage.setItem( |
| 615 | + `pref_${ctx.userId}_${_args.key}`, |
| 616 | + _args.value, |
| 617 | + ) |
| 618 | + return { success: true } |
| 619 | + } |
| 620 | + return { success: false } |
| 621 | + }) |
| 622 | + |
| 623 | + const toolDef = toolDefinition({ |
| 624 | + name: 'savePreference', |
| 625 | + description: 'Save user preference', |
| 626 | + inputSchema: z.object({ |
| 627 | + key: z.string(), |
| 628 | + value: z.string(), |
| 629 | + }), |
| 630 | + outputSchema: z.object({ |
| 631 | + success: z.boolean(), |
| 632 | + }), |
| 633 | + }) |
| 634 | + |
| 635 | + const tool = toolDef.client<TestContext>(executeFn) |
| 636 | + |
| 637 | + const chunks = createToolCallChunks([ |
| 638 | + { |
| 639 | + id: 'tool-1', |
| 640 | + name: 'savePreference', |
| 641 | + arguments: '{"key":"theme","value":"dark"}', |
| 642 | + }, |
| 643 | + ]) |
| 644 | + const adapter = createMockConnectionAdapter({ chunks }) |
| 645 | + |
| 646 | + const client = new ChatClient({ |
| 647 | + connection: adapter, |
| 648 | + tools: [tool], |
| 649 | + context: testContext, |
| 650 | + }) |
| 651 | + |
| 652 | + await client.sendMessage('Save my preference') |
| 653 | + |
| 654 | + // Wait a bit for async tool execution |
| 655 | + await new Promise((resolve) => setTimeout(resolve, 10)) |
| 656 | + |
| 657 | + // Tool should have been called with context |
| 658 | + expect(executeFn).toHaveBeenCalled() |
| 659 | + const lastCall = executeFn.mock.calls[0] |
| 660 | + expect(lastCall?.[0]).toEqual({ key: 'theme', value: 'dark' }) |
| 661 | + expect(lastCall?.[1]).toEqual(testContext) |
| 662 | + |
| 663 | + // localStorage should have been called |
| 664 | + expect(mockStorage.setItem).toHaveBeenCalledWith('pref_123_theme', 'dark') |
| 665 | + }) |
| 666 | + |
| 667 | + it('should send context to server in request body', async () => { |
| 668 | + const testContext = { |
| 669 | + userId: '123', |
| 670 | + sessionId: 'session-456', |
| 671 | + } |
| 672 | + |
| 673 | + let capturedBody: any = null |
| 674 | + const adapter = createMockConnectionAdapter({ |
| 675 | + chunks: createTextChunks('Response'), |
| 676 | + onConnect: (_messages, body) => { |
| 677 | + capturedBody = body |
| 678 | + }, |
| 679 | + }) |
| 680 | + |
| 681 | + const client = new ChatClient({ |
| 682 | + connection: adapter, |
| 683 | + context: testContext, |
| 684 | + }) |
| 685 | + |
| 686 | + await client.sendMessage('Hello') |
| 687 | + |
| 688 | + // Context should be in the request body |
| 689 | + expect(capturedBody).toBeDefined() |
| 690 | + expect(capturedBody.context).toEqual(testContext) |
| 691 | + }) |
| 692 | + |
| 693 | + it('should work without context (context is optional)', async () => { |
| 694 | + const executeFn = vi.fn(async (args: any) => { |
| 695 | + return { result: args.value } |
| 696 | + }) |
| 697 | + |
| 698 | + const toolDef = toolDefinition({ |
| 699 | + name: 'simpleTool', |
| 700 | + description: 'Simple tool', |
| 701 | + inputSchema: z.object({ |
| 702 | + value: z.string(), |
| 703 | + }), |
| 704 | + outputSchema: z.object({ |
| 705 | + result: z.string(), |
| 706 | + }), |
| 707 | + }) |
| 708 | + |
| 709 | + const tool = toolDef.client(executeFn) |
| 710 | + |
| 711 | + const chunks = createToolCallChunks([ |
| 712 | + { id: 'tool-1', name: 'simpleTool', arguments: '{"value":"test"}' }, |
| 713 | + ]) |
| 714 | + const adapter = createMockConnectionAdapter({ chunks }) |
| 715 | + |
| 716 | + const client = new ChatClient({ |
| 717 | + connection: adapter, |
| 718 | + tools: [tool], |
| 719 | + }) |
| 720 | + |
| 721 | + await client.sendMessage('Test') |
| 722 | + |
| 723 | + // Tool should have been called without context |
| 724 | + expect(executeFn).toHaveBeenCalledWith({ value: 'test' }, undefined) |
| 725 | + }) |
| 726 | + }) |
588 | 727 | }) |
0 commit comments