diff --git a/.agents/__tests__/editor-best-of-n.integration.test.ts b/.agents/__tests__/editor-best-of-n.integration.test.ts new file mode 100644 index 000000000..b4f8fcce0 --- /dev/null +++ b/.agents/__tests__/editor-best-of-n.integration.test.ts @@ -0,0 +1,91 @@ +import { API_KEY_ENV_VAR } from '@codebuff/common/old-constants' +import { describe, expect, it } from 'bun:test' + +import { CodebuffClient } from '@codebuff/sdk' + +import type { PrintModeEvent } from '@codebuff/common/types/print-mode' + +/** + * Integration tests for the editor-best-of-n-max agent. + * These tests verify that the best-of-n editor workflow works correctly: + * 1. Spawns multiple implementor agents in parallel + * 2. Collects their implementation proposals + * 3. Uses a selector agent to choose the best implementation + * 4. Applies the chosen implementation + */ +describe('Editor Best-of-N Max Agent Integration', () => { + it( + 'should generate and select the best implementation for a simple edit', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create mock project files with a simple TypeScript file to edit + const projectFiles: Record = { + 'src/utils/math.ts': ` +export function add(a: number, b: number): number { + return a + b +} + +export function subtract(a: number, b: number): number { + return a - b +} +`, + 'src/index.ts': ` +import { add, subtract } from './utils/math' + +console.log(add(1, 2)) +console.log(subtract(5, 3)) +`, + 'package.json': JSON.stringify({ + name: 'test-project', + version: '1.0.0', + dependencies: {}, + }), + } + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-best-of-n-project', + projectFiles, + }) + + const events: PrintModeEvent[] = [] + + // Run the editor-best-of-n-max agent with a simple task + // Using n=2 to keep the test fast while still testing the best-of-n workflow + const run = await client.run({ + agent: 'editor-best-of-n-max', + prompt: + 'Add a multiply function to src/utils/math.ts that takes two numbers and returns their product', + params: { n: 2 }, + handleEvent: (event) => { + console.log(event) + events.push(event) + }, + }) + + // The output should not be an error + expect(run.output.type).not.toEqual('error') + + // Verify we got some output + expect(run.output).toBeDefined() + + // The output should contain the implementation response + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + console.log('Output:', outputStr) + + // Should contain evidence of the multiply function being added + const relevantTerms = ['multiply', 'product', 'str_replace', 'write_file'] + const foundRelevantTerm = relevantTerms.some((term) => + outputStr.toLowerCase().includes(term.toLowerCase()), + ) + + expect(foundRelevantTerm).toBe(true) + }, + { timeout: 120_000 }, // 2 minute timeout for best-of-n workflow + ) +}) diff --git a/.agents/__tests__/file-explorer.integration.test.ts b/.agents/__tests__/file-explorer.integration.test.ts new file mode 100644 index 000000000..0aa3cc3f6 --- /dev/null +++ b/.agents/__tests__/file-explorer.integration.test.ts @@ -0,0 +1,348 @@ +import { API_KEY_ENV_VAR } from '@codebuff/common/old-constants' +import { describe, expect, it } from 'bun:test' + +import { CodebuffClient } from '@codebuff/sdk' +import filePickerDefinition from '../file-explorer/file-picker' +import fileListerDefinition from '../file-explorer/file-lister' + +import type { PrintModeEvent } from '@codebuff/common/types/print-mode' + +/** + * Integration tests for agents that use the read_subtree tool. + * These tests verify that the SDK properly initializes the session state + * with project files and that agents can access the file tree through + * the read_subtree tool. + * + * The file-lister agent is used directly instead of file-picker because: + * - file-lister directly uses the read_subtree tool + * - file-picker spawns file-lister as a subagent, adding complexity + * - Testing file-lister directly verifies the core functionality + */ +describe('File Lister Agent Integration - read_subtree tool', () => { + it( + 'should find relevant files using read_subtree tool', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create mock project files that the file-lister should be able to find + const projectFiles: Record = { + 'src/index.ts': ` +import { UserService } from './services/user-service' +import { AuthService } from './services/auth-service' + +export function main() { + const userService = new UserService() + const authService = new AuthService() + console.log('Application started') +} +`, + 'src/services/user-service.ts': ` +export class UserService { + async getUser(id: string) { + return { id, name: 'John Doe' } + } + + async createUser(name: string) { + return { id: 'new-user-id', name } + } + + async deleteUser(id: string) { + console.log('User deleted:', id) + } +} +`, + 'src/services/auth-service.ts': ` +export class AuthService { + async login(email: string, password: string) { + return { token: 'mock-token' } + } + + async logout() { + console.log('Logged out') + } + + async validateToken(token: string) { + return token === 'mock-token' + } +} +`, + 'src/utils/logger.ts': ` +export function log(message: string) { + console.log('[LOG]', message) +} + +export function error(message: string) { + console.error('[ERROR]', message) +} +`, + 'src/types/user.ts': ` +export interface User { + id: string + name: string + email?: string +} +`, + 'package.json': JSON.stringify({ + name: 'test-project', + version: '1.0.0', + dependencies: {}, + }), + 'README.md': + '# Test Project\n\nA simple test project for integration testing.', + } + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-project', + projectFiles, + }) + + const events: PrintModeEvent[] = [] + + // Run the file-lister agent to find files related to user service + // The file-lister agent uses the read_subtree tool directly + const run = await client.run({ + agent: 'file-lister', + prompt: 'Find files related to user authentication and user management', + handleEvent: (event) => { + events.push(event) + }, + }) + + // The output should not be an error + expect(run.output.type).not.toEqual('error') + + // Verify we got some output + expect(run.output).toBeDefined() + + // The file-lister should have found relevant files + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + + // Verify that the file-lister found some relevant files + const relevantFiles = [ + 'user-service', + 'auth-service', + 'user', + 'auth', + 'services', + ] + const foundRelevantFile = relevantFiles.some((file) => + outputStr.toLowerCase().includes(file.toLowerCase()), + ) + + expect(foundRelevantFile).toBe(true) + }, + { timeout: 60_000 }, + ) + + it( + 'should use the file tree from session state', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create a different set of project files with a specific structure + const projectFiles: Record = { + 'packages/core/src/index.ts': 'export const VERSION = "1.0.0"', + 'packages/core/src/api/server.ts': + 'export function startServer() { console.log("started") }', + 'packages/core/src/api/routes.ts': + 'export const routes = { health: "/health" }', + 'packages/utils/src/helpers.ts': + 'export function formatDate(d: Date) { return d.toISOString() }', + 'docs/api.md': '# API Documentation\n\nAPI docs here.', + 'package.json': JSON.stringify({ name: 'mono-repo', version: '2.0.0' }), + } + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-project', + projectFiles, + }) + + const events: PrintModeEvent[] = [] + + // Run file-lister to find API-related files + const run = await client.run({ + agent: 'file-lister', + prompt: 'Find files related to the API server implementation', + handleEvent: (event) => { + events.push(event) + }, + }) + + expect(run.output.type).not.toEqual('error') + + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + + // Should find API-related files + const apiRelatedTerms = ['server', 'routes', 'api', 'core'] + const foundApiFile = apiRelatedTerms.some((term) => + outputStr.toLowerCase().includes(term.toLowerCase()), + ) + + expect(foundApiFile).toBe(true) + }, + { timeout: 60_000 }, + ) + + it( + 'should respect directories parameter', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create project with multiple top-level directories + const projectFiles: Record = { + 'frontend/src/App.tsx': + 'export function App() { return
App
}', + 'frontend/src/components/Button.tsx': + 'export function Button() { return }', + 'backend/src/server.ts': + 'export function start() { console.log("started") }', + 'backend/src/routes/users.ts': + 'export function getUsers() { return [] }', + 'shared/types/common.ts': 'export type ID = string', + 'package.json': JSON.stringify({ name: 'full-stack-app' }), + } + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-project', + projectFiles, + }) + + // Run file-lister with directories parameter to limit to frontend only + const run = await client.run({ + agent: 'file-lister', + prompt: 'Find React component files', + params: { + directories: ['frontend'], + }, + handleEvent: () => {}, + }) + + expect(run.output.type).not.toEqual('error') + + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + + // Should find frontend files + const frontendTerms = ['app', 'button', 'component', 'frontend'] + const foundFrontendFile = frontendTerms.some((term) => + outputStr.toLowerCase().includes(term.toLowerCase()), + ) + + expect(foundFrontendFile).toBe(true) + }, + { timeout: 60_000 }, + ) +}) + +/** + * Integration tests for the file-picker agent that spawns subagents. + * The file-picker spawns file-lister as a subagent to find files. + * This tests the spawn_agents tool functionality through the SDK. + */ +describe('File Picker Agent Integration - spawn_agents tool', () => { + // Note: This test requires the local agent definitions to be used for both + // file-picker AND its spawned file-lister subagent. Currently, the spawned + // agent may resolve to the server version which has the old parsing bug. + // Skip until we have a way to ensure spawned agents use local definitions. + it.skip( + 'should spawn file-lister subagent and find relevant files', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create mock project files + const projectFiles: Record = { + 'src/index.ts': ` +import { UserService } from './services/user-service' +export function main() { + const userService = new UserService() + console.log('Application started') +} +`, + 'src/services/user-service.ts': ` +export class UserService { + async getUser(id: string) { + return { id, name: 'John Doe' } + } +} +`, + 'src/services/auth-service.ts': ` +export class AuthService { + async login(email: string, password: string) { + return { token: 'mock-token' } + } +} +`, + 'package.json': JSON.stringify({ + name: 'test-project', + version: '1.0.0', + }), + } + + // Use local agent definitions to test the updated handleSteps + const localFilePickerDef = filePickerDefinition as unknown as any + const localFileListerDef = fileListerDefinition as unknown as any + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-project-picker', + projectFiles, + agentDefinitions: [localFilePickerDef, localFileListerDef], + }) + + const events: PrintModeEvent[] = [] + + // Run the file-picker agent which spawns file-lister as a subagent + const run = await client.run({ + agent: localFilePickerDef.id, + prompt: 'Find files related to user authentication', + handleEvent: (event) => { + events.push(event) + }, + }) + + // Check for errors in the output + if (run.output.type === 'error') { + console.error('File picker error:', run.output) + } + + console.log('File picker output type:', run.output.type) + console.log('File picker output:', JSON.stringify(run.output, null, 2)) + + // The output should not be an error + expect(run.output.type).not.toEqual('error') + + // Verify we got some output + expect(run.output).toBeDefined() + + // The file-picker should have found relevant files via its spawned file-lister + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + + // Verify that the file-picker found some relevant files + const relevantFiles = ['user', 'auth', 'service'] + const foundRelevantFile = relevantFiles.some((file) => + outputStr.toLowerCase().includes(file.toLowerCase()), + ) + + expect(foundRelevantFile).toBe(true) + }, + { timeout: 90_000 }, + ) +}) diff --git a/.agents/editor/best-of-n/editor-best-of-n.ts b/.agents/editor/best-of-n/editor-best-of-n.ts index d9dd52634..7592a1785 100644 --- a/.agents/editor/best-of-n/editor-best-of-n.ts +++ b/.agents/editor/best-of-n/editor-best-of-n.ts @@ -39,11 +39,11 @@ export function createBestOfNEditor( spawnableAgents: buildArray( 'best-of-n-selector', 'best-of-n-selector-opus', - isDefault && 'best-of-n-selector-gemini', + 'best-of-n-selector-gemini', 'editor-implementor', 'editor-implementor-opus', - isDefault && 'editor-implementor-gemini', - isMax && 'editor-implementor-gpt-5', + 'editor-implementor-gemini', + 'editor-implementor-gpt-5', ), inputSchema: { @@ -229,7 +229,9 @@ function* handleStepsDefault({ } } function* handleStepsMax({ + agentState, params, + logger, }: AgentStepContext): ReturnType< NonNullable > { @@ -254,6 +256,28 @@ function* handleStepsMax({ 'editor-implementor-opus', ] as const + // Only keep messages up to just before the last spawn agent tool call. + const { messageHistory: initialMessageHistory } = agentState + const lastSpawnAgentMessageIndex = initialMessageHistory.findLastIndex( + (message) => + message.role === 'assistant' && + Array.isArray(message.content) && + message.content.length > 0 && + message.content[0].type === 'tool-call' && + message.content[0].toolName === 'spawn_agents', + ) + const updatedMessageHistory = initialMessageHistory.slice( + 0, + lastSpawnAgentMessageIndex, + ) + yield { + toolName: 'set_messages', + input: { + messages: updatedMessageHistory, + }, + includeToolCall: false, + } satisfies ToolCall<'set_messages'> + // Spawn implementor agents using the model pattern const implementorAgents = MAX_MODEL_PATTERN.slice(0, n).map((agent_type) => ({ agent_type, @@ -269,8 +293,9 @@ function* handleStepsMax({ } satisfies ToolCall<'spawn_agents'> // Extract spawn results - const spawnedImplementations = - extractSpawnResults<{ text: string }[]>(implementorResults) + const spawnedImplementations = extractSpawnResults( + implementorResults, + ) as any[] // Extract all the plans from the structured outputs const letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' @@ -280,11 +305,16 @@ function* handleStepsMax({ content: 'errorMessage' in result ? `Error: ${result.errorMessage}` - : result[0].text, + : extractLastMessageText(result), })) + logger.info( + { spawnedImplementations, implementations }, + 'spawnedImplementations', + ) + // Spawn selector with implementations as params - const { toolResult: selectorResult } = yield { + const { toolResult: selectorResult, agentState: selectorAgentState } = yield { toolName: 'spawn_agents', input: { agents: [ @@ -298,8 +328,10 @@ function* handleStepsMax({ } satisfies ToolCall<'spawn_agents'> const selectorOutput = extractSpawnResults<{ - implementationId: string - reasoning: string + value: { + implementationId: string + reasoning: string + } }>(selectorResult)[0] if ('errorMessage' in selectorOutput) { @@ -309,7 +341,7 @@ function* handleStepsMax({ } satisfies ToolCall<'set_output'> return } - const { implementationId } = selectorOutput + const { implementationId } = selectorOutput.value const chosenImplementation = implementations.find( (implementation) => implementation.id === implementationId, ) @@ -321,68 +353,77 @@ function* handleStepsMax({ return } - // Apply the chosen implementation using STEP_TEXT (only tool calls, no commentary) - const toolCallsOnly = extractToolCallsOnly( - typeof chosenImplementation.content === 'string' - ? chosenImplementation.content - : '', - ) + const numMessagesBeforeStepText = selectorAgentState.messageHistory.length + const { agentState: postEditsAgentState } = yield { type: 'STEP_TEXT', - text: toolCallsOnly, + text: chosenImplementation.content, } as StepText const { messageHistory } = postEditsAgentState - const lastAssistantMessageIndex = messageHistory.findLastIndex( - (message) => message.role === 'assistant', - ) - const editToolResults = messageHistory - .slice(lastAssistantMessageIndex) - .filter((message) => message.role === 'tool') - .flatMap((message) => message.content) - .filter((output) => output.type === 'json') - .map((output) => output.value) - // Set output with the chosen implementation and reasoning + // Set output with the messages from running the step text of the chosen implementation yield { toolName: 'set_output', input: { - response: chosenImplementation.content, - toolResults: editToolResults, + messages: messageHistory.slice(numMessagesBeforeStepText), }, includeToolCall: false, } satisfies ToolCall<'set_output'> - function extractSpawnResults( - results: any[] | undefined, - ): (T | { errorMessage: string })[] { - if (!results) return [] - const spawnedResults = results - .filter((result) => result.type === 'json') - .map((result) => result.value) - .flat() as { - agentType: string - value: { value?: T; errorMessage?: string } - }[] - return spawnedResults.map( - (result) => - result.value.value ?? { - errorMessage: - result.value.errorMessage ?? 'Error extracting spawn results', - }, - ) + /** + * Extracts the array of subagent results from spawn_agents tool output. + * + * The spawn_agents tool result structure is: + * [{ type: 'json', value: [{ agentName, agentType, value: AgentOutput }] }] + * + * Returns an array of agent outputs, one per spawned agent. + */ + function extractSpawnResults(results: any[] | undefined): T[] { + if (!results || results.length === 0) return [] + + // Find the json result containing spawn results + const jsonResult = results.find((r) => r.type === 'json') + if (!jsonResult?.value) return [] + + // Get the spawned agent results array + const spawnedResults = Array.isArray(jsonResult.value) + ? jsonResult.value + : [jsonResult.value] + + // Extract the value (AgentOutput) from each result + return spawnedResults.map((result: any) => result?.value).filter(Boolean) } - // Extract only tool calls from text, removing any commentary - function extractToolCallsOnly(text: string): string { - const toolExtractionPattern = - /\n(.*?)\n<\/codebuff_tool_call>/gs - const matches: string[] = [] - - for (const match of text.matchAll(toolExtractionPattern)) { - matches.push(match[0]) // Include the full tool call with tags + /** + * Extracts the text content from a 'lastMessage' AgentOutput. + * + * For agents with outputMode: 'last_message', the output structure is: + * { type: 'lastMessage', value: [{ role: 'assistant', content: [{ type: 'text', text: '...' }] }] } + * + * Returns the text from the last assistant message, or null if not found. + */ + function extractLastMessageText(agentOutput: any): string | null { + if (!agentOutput) return null + + // Handle 'lastMessage' output mode - the value contains an array of messages + if ( + agentOutput.type === 'lastMessage' && + Array.isArray(agentOutput.value) + ) { + // Find the last assistant message with text content + for (let i = agentOutput.value.length - 1; i >= 0; i--) { + const message = agentOutput.value[i] + if (message.role === 'assistant' && Array.isArray(message.content)) { + // Find text content in the message + for (const part of message.content) { + if (part.type === 'text' && typeof part.text === 'string') { + return part.text + } + } + } + } } - - return matches.join('\n') + return null } } diff --git a/.agents/editor/best-of-n/editor-implementor.ts b/.agents/editor/best-of-n/editor-implementor.ts index f159df2ce..c27af72a2 100644 --- a/.agents/editor/best-of-n/editor-implementor.ts +++ b/.agents/editor/best-of-n/editor-implementor.ts @@ -37,7 +37,7 @@ export const createBestOfNImplementor = (options: { Your task is to write out ALL the code changes needed to complete the user's request in a single comprehensive response. -Important: You can not make any other tool calls besides editing files. You cannot read more files, write todos, or spawn agents. +Important: You can not make any other tool calls besides editing files. You cannot read more files, write todos, spawn agents, or set output. Do not call any of these tools! Write out what changes you would make using the tool call format below. Use this exact format for each file change: diff --git a/.agents/file-explorer/file-picker.ts b/.agents/file-explorer/file-picker.ts index 25f7b6008..048d904d3 100644 --- a/.agents/file-explorer/file-picker.ts +++ b/.agents/file-explorer/file-picker.ts @@ -64,17 +64,22 @@ Do not use any further tools or spawn any further agents. }, } satisfies ToolCall - const filesResult = - extractSpawnResults<{ text: string }[]>(fileListerResults)[0] - if (!Array.isArray(filesResult)) { + const spawnResults = extractSpawnResults(fileListerResults) + const firstResult = spawnResults[0] + const fileListText = extractLastMessageText(firstResult) + + if (!fileListText) { + const errorMessage = extractErrorMessage(firstResult) yield { type: 'STEP_TEXT', - text: filesResult.errorMessage, + text: errorMessage + ? `Error from file-lister: ${errorMessage}` + : 'Error: Could not extract file list from spawned agent', } satisfies StepText return } - const paths = filesResult[0].text.split('\n').filter(Boolean) + const paths = fileListText.split('\n').filter(Boolean) yield { toolName: 'read_files', @@ -85,24 +90,71 @@ Do not use any further tools or spawn any further agents. yield 'STEP' - function extractSpawnResults( - results: any[] | undefined, - ): (T | { errorMessage: string })[] { - if (!results) return [] - const spawnedResults = results - .filter((result) => result.type === 'json') - .map((result) => result.value) - .flat() as { - agentType: string - value: { value?: T; errorMessage?: string } - }[] - return spawnedResults.map( - (result) => - result.value.value ?? { - errorMessage: - result.value.errorMessage ?? 'Error extracting spawn results', - }, - ) + /** + * Extracts the array of subagent results from spawn_agents tool output. + * + * The spawn_agents tool result structure is: + * [{ type: 'json', value: [{ agentName, agentType, value: AgentOutput }] }] + * + * Returns an array of agent outputs, one per spawned agent. + */ + function extractSpawnResults(results: any[] | undefined): any[] { + if (!results || results.length === 0) return [] + + // Find the json result containing spawn results + const jsonResult = results.find((r) => r.type === 'json') + if (!jsonResult?.value) return [] + + // Get the spawned agent results array + const spawnedResults = Array.isArray(jsonResult.value) ? jsonResult.value : [jsonResult.value] + + // Extract the value (AgentOutput) from each result + return spawnedResults.map((result: any) => result?.value).filter(Boolean) + } + + /** + * Extracts the text content from a 'lastMessage' AgentOutput. + * + * For agents with outputMode: 'last_message', the output structure is: + * { type: 'lastMessage', value: [{ role: 'assistant', content: [{ type: 'text', text: '...' }] }] } + * + * Returns the text from the last assistant message, or null if not found. + */ + function extractLastMessageText(agentOutput: any): string | null { + if (!agentOutput) return null + + // Handle 'lastMessage' output mode - the value contains an array of messages + if (agentOutput.type === 'lastMessage' && Array.isArray(agentOutput.value)) { + // Find the last assistant message with text content + for (let i = agentOutput.value.length - 1; i >= 0; i--) { + const message = agentOutput.value[i] + if (message.role === 'assistant' && Array.isArray(message.content)) { + // Find text content in the message + for (const part of message.content) { + if (part.type === 'text' && typeof part.text === 'string') { + return part.text + } + } + } + } + } + + return null + } + + /** + * Extracts the error message from an AgentOutput if it's an error type. + * + * Returns the error message string, or null if not an error output. + */ + function extractErrorMessage(agentOutput: any): string | null { + if (!agentOutput) return null + + if (agentOutput.type === 'error') { + return agentOutput.message ?? agentOutput.value ?? null + } + + return null } }, } diff --git a/.agents/tsconfig.json b/.agents/tsconfig.json index 4387f3d66..dbb372c16 100644 --- a/.agents/tsconfig.json +++ b/.agents/tsconfig.json @@ -5,6 +5,7 @@ "skipLibCheck": true, "types": ["bun", "node"], "paths": { + "@codebuff/sdk": ["../sdk/src/index.ts"], "@codebuff/common/*": ["../common/src/*"] } }, diff --git a/.agents/types/util-types.ts b/.agents/types/util-types.ts index c6bc95d73..79b4e81f5 100644 --- a/.agents/types/util-types.ts +++ b/.agents/types/util-types.ts @@ -74,16 +74,15 @@ export type ToolCallPart = { providerExecuted?: boolean } -export type ToolResultOutput = - | { - type: 'json' - value: JSONValue - } - | { - type: 'media' - data: string - mediaType: string - } +export type MediaToolResultOutputSchema = { + data: string + mediaType: string +} + +export type ToolResultOutput = { + value: JSONValue + media?: MediaToolResultOutputSchema[] +} // ===== Message Types ===== type AuxiliaryData = { diff --git a/backend/src/__tests__/cost-aggregation.integration.test.ts b/backend/src/__tests__/cost-aggregation.integration.test.ts index 3206f3e5d..5dd6a5cd8 100644 --- a/backend/src/__tests__/cost-aggregation.integration.test.ts +++ b/backend/src/__tests__/cost-aggregation.integration.test.ts @@ -4,6 +4,7 @@ import * as agentRegistry from '@codebuff/agent-runtime/templates/agent-registry import { TEST_USER_ID } from '@codebuff/common/old-constants' import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' +import { generateCompactId } from '@codebuff/common/util/string' import { spyOn, beforeEach, @@ -22,6 +23,7 @@ import type { AgentRuntimeScopedDeps, } from '@codebuff/common/types/contracts/agent-runtime' import type { SendActionFn } from '@codebuff/common/types/contracts/client' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { Mock } from 'bun:test' @@ -149,15 +151,30 @@ describe('Cost Aggregation Integration Tests', () => { if (callCount === 1) { // Main agent spawns a subagent yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n', - } + type: 'tool-call', + toolName: 'spawn_agents', + toolCallId: generateCompactId('test-id-'), + input: { + agents: [ + { + agent_type: 'editor', + prompt: 'Write a simple hello world file', + }, + ], + }, + } satisfies StreamChunk } else { // Subagent writes a file yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n', - } + type: 'tool-call', + toolName: 'write_file', + toolCallId: generateCompactId('test-id-'), + input: { + path: 'hello.txt', + instructions: 'Create hello world file', + content: 'Hello, World!', + }, + } satisfies StreamChunk } return 'mock-message-id' }, @@ -252,8 +269,8 @@ describe('Cost Aggregation Integration Tests', () => { // Verify the total cost includes both main agent and subagent costs const finalCreditsUsed = result.sessionState.mainAgentState.creditsUsed - // The actual cost is higher than expected due to multiple steps in agent execution - expect(finalCreditsUsed).toEqual(73) + // 10 for the first call, 7 for the subagent, 7*9 for the next 9 calls + expect(finalCreditsUsed).toEqual(80) // Verify the cost breakdown makes sense expect(finalCreditsUsed).toBeGreaterThan(0) @@ -307,21 +324,35 @@ describe('Cost Aggregation Integration Tests', () => { if (callCount === 1) { // Main agent spawns first-level subagent yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n', - } + type: 'tool-call', + toolName: 'spawn_agents', + toolCallId: generateCompactId('test-id-'), + input: { + agents: [{ agent_type: 'editor', prompt: 'Create files' }], + }, + } satisfies StreamChunk } else if (callCount === 2) { // First-level subagent spawns second-level subagent yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n', - } + type: 'tool-call', + toolName: 'spawn_agents', + toolCallId: generateCompactId('test-id-'), + input: { + agents: [{ agent_type: 'editor', prompt: 'Write specific file' }], + }, + } satisfies StreamChunk } else { // Second-level subagent does actual work yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n', - } + type: 'tool-call', + toolName: 'write_file', + toolCallId: generateCompactId('test-id-'), + input: { + path: 'nested.txt', + instructions: 'Create nested file', + content: 'Nested content', + }, + } satisfies StreamChunk } return 'mock-message-id' @@ -348,8 +379,8 @@ describe('Cost Aggregation Integration Tests', () => { // Should aggregate costs from all levels: main + sub1 + sub2 const finalCreditsUsed = result.sessionState.mainAgentState.creditsUsed - // Multi-level agents should have higher costs than simple ones - expect(finalCreditsUsed).toEqual(50) + // 10 calls from base agent, 1 from first subagent, 1 from second subagent: 12 calls total + expect(finalCreditsUsed).toEqual(60) }) it('should maintain cost integrity when subagents fail', async () => { @@ -365,12 +396,19 @@ describe('Cost Aggregation Integration Tests', () => { if (callCount === 1) { // Main agent spawns subagent yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n', - } + type: 'tool-call', + toolName: 'spawn_agents', + toolCallId: generateCompactId('test-id-'), + input: { + agents: [{ agent_type: 'editor', prompt: 'This will fail' }], + }, + } satisfies StreamChunk } else { // Subagent fails after incurring cost - yield { type: 'text' as const, text: 'Some response' } + yield { + type: 'text', + text: 'Some response', + } satisfies StreamChunk throw new Error('Subagent execution failed') } diff --git a/cli/src/hooks/use-send-message.ts b/cli/src/hooks/use-send-message.ts index 8112a568d..05ad85e7d 100644 --- a/cli/src/hooks/use-send-message.ts +++ b/cli/src/hooks/use-send-message.ts @@ -30,7 +30,12 @@ import { import type { ElapsedTimeTracker } from './use-elapsed-time' import type { StreamStatus } from './use-message-queue' -import type { ChatMessage, ContentBlock, ToolContentBlock, AskUserContentBlock } from '../types/chat' +import type { + ChatMessage, + ContentBlock, + ToolContentBlock, + AskUserContentBlock, +} from '../types/chat' import type { SendMessageFn } from '../types/contracts/send-message' import type { ParamsOf } from '../types/function-params' import type { SetElement } from '../types/utils' @@ -1128,7 +1133,7 @@ export const useSendMessage = ({ ] of spawnAgentsMapRef.current.entries()) { const eventType = event.agentType || '' const storedType = info.agentType || '' - + // Extract base names without version or scope // e.g., 'codebuff/file-picker@0.0.2' -> 'file-picker' // 'file-picker' -> 'file-picker' @@ -1140,10 +1145,10 @@ export const useSendMessage = ({ // Handle simple names, possibly with version return type.split('@')[0] } - + const eventBaseName = getBaseName(eventType) const storedBaseName = getBaseName(storedType) - + // Match if base names are the same const isMatch = eventBaseName === storedBaseName if (isMatch) { @@ -1421,6 +1426,7 @@ export const useSendMessage = ({ input, agentId, includeToolCall, + parentAgentId, } = event if (toolName === 'spawn_agents' && input?.agents) { @@ -1492,7 +1498,7 @@ export const useSendMessage = ({ } // If this tool call belongs to a subagent, add it to that agent's blocks - if (agentId) { + if (parentAgentId && agentId) { applyMessageUpdate((prev) => prev.map((msg) => { if (msg.id !== aiMessageId || !msg.blocks) { @@ -1562,18 +1568,24 @@ export const useSendMessage = ({ } setStreamingAgents((prev) => new Set(prev).add(toolCallId)) - } else if (event.type === 'tool_result' && event.toolCallId) { + } else if (event.type === 'tool_result' && event.toolCallId) { const { toolCallId } = event // Handle ask_user result transformation - applyMessageUpdate((prev) => + applyMessageUpdate((prev) => prev.map((msg) => { if (msg.id !== aiMessageId || !msg.blocks) return msg // Recursively check for tool blocks to transform - const transformAskUser = (blocks: ContentBlock[]): ContentBlock[] => { + const transformAskUser = ( + blocks: ContentBlock[], + ): ContentBlock[] => { return blocks.map((block) => { - if (block.type === 'tool' && block.toolCallId === toolCallId && block.toolName === 'ask_user') { + if ( + block.type === 'tool' && + block.toolCallId === toolCallId && + block.toolName === 'ask_user' + ) { const resultValue = (event.output?.[0] as any)?.value const skipped = resultValue?.skipped const answers = resultValue?.answers @@ -1592,7 +1604,7 @@ export const useSendMessage = ({ skipped, } as AskUserContentBlock } - + if (block.type === 'agent' && block.blocks) { const updatedBlocks = transformAskUser(block.blocks) if (updatedBlocks !== block.blocks) { @@ -1605,10 +1617,10 @@ export const useSendMessage = ({ const newBlocks = transformAskUser(msg.blocks) if (newBlocks !== msg.blocks) { - return { ...msg, blocks: newBlocks } + return { ...msg, blocks: newBlocks } } return msg - }) + }), ) // Check if this is a spawn_agents result diff --git a/common/src/templates/initial-agents-dir/types/util-types.ts b/common/src/templates/initial-agents-dir/types/util-types.ts index c6bc95d73..79b4e81f5 100644 --- a/common/src/templates/initial-agents-dir/types/util-types.ts +++ b/common/src/templates/initial-agents-dir/types/util-types.ts @@ -74,16 +74,15 @@ export type ToolCallPart = { providerExecuted?: boolean } -export type ToolResultOutput = - | { - type: 'json' - value: JSONValue - } - | { - type: 'media' - data: string - mediaType: string - } +export type MediaToolResultOutputSchema = { + data: string + mediaType: string +} + +export type ToolResultOutput = { + value: JSONValue + media?: MediaToolResultOutputSchema[] +} // ===== Message Types ===== type AuxiliaryData = { diff --git a/common/src/tools/params/tool/add-message.ts b/common/src/tools/params/tool/add-message.ts index 2866cc2d3..7312fdbd5 100644 --- a/common/src/tools/params/tool/add-message.ts +++ b/common/src/tools/params/tool/add-message.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + emptyToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -16,7 +19,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/add-subgoal.ts b/common/src/tools/params/tool/add-subgoal.ts index ed592797b..0630e76de 100644 --- a/common/src/tools/params/tool/add-subgoal.ts +++ b/common/src/tools/params/tool/add-subgoal.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -32,7 +32,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/ask-user.ts b/common/src/tools/params/tool/ask-user.ts index 8a228de46..dc83e1618 100644 --- a/common/src/tools/params/tool/ask-user.ts +++ b/common/src/tools/params/tool/ask-user.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -10,11 +10,16 @@ export const questionSchema = z.object({ .string() .max(12) .optional() - .describe('Short label (max 12 chars) displayed as a chip/tag. Example: "Auth method"'), + .describe( + 'Short label (max 12 chars) displayed as a chip/tag. Example: "Auth method"', + ), options: z .object({ label: z.string().describe('The display text for this option'), - description: z.string().optional().describe('Explanation shown when option is focused'), + description: z + .string() + .optional() + .describe('Explanation shown when option is focused'), }) .array() .refine((opts) => opts.length >= 2, { @@ -30,10 +35,22 @@ export const questionSchema = z.object({ ), validation: z .object({ - maxLength: z.number().optional().describe('Maximum length for "Other" text input'), - minLength: z.number().optional().describe('Minimum length for "Other" text input'), - pattern: z.string().optional().describe('Regex pattern for "Other" text input'), - patternError: z.string().optional().describe('Custom error message when pattern fails'), + maxLength: z + .number() + .optional() + .describe('Maximum length for "Other" text input'), + minLength: z + .number() + .optional() + .describe('Minimum length for "Other" text input'), + pattern: z + .string() + .optional() + .describe('Regex pattern for "Other" text input'), + patternError: z + .string() + .optional() + .describe('Custom error message when pattern fails'), }) .optional() .describe('Validation rules for "Other" text input'), @@ -67,14 +84,20 @@ const outputSchema = z.object({ .array(z.string()) .optional() .describe('Array of selected option texts (multi-select mode)'), - otherText: z.string().optional().describe('Custom text input (if user typed their own answer)'), + otherText: z + .string() + .optional() + .describe('Custom text input (if user typed their own answer)'), }), ) .optional() .describe( 'Array of user answers, one per question. Each answer has either selectedOption (single), selectedOptions (multi), or otherText.', ), - skipped: z.boolean().optional().describe('True if user skipped the questions'), + skipped: z + .boolean() + .optional() + .describe('True if user skipped the questions'), }) const description = ` @@ -87,7 +110,7 @@ The user can either: - Skip the questions to provide different instructions instead Single-select example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -96,9 +119,18 @@ ${$getToolCallString({ question: 'Which authentication method should we use?', header: 'Auth method', options: [ - { label: 'JWT tokens', description: 'Stateless tokens stored in localStorage' }, - { label: 'Session cookies', description: 'Server-side sessions with httpOnly cookies' }, - { label: 'OAuth2', description: 'Third-party authentication (Google, GitHub, etc.)' }, + { + label: 'JWT tokens', + description: 'Stateless tokens stored in localStorage', + }, + { + label: 'Session cookies', + description: 'Server-side sessions with httpOnly cookies', + }, + { + label: 'OAuth2', + description: 'Third-party authentication (Google, GitHub, etc.)', + }, ], }, ], @@ -107,7 +139,7 @@ ${$getToolCallString({ })} Multi-select example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/browser-logs.ts b/common/src/tools/params/tool/browser-logs.ts index acb4d51d9..742c2168c 100644 --- a/common/src/tools/params/tool/browser-logs.ts +++ b/common/src/tools/params/tool/browser-logs.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { BrowserResponseSchema } from '../../../browser-actions' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -64,7 +64,7 @@ Navigate: - \`waitUntil\`: (required) One of 'load', 'domcontentloaded', 'networkidle0' Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/code-search.ts b/common/src/tools/params/tool/code-search.ts index 876ea2934..2f5d82791 100644 --- a/common/src/tools/params/tool/code-search.ts +++ b/common/src/tools/params/tool/code-search.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -85,37 +85,37 @@ RESULT LIMITING: - If the global limit is reached, remaining files will be skipped Examples: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'foo' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'foo\\.bar = 1\\.0' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'import.*foo', cwd: 'src' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'function.*authenticate', flags: '-i -t ts -t js' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'TODO', flags: '-n --type-not py' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'getUserData', maxResults: 10 }, diff --git a/common/src/tools/params/tool/create-plan.ts b/common/src/tools/params/tool/create-plan.ts index 1aca1d6ce..56c027da2 100644 --- a/common/src/tools/params/tool/create-plan.ts +++ b/common/src/tools/params/tool/create-plan.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { updateFileResultSchema } from './str-replace' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -52,7 +52,7 @@ After creating the plan, you should end turn to let the user review the plan. Important: Use this tool sparingly. Do not use this tool more than once in a conversation, unless in ask mode. Examples: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/end-turn.ts b/common/src/tools/params/tool/end-turn.ts index 16d21a672..fff966911 100644 --- a/common/src/tools/params/tool/end-turn.ts +++ b/common/src/tools/params/tool/end-turn.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + emptyToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -20,14 +23,14 @@ Only use this tool to hand control back to the user. - Effect: Signals the UI to wait for the user's reply; any pending tool results will be ignored. *INCORRECT USAGE*: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName: 'some_tool_that_produces_results', inputSchema: null, input: { query: 'some example search term' }, endsAgentStep: false, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, @@ -37,7 +40,7 @@ ${$getToolCallString({ *CORRECT USAGE*: All done! Would you like some more help with xyz? -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, diff --git a/common/src/tools/params/tool/find-files.ts b/common/src/tools/params/tool/find-files.ts index 4b46e15ec..3a931b342 100644 --- a/common/src/tools/params/tool/find-files.ts +++ b/common/src/tools/params/tool/find-files.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { fileContentsSchema } from './read-files' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -21,7 +21,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/glob.ts b/common/src/tools/params/tool/glob.ts index e98dc6798..b944dd73e 100644 --- a/common/src/tools/params/tool/glob.ts +++ b/common/src/tools/params/tool/glob.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -26,7 +26,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/list-directory.ts b/common/src/tools/params/tool/list-directory.ts index 403179981..d70590f37 100644 --- a/common/src/tools/params/tool/list-directory.ts +++ b/common/src/tools/params/tool/list-directory.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -19,7 +19,7 @@ const description = ` Lists all files and directories in the specified path. Useful for exploring directory structure and finding files. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -28,7 +28,7 @@ ${$getToolCallString({ endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/lookup-agent-info.ts b/common/src/tools/params/tool/lookup-agent-info.ts index 4f1ee5cc5..029668ec4 100644 --- a/common/src/tools/params/tool/lookup-agent-info.ts +++ b/common/src/tools/params/tool/lookup-agent-info.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { jsonValueSchema } from '../../../types/json' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -18,7 +18,7 @@ const description = ` Retrieve information about an agent by ID for proper spawning. Use this when you see a request with a full agent ID like "@publisher/agent-id@version" to validate the agent exists and get its metadata. Only agents that are published under a publisher and version are supported for this tool. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/read-docs.ts b/common/src/tools/params/tool/read-docs.ts index 235c3faee..25e5ee06b 100644 --- a/common/src/tools/params/tool/read-docs.ts +++ b/common/src/tools/params/tool/read-docs.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -50,7 +50,7 @@ Use cases: The tool will search for the library and return the most relevant documentation content. If a topic is specified, it will focus the results on that specific area. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -61,7 +61,7 @@ ${$getToolCallString({ endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/read-files.ts b/common/src/tools/params/tool/read-files.ts index 2c1877720..3f757aa9b 100644 --- a/common/src/tools/params/tool/read-files.ts +++ b/common/src/tools/params/tool/read-files.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -36,7 +36,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/read-subtree.ts b/common/src/tools/params/tool/read-subtree.ts index 3156d8ca7..09f0c1f58 100644 --- a/common/src/tools/params/tool/read-subtree.ts +++ b/common/src/tools/params/tool/read-subtree.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -28,7 +28,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/run-file-change-hooks.ts b/common/src/tools/params/tool/run-file-change-hooks.ts index 1b1379982..e69c211d6 100644 --- a/common/src/tools/params/tool/run-file-change-hooks.ts +++ b/common/src/tools/params/tool/run-file-change-hooks.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { terminalCommandOutputSchema } from './run-terminal-command' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -25,7 +25,7 @@ Use cases: The client will run only the hooks whose filePattern matches the provided files. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/run-terminal-command.ts b/common/src/tools/params/tool/run-terminal-command.ts index c89e16e57..4bd53f0c2 100644 --- a/common/src/tools/params/tool/run-terminal-command.ts +++ b/common/src/tools/params/tool/run-terminal-command.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -156,7 +156,7 @@ Notes: ${gitCommitGuidePrompt} Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -165,7 +165,7 @@ ${$getToolCallString({ endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/set-messages.ts b/common/src/tools/params/tool/set-messages.ts index bb062cadf..a381f4ca7 100644 --- a/common/src/tools/params/tool/set-messages.ts +++ b/common/src/tools/params/tool/set-messages.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + emptyToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -13,7 +16,7 @@ const inputSchema = z .describe(`Set the conversation history to the provided messages.`) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/set-output.ts b/common/src/tools/params/tool/set-output.ts index f86c94f80..6b976ce0d 100644 --- a/common/src/tools/params/tool/set-output.ts +++ b/common/src/tools/params/tool/set-output.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString } from '../utils' +import { $getNativeToolCallExampleString } from '../utils' import type { $ToolParams } from '../../constants' @@ -16,7 +16,7 @@ You must use this tool as it is the only way to report any findings to the user. Please set the output with all the information and analysis you want to pass on to the user. If you just want to send a simple message, use an object with the key "message" and value of the message you want to send. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/spawn-agent-inline.ts b/common/src/tools/params/tool/spawn-agent-inline.ts index 6ee9a9d44..8b3b682ad 100644 --- a/common/src/tools/params/tool/spawn-agent-inline.ts +++ b/common/src/tools/params/tool/spawn-agent-inline.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + emptyToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -31,7 +34,7 @@ This is useful for: - Managing message history (e.g., summarization) The agent will run until it calls end_turn, then control returns to you. There is no tool result for this tool. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/spawn-agents.ts b/common/src/tools/params/tool/spawn-agents.ts index f7da5e5d7..2c83c8b5b 100644 --- a/common/src/tools/params/tool/spawn-agents.ts +++ b/common/src/tools/params/tool/spawn-agents.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { jsonObjectSchema } from '../../../types/json' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -36,7 +36,7 @@ Use this tool to spawn agents to help you complete the user request. Each agent The prompt field is a simple string, while params is a JSON object that gets validated against the agent's schema. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/str-replace.ts b/common/src/tools/params/tool/str-replace.ts index 5aee745fe..b02ce1e81 100644 --- a/common/src/tools/params/tool/str-replace.ts +++ b/common/src/tools/params/tool/str-replace.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -61,7 +61,7 @@ Important: If you are making multiple edits in a row to a file, use only one str_replace call with multiple replacements instead of multiple str_replace tool calls. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/task-completed.ts b/common/src/tools/params/tool/task-completed.ts index a8c35d1c6..7ea2a4f85 100644 --- a/common/src/tools/params/tool/task-completed.ts +++ b/common/src/tools/params/tool/task-completed.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + emptyToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -34,19 +37,19 @@ Use this tool to signal that the task is complete. All changes have been implemented and tested successfully! -${$getToolCallString({ toolName, inputSchema, input: {}, endsAgentStep })} +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, endsAgentStep })} OR I need more information to proceed. Which database schema should I use for this migration? -${$getToolCallString({ toolName, inputSchema, input: {}, endsAgentStep })} +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, endsAgentStep })} OR I can't get the tests to pass after several different attempts. I need help from the user to proceed. -${$getToolCallString({ toolName, inputSchema, input: {}, endsAgentStep })} +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, endsAgentStep })} `.trim() export const taskCompletedParams = { diff --git a/common/src/tools/params/tool/think-deeply.ts b/common/src/tools/params/tool/think-deeply.ts index 4292332fa..e84a07601 100644 --- a/common/src/tools/params/tool/think-deeply.ts +++ b/common/src/tools/params/tool/think-deeply.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + emptyToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -29,7 +32,7 @@ Avoid for simple changes (e.g., single functions, minor edits). This tool does not generate a tool result. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/update-subgoal.ts b/common/src/tools/params/tool/update-subgoal.ts index 299ca9eea..75e778c63 100644 --- a/common/src/tools/params/tool/update-subgoal.ts +++ b/common/src/tools/params/tool/update-subgoal.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -31,7 +31,7 @@ const description = ` Examples: Usage 1 (update status): -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -42,7 +42,7 @@ ${$getToolCallString({ })} Usage 2 (update plan): -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -53,7 +53,7 @@ ${$getToolCallString({ })} Usage 3 (add log): -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -64,7 +64,7 @@ ${$getToolCallString({ })} Usage 4 (update status and add log): -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/web-search.ts b/common/src/tools/params/tool/web-search.ts index 7a458cc01..e87c8f271 100644 --- a/common/src/tools/params/tool/web-search.ts +++ b/common/src/tools/params/tool/web-search.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -34,7 +34,7 @@ Use cases: The tool will return search results with titles, URLs, and content snippets. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -44,7 +44,7 @@ ${$getToolCallString({ endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/write-file.ts b/common/src/tools/params/tool/write-file.ts index 00ec71c6d..cf50fee05 100644 --- a/common/src/tools/params/tool/write-file.ts +++ b/common/src/tools/params/tool/write-file.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { updateFileResultSchema } from './str-replace' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -39,7 +39,7 @@ Do not use this tool to delete or rename a file. Instead run a terminal command Examples: Example 1 - Simple file creation: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -51,7 +51,7 @@ ${$getToolCallString({ })} Example 2 - Editing with placeholder comments: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/write-todos.ts b/common/src/tools/params/tool/write-todos.ts index 7b7489a6f..ae73e72a1 100644 --- a/common/src/tools/params/tool/write-todos.ts +++ b/common/src/tools/params/tool/write-todos.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString } from '../utils' +import { $getNativeToolCallExampleString } from '../utils' import type { $ToolParams } from '../../constants' @@ -30,7 +30,7 @@ After completing each todo step, call this tool again to update the list and mar Use this tool frequently as you work through tasks to update the list of todos with their current status. Doing this is extremely useful because it helps you stay on track and complete all the requirements of the user's request. It also helps inform the user of your plans and the current progress, which they want to know at all times. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/utils.ts b/common/src/tools/params/utils.ts index cbf79d327..951ee3a61 100644 --- a/common/src/tools/params/utils.ts +++ b/common/src/tools/params/utils.ts @@ -34,6 +34,20 @@ export function $getToolCallString(params: { return [startToolTag, JSON.stringify(obj, null, 2), endToolTag].join('') } +export function $getNativeToolCallExampleString(params: { + toolName: string + inputSchema: z.ZodType | null + input: Input + endsAgentStep?: boolean // unused +}): string { + const { toolName, input } = params + return [ + `<${toolName}_params_example>\n`, + JSON.stringify(input, null, 2), + `\n`, + ].join('') +} + /** Generates the zod schema for a single JSON tool result. */ export function jsonToolResultSchema( valueSchema: z.ZodType, diff --git a/common/src/types/agent-template.ts b/common/src/types/agent-template.ts index 77989fc6d..9cd57c24d 100644 --- a/common/src/types/agent-template.ts +++ b/common/src/types/agent-template.ts @@ -5,6 +5,8 @@ * It imports base types from the user-facing template to eliminate duplication. */ +import { z } from 'zod/v4' + import type { MCPConfig } from './mcp' import type { Model } from '../old-constants' import type { ToolResultOutput } from './messages/content-part' @@ -15,7 +17,6 @@ import type { } from '../templates/initial-agents-dir/types/agent-definition' import type { Logger } from '../templates/initial-agents-dir/types/util-types' import type { ToolName } from '../tools/constants' -import type { z } from 'zod/v4' export type AgentId = `${string}/${string}@${number}.${number}.${number}` @@ -141,6 +142,33 @@ export type AgentTemplate< export type StepText = { type: 'STEP_TEXT'; text: string } export type GenerateN = { type: 'GENERATE_N'; n: number } +// Zod schemas for handleSteps yield values +export const StepTextSchema = z.object({ + type: z.literal('STEP_TEXT'), + text: z.string(), +}) + +export const GenerateNSchema = z.object({ + type: z.literal('GENERATE_N'), + n: z.number().int().positive(), +}) + +export const HandleStepsToolCallSchema = z.object({ + toolName: z.string().min(1), + input: z.record(z.string(), z.any()), + includeToolCall: z.boolean().optional(), +}) + +export const HandleStepsYieldValueSchema = z.union([ + z.literal('STEP'), + z.literal('STEP_ALL'), + StepTextSchema, + GenerateNSchema, + HandleStepsToolCallSchema, +]) + +export type HandleStepsYieldValue = z.infer + export type StepGenerator = Generator< Omit | 'STEP' | 'STEP_ALL' | StepText | GenerateN, // Generic tool call type void, diff --git a/common/src/types/contracts/llm.ts b/common/src/types/contracts/llm.ts index 23fc5ede7..ac3235995 100644 --- a/common/src/types/contracts/llm.ts +++ b/common/src/types/contracts/llm.ts @@ -1,12 +1,12 @@ import type { TrackEventFn } from './analytics' import type { SendActionFn } from './client' import type { CheckLiveUserInputFn } from './live-user-input' +import type { OpenRouterProviderRoutingOptions } from '../agent-template' import type { ParamsExcluding } from '../function-params' import type { Logger } from './logger' import type { Model } from '../../old-constants' import type { Message } from '../messages/codebuff-message' -import type { OpenRouterProviderRoutingOptions } from '../agent-template' -import type { generateText, streamText } from 'ai' +import type { generateText, streamText, ToolCallPart } from 'ai' import type z from 'zod/v4' export type StreamChunk = @@ -19,6 +19,10 @@ export type StreamChunk = type: 'reasoning' text: string } + | Pick< + ToolCallPart, + 'type' | 'toolCallId' | 'toolName' | 'input' | 'providerOptions' + > | { type: 'error'; message: string } export type PromptAiSdkStreamFn = ( diff --git a/common/src/types/messages/codebuff-message.ts b/common/src/types/messages/codebuff-message.ts index 0ce1708b3..d222e2946 100644 --- a/common/src/types/messages/codebuff-message.ts +++ b/common/src/types/messages/codebuff-message.ts @@ -41,7 +41,7 @@ export type ToolMessage = { role: 'tool' toolCallId: string toolName: string - content: ToolResultOutput[] + content: ToolResultOutput } & AuxiliaryMessageData export type Message = diff --git a/common/src/types/messages/content-part.ts b/common/src/types/messages/content-part.ts index c4692e42a..e78a2cada 100644 --- a/common/src/types/messages/content-part.ts +++ b/common/src/types/messages/content-part.ts @@ -45,15 +45,16 @@ export const toolCallPartSchema = z.object({ }) export type ToolCallPart = z.infer -export const toolResultOutputSchema = z.discriminatedUnion('type', [ - z.object({ - type: z.literal('json'), - value: jsonValueSchema, - }), - z.object({ - type: z.literal('media'), - data: z.string(), - mediaType: z.string(), - }), -]) +export const mediaToolResultOutputSchema = z.object({ + data: z.string(), + mediaType: z.string(), +}) +export type MediaToolResultOutputSchema = z.infer< + typeof mediaToolResultOutputSchema +> + +export const toolResultOutputSchema = z.object({ + value: jsonValueSchema, + media: mediaToolResultOutputSchema.array().optional(), +}) export type ToolResultOutput = z.infer diff --git a/common/src/types/session-state.ts b/common/src/types/session-state.ts index 8a3abaec2..7fc5907a4 100644 --- a/common/src/types/session-state.ts +++ b/common/src/types/session-state.ts @@ -48,7 +48,7 @@ export const AgentOutputSchema = z.discriminatedUnion('type', [ }), z.object({ type: z.literal('lastMessage'), - value: z.any(), + value: z.array(z.any()), // Array of assistant and tool messages from the last turn, including tool results }), z.object({ type: z.literal('allMessages'), diff --git a/common/src/util/__tests__/messages.test.ts b/common/src/util/__tests__/messages.test.ts index 53e1cb722..72658d1a0 100644 --- a/common/src/util/__tests__/messages.test.ts +++ b/common/src/util/__tests__/messages.test.ts @@ -13,6 +13,7 @@ import { } from '../messages' import type { Message } from '../../types/messages/codebuff-message' +import type { AssistantModelMessage, ToolResultPart } from 'ai' describe('withCacheControl', () => { it('should add cache control to object without providerOptions', () => { @@ -189,12 +190,6 @@ describe('convertCbToModelMessages', () => { describe('tool message conversion', () => { it('should convert tool messages with JSON output', () => { - const toolResult = [ - { - type: 'json', - value: { result: 'success' }, - }, - ] const messages: Message[] = [ { role: 'tool', @@ -211,15 +206,17 @@ describe('convertCbToModelMessages', () => { expect(result).toEqual([ expect.objectContaining({ - role: 'user', + role: 'tool', content: [ expect.objectContaining({ - type: 'text', - }), + type: 'tool-result', + toolCallId: 'call_123', + toolName: 'test_tool', + output: { type: 'json', value: { result: 'success' } }, + } satisfies ToolResultPart), ], }), ]) - expect((result as any)[0].content[0].text).toContain('') }) it('should convert tool messages with media output', () => { @@ -270,14 +267,15 @@ describe('convertCbToModelMessages', () => { includeCacheControl: false, }) - console.dir({ result }, { depth: null }) // Multiple tool outputs are aggregated into one user message expect(result).toEqual([ expect.objectContaining({ - role: 'user', + role: 'tool', + }), + expect.objectContaining({ + role: 'tool', }), ]) - expect(result[0].content).toHaveLength(2) }) }) @@ -806,14 +804,19 @@ describe('convertCbToModelMessages', () => { includeCacheControl: false, }) - expect(result).toHaveLength(1) - expect(result[0].role).toBe('assistant') - if (typeof result[0].content !== 'string') { - expect(result[0].content[0].type).toBe('text') - if (result[0].content[0].type === 'text') { - expect(result[0].content[0].text).toContain('test_tool') - } - } + expect(result).toEqual([ + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'call_123', + toolName: 'test_tool', + input: { param: 'value' }, + }, + ], + } satisfies AssistantModelMessage, + ]) }) it('should preserve message metadata during conversion', () => { diff --git a/common/src/util/messages.ts b/common/src/util/messages.ts index b8387a96f..54f89fb92 100644 --- a/common/src/util/messages.ts +++ b/common/src/util/messages.ts @@ -1,8 +1,5 @@ import { cloneDeep, has, isEqual } from 'lodash' -import { getToolCallString } from '../tools/utils' - -import type { JSONValue } from '../types/json' import type { AssistantMessage, AuxiliaryMessageData, @@ -11,7 +8,6 @@ import type { ToolMessage, UserMessage, } from '../types/messages/codebuff-message' -import type { ToolResultOutput } from '../types/messages/content-part' import type { ProviderMetadata } from '../types/messages/provider-metadata' import type { AssistantModelMessage, @@ -100,56 +96,52 @@ function assistantToCodebuffMessage( content: Exclude[number] }, ): AssistantMessage { - if (message.content.type === 'tool-call') { - return cloneDeep({ - ...message, - content: [ - { - type: 'text', - text: getToolCallString( - message.content.toolName, - message.content.input, - false, - ), - }, - ], - }) - } + // if (message.content.type === 'tool-call') { + // return cloneDeep({ + // ...message, + // content: [ + // { + // type: 'text', + // text: getToolCallString( + // message.content.toolName, + // message.content.input, + // false, + // ), + // }, + // ], + // }) + // } return cloneDeep({ ...message, content: [message.content] }) } function convertToolResultMessage( message: ToolMessage, ): ModelMessageWithAuxiliaryData[] { - return message.content.map((c) => { - if (c.type === 'json') { - const toolResult = { - toolName: message.toolName, - toolCallId: message.toolCallId, - output: c.value, - } - return cloneDeep({ - ...message, - role: 'user', - content: [ - { - type: 'text', - text: `\n${JSON.stringify(toolResult, null, 2)}\n`, - }, - ], - }) - } - if (c.type === 'media') { - return cloneDeep({ + const messages: ModelMessageWithAuxiliaryData[] = [ + cloneDeep({ + ...message, + role: 'tool', + content: [ + { + ...message, + output: { type: 'json', value: message.content.value }, + type: 'tool-result', + }, + ], + }), + ] + + for (const c of message.content.media ?? []) { + messages.push( + cloneDeep({ ...message, role: 'user', content: [{ type: 'file', data: c.data, mediaType: c.mediaType }], - }) - } - c satisfies never - const oAny = c as any - throw new Error(`Invalid tool output type: ${oAny.type}`) - }) + }), + ) + } + + return messages } function convertToolMessage(message: Message): ModelMessageWithAuxiliaryData[] { @@ -423,32 +415,3 @@ export function assistantMessage( content: assistantContent(params), } } - -export function jsonToolResult( - value: T, -): [ - Extract & { - value: T - }, -] { - return [ - { - type: 'json', - value, - }, - ] -} - -export function mediaToolResult(params: { - data: string - mediaType: string -}): [Extract] { - const { data, mediaType } = params - return [ - { - type: 'media', - data, - mediaType, - }, - ] -} diff --git a/evals/scaffolding.ts b/evals/scaffolding.ts index 6250a2f0b..a86b7b4e3 100644 --- a/evals/scaffolding.ts +++ b/evals/scaffolding.ts @@ -206,13 +206,15 @@ export async function runAgentStepScaffolding( const result = await runAgentStep({ ...EVALS_AGENT_RUNTIME_IMPL, ...agentRuntimeScopedImpl, + additionalToolDefinitions: () => Promise.resolve({}), - textOverride: null, - runId: 'test-run-id', - userId: TEST_USER_ID, - userInputId: generateCompactId(), + agentState, + agentType, + ancestorRunIds: [], clientSessionId: sessionId, + fileContext, fingerprintId: 'test-fingerprint-id', + localAgentTemplates, onResponseChunk: (chunk: string | PrintModeEvent) => { if (typeof chunk !== 'string') { return @@ -222,17 +224,16 @@ export async function runAgentStepScaffolding( } fullResponse += chunk }, - agentType, - fileContext, - localAgentTemplates, - agentState, prompt, - ancestorRunIds: [], - spawnParams: undefined, - repoUrl: undefined, repoId: undefined, - system: 'Test system prompt', + repoUrl: undefined, + runId: 'test-run-id', signal: new AbortController().signal, + spawnParams: undefined, + system: 'Test system prompt', + tools: {}, + userId: TEST_USER_ID, + userInputId: generateCompactId(), }) return { diff --git a/packages/agent-runtime/src/__tests__/cost-aggregation.test.ts b/packages/agent-runtime/src/__tests__/cost-aggregation.test.ts index 5c73a01b3..b46fee77e 100644 --- a/packages/agent-runtime/src/__tests__/cost-aggregation.test.ts +++ b/packages/agent-runtime/src/__tests__/cost-aggregation.test.ts @@ -159,7 +159,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: 75, // First subagent uses 75 credits }, - output: { type: 'lastMessage', value: 'Sub-agent 1 response' }, + output: { type: 'lastMessage', value: [assistantMessage('Sub-agent 1 response')] }, }) .mockResolvedValueOnce({ agentState: { @@ -169,7 +169,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: 100, // Second subagent uses 100 credits }, - output: { type: 'lastMessage', value: 'Sub-agent 2 response' }, + output: { type: 'lastMessage', value: [assistantMessage('Sub-agent 2 response')] }, }) const mockToolCall = { @@ -223,7 +223,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: 50, // Successful agent }, - output: { type: 'lastMessage', value: 'Successful response' }, + output: { type: 'lastMessage', value: [assistantMessage('Successful response')] }, }) .mockRejectedValueOnce( (() => { @@ -370,7 +370,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: subAgent1Cost, } as AgentState, - output: { type: 'lastMessage', value: 'Sub-agent 1 response' }, + output: { type: 'lastMessage', value: [assistantMessage('Sub-agent 1 response')] }, }) .mockResolvedValueOnce({ agentState: { @@ -381,7 +381,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: subAgent2Cost, } as AgentState, - output: { type: 'lastMessage', value: 'Sub-agent 2 response' }, + output: { type: 'lastMessage', value: [assistantMessage('Sub-agent 2 response')] }, }) const mockToolCall = { diff --git a/packages/agent-runtime/src/__tests__/malformed-tool-call.test.ts b/packages/agent-runtime/src/__tests__/malformed-tool-call.test.ts index 8b32ea54a..a7b947281 100644 --- a/packages/agent-runtime/src/__tests__/malformed-tool-call.test.ts +++ b/packages/agent-runtime/src/__tests__/malformed-tool-call.test.ts @@ -16,7 +16,7 @@ import { } from 'bun:test' import { mockFileContext } from './test-utils' -import { processStreamWithTools } from '../tools/stream-parser' +import { processStream } from '../tools/stream-parser' import type { AgentTemplate } from '../templates/types' import type { @@ -34,7 +34,7 @@ let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } describe('malformed tool call error handling', () => { let testAgent: AgentTemplate let agentRuntimeImpl: AgentRuntimeDeps & AgentRuntimeScopedDeps - let defaultParams: ParamsOf + let defaultParams: ParamsOf beforeEach(() => { agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } @@ -139,7 +139,7 @@ describe('malformed tool call error handling', () => { const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, stream, }) @@ -177,7 +177,7 @@ describe('malformed tool call error handling', () => { const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, stream, }) @@ -204,7 +204,7 @@ describe('malformed tool call error handling', () => { const stream = createMockStream(chunks) - const result = await processStreamWithTools({ + const result = await processStream({ ...defaultParams, stream, }) @@ -235,7 +235,7 @@ describe('malformed tool call error handling', () => { const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, stream, }) @@ -268,7 +268,7 @@ describe('malformed tool call error handling', () => { const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, requestFiles: async ({ filePaths }) => { return Object.fromEntries( @@ -307,7 +307,7 @@ describe('malformed tool call error handling', () => { const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, stream, }) diff --git a/packages/agent-runtime/src/__tests__/n-parameter.test.ts b/packages/agent-runtime/src/__tests__/n-parameter.test.ts index c30ef339f..6cecb22f5 100644 --- a/packages/agent-runtime/src/__tests__/n-parameter.test.ts +++ b/packages/agent-runtime/src/__tests__/n-parameter.test.ts @@ -104,7 +104,6 @@ describe('n parameter and GENERATE_N functionality', () => { runAgentStepBaseParams = { ...agentRuntimeImpl, additionalToolDefinitions: () => Promise.resolve({}), - textOverride: null, runId: 'test-run-id', ancestorRunIds: [], repoId: undefined, @@ -122,6 +121,7 @@ describe('n parameter and GENERATE_N functionality', () => { spawnParams: undefined, system: 'Test system', signal: new AbortController().signal, + tools: {} } }) diff --git a/packages/agent-runtime/src/__tests__/read-docs-tool.test.ts b/packages/agent-runtime/src/__tests__/read-docs-tool.test.ts index 4b62cb588..65660004c 100644 --- a/packages/agent-runtime/src/__tests__/read-docs-tool.test.ts +++ b/packages/agent-runtime/src/__tests__/read-docs-tool.test.ts @@ -75,7 +75,6 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { runAgentStepBaseParams = { ...agentRuntimeImpl, additionalToolDefinitions: () => Promise.resolve({}), - textOverride: null, runId: 'test-run-id', ancestorRunIds: [], repoId: undefined, @@ -89,6 +88,7 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { agentType: 'researcher', spawnParams: undefined, signal: new AbortController().signal, + tools: {}, } }) @@ -214,7 +214,6 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { const { agentState: newAgentState } = await runAgentStep({ ...runAgentStepBaseParams, - textOverride: null, fileContext: mockFileContextWithAgents, localAgentTemplates: agentTemplates, agentState, diff --git a/packages/agent-runtime/src/__tests__/run-agent-step-tools.test.ts b/packages/agent-runtime/src/__tests__/run-agent-step-tools.test.ts index 62e026c0f..f8da2e23b 100644 --- a/packages/agent-runtime/src/__tests__/run-agent-step-tools.test.ts +++ b/packages/agent-runtime/src/__tests__/run-agent-step-tools.test.ts @@ -116,22 +116,21 @@ describe('runAgentStep - set_output tool', () => { runAgentStepBaseParams = { ...agentRuntimeImpl, + additionalToolDefinitions: () => Promise.resolve({}), ancestorRunIds: [], clientSessionId: 'test-session', fileContext: mockFileContext, fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, repoId: undefined, repoUrl: undefined, runId: 'test-run-id', signal: new AbortController().signal, spawnParams: undefined, system: 'Test system prompt', - textOverride: null, + tools: {}, userId: TEST_USER_ID, userInputId: 'test-input', - - additionalToolDefinitions: () => Promise.resolve({}), - onResponseChunk: () => {}, } }) diff --git a/packages/agent-runtime/src/__tests__/run-programmatic-step.test.ts b/packages/agent-runtime/src/__tests__/run-programmatic-step.test.ts index df7ded81d..c69886945 100644 --- a/packages/agent-runtime/src/__tests__/run-programmatic-step.test.ts +++ b/packages/agent-runtime/src/__tests__/run-programmatic-step.test.ts @@ -1433,6 +1433,240 @@ describe('runProgrammaticStep', () => { }) }) + describe('yield value validation', () => { + it('should reject invalid yield values', async () => { + const mockGenerator = (function* () { + yield { invalid: 'value' } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const responseChunks: any[] = [] + mockParams.onResponseChunk = (chunk) => responseChunks.push(chunk) + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject yield values with wrong types', async () => { + const mockGenerator = (function* () { + yield { type: 'STEP_TEXT', text: 123 } as any // text should be string + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const responseChunks: any[] = [] + mockParams.onResponseChunk = (chunk) => responseChunks.push(chunk) + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject GENERATE_N with non-positive n', async () => { + const mockGenerator = (function* () { + yield { type: 'GENERATE_N', n: 0 } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const responseChunks: any[] = [] + mockParams.onResponseChunk = (chunk) => responseChunks.push(chunk) + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject GENERATE_N with negative n', async () => { + const mockGenerator = (function* () { + yield { type: 'GENERATE_N', n: -5 } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const responseChunks: any[] = [] + mockParams.onResponseChunk = (chunk) => responseChunks.push(chunk) + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should accept valid STEP literal', async () => { + const mockGenerator = (function* () { + yield 'STEP' + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(false) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept valid STEP_ALL literal', async () => { + const mockGenerator = (function* () { + yield 'STEP_ALL' + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(false) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept valid STEP_TEXT object', async () => { + const mockGenerator = (function* () { + yield { type: 'STEP_TEXT', text: 'Custom response text' } + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(false) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept valid GENERATE_N object', async () => { + const mockGenerator = (function* () { + yield { type: 'GENERATE_N', n: 3 } + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(false) + expect(result.generateN).toBe(3) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept valid tool call object', async () => { + const mockGenerator = (function* () { + yield { toolName: 'read_files', input: { paths: ['test.txt'] } } + yield { toolName: 'end_turn', input: {} } + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept tool call with includeToolCall option', async () => { + const mockGenerator = (function* () { + yield { + toolName: 'read_files', + input: { paths: ['test.txt'] }, + includeToolCall: false, + } + yield { toolName: 'end_turn', input: {} } + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should reject random string values', async () => { + const mockGenerator = (function* () { + yield 'INVALID_STEP' as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject null yield values', async () => { + const mockGenerator = (function* () { + yield null as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject undefined yield values', async () => { + const mockGenerator = (function* () { + yield undefined as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject tool call without toolName', async () => { + const mockGenerator = (function* () { + yield { input: { paths: ['test.txt'] } } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject tool call without input', async () => { + const mockGenerator = (function* () { + yield { toolName: 'read_files' } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + }) + describe('logging and context', () => { it('should log agent execution start', async () => { const mockGenerator = (function* () { diff --git a/packages/agent-runtime/src/__tests__/spawn-agents-message-history.test.ts b/packages/agent-runtime/src/__tests__/spawn-agents-message-history.test.ts index b47223129..90715389b 100644 --- a/packages/agent-runtime/src/__tests__/spawn-agents-message-history.test.ts +++ b/packages/agent-runtime/src/__tests__/spawn-agents-message-history.test.ts @@ -52,7 +52,7 @@ describe('Spawn Agents Message History', () => { assistantMessage('Mock agent response'), ], }, - output: { type: 'lastMessage', value: 'Mock agent response' }, + output: { type: 'lastMessage', value: [assistantMessage('Mock agent response')] }, } }) diff --git a/packages/agent-runtime/src/__tests__/spawn-agents-permissions.test.ts b/packages/agent-runtime/src/__tests__/spawn-agents-permissions.test.ts index 3f827e2a6..dc8d32252 100644 --- a/packages/agent-runtime/src/__tests__/spawn-agents-permissions.test.ts +++ b/packages/agent-runtime/src/__tests__/spawn-agents-permissions.test.ts @@ -85,7 +85,7 @@ describe('Spawn Agents Permissions', () => { ...options.agentState, messageHistory: [assistantMessage('Mock agent response')], }, - output: { type: 'lastMessage', value: 'Mock agent response' }, + output: { type: 'lastMessage', value: [assistantMessage('Mock agent response')] }, } }) }) diff --git a/packages/agent-runtime/src/__tests__/subagent-streaming.test.ts b/packages/agent-runtime/src/__tests__/subagent-streaming.test.ts index 1bd1a6970..134cadd8b 100644 --- a/packages/agent-runtime/src/__tests__/subagent-streaming.test.ts +++ b/packages/agent-runtime/src/__tests__/subagent-streaming.test.ts @@ -96,7 +96,7 @@ describe('Subagent Streaming', () => { ...options.agentState, messageHistory: [assistantMessage('Test response from subagent')], }, - output: { type: 'lastMessage', value: 'Test response from subagent' }, + output: { type: 'lastMessage', value: [assistantMessage('Test response from subagent')] }, } }) diff --git a/packages/agent-runtime/src/__tests__/tool-stream-parser.test.ts b/packages/agent-runtime/src/__tests__/tool-stream-parser.test.ts index b5c5dfb23..0cff771da 100644 --- a/packages/agent-runtime/src/__tests__/tool-stream-parser.test.ts +++ b/packages/agent-runtime/src/__tests__/tool-stream-parser.test.ts @@ -4,7 +4,7 @@ import { getToolCallString } from '@codebuff/common/tools/utils' import { beforeEach, describe, expect, it } from 'bun:test' import { globalStopSequence } from '../constants' -import { processStreamWithTags } from '../tool-stream-parser' +import { processStreamWithTools } from '../tool-stream-parser' import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' @@ -61,7 +61,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -129,7 +129,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -206,7 +206,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -282,7 +282,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -349,7 +349,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -415,7 +415,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -488,7 +488,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -555,7 +555,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -612,7 +612,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -655,7 +655,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -716,7 +716,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -787,7 +787,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, diff --git a/packages/agent-runtime/src/__tests__/web-search-tool.test.ts b/packages/agent-runtime/src/__tests__/web-search-tool.test.ts index 417f57819..badc3e126 100644 --- a/packages/agent-runtime/src/__tests__/web-search-tool.test.ts +++ b/packages/agent-runtime/src/__tests__/web-search-tool.test.ts @@ -61,23 +61,22 @@ describe('web_search tool with researcher agent (via web API facade)', () => { runAgentStepBaseParams = { ...agentRuntimeImpl, + additionalToolDefinitions: () => Promise.resolve({}), agentType: 'researcher', ancestorRunIds: [], clientSessionId: 'test-session', fileContext: mockFileContext, fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, repoId: undefined, repoUrl: undefined, runId: 'test-run-id', signal: new AbortController().signal, spawnParams: undefined, system: 'Test system prompt', - textOverride: null, + tools: {}, userId: TEST_USER_ID, userInputId: 'test-input', - - additionalToolDefinitions: () => Promise.resolve({}), - onResponseChunk: () => {}, } // Mock analytics and tracing diff --git a/packages/agent-runtime/src/prompt-agent-stream.ts b/packages/agent-runtime/src/prompt-agent-stream.ts index 3447b5948..ecf4a691c 100644 --- a/packages/agent-runtime/src/prompt-agent-stream.ts +++ b/packages/agent-runtime/src/prompt-agent-stream.ts @@ -12,6 +12,7 @@ import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsOf } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { OpenRouterProviderOptions } from '@codebuff/internal/openrouter-ai-sdk' +import type { ToolSet } from 'ai' export const getAgentStreamFromTemplate = (params: { agentId?: string @@ -25,7 +26,7 @@ export const getAgentStreamFromTemplate = (params: { runId: string sessionConnections: SessionRecord template: AgentTemplate - textOverride: string | null + tools: ToolSet userId: string | undefined userInputId: string @@ -46,7 +47,7 @@ export const getAgentStreamFromTemplate = (params: { runId, sessionConnections, template, - textOverride, + tools, userId, userInputId, @@ -56,14 +57,6 @@ export const getAgentStreamFromTemplate = (params: { trackEvent, } = params - if (textOverride !== null) { - async function* stream(): ReturnType { - yield { type: 'text', text: textOverride!, agentId } - return crypto.randomUUID() - } - return stream() - } - if (!template) { throw new Error('Agent template is null/undefined') } @@ -71,24 +64,26 @@ export const getAgentStreamFromTemplate = (params: { const { model } = template const aiSdkStreamParams: ParamsOf = { + agentId, apiKey, - runId, + clientSessionId, + fingerprintId, + includeCacheControl, + logger, + liveUserInputRecord, + maxOutputTokens: 32_000, + maxRetries: 3, messages, model, + runId, + sessionConnections, stopSequences: [globalStopSequence], - clientSessionId, - fingerprintId, - userInputId, + tools, userId, - maxOutputTokens: 32_000, + userInputId, + onCostCalculated, - includeCacheControl, - agentId, - maxRetries: 3, sendAction, - liveUserInputRecord, - sessionConnections, - logger, trackEvent, } diff --git a/packages/agent-runtime/src/run-agent-step.ts b/packages/agent-runtime/src/run-agent-step.ts index 19720c705..fa0040ef8 100644 --- a/packages/agent-runtime/src/run-agent-step.ts +++ b/packages/agent-runtime/src/run-agent-step.ts @@ -14,7 +14,8 @@ import { runProgrammaticStep } from './run-programmatic-step' import { additionalSystemPrompts } from './system-prompt/prompts' import { getAgentTemplate } from './templates/agent-registry' import { getAgentPrompt } from './templates/strings' -import { processStreamWithTools } from './tools/stream-parser' +import { getToolSet } from './tools/prompts' +import { processStream } from './tools/stream-parser' import { getAgentOutput } from './util/agent-output' import { withSystemInstructionTags, @@ -106,7 +107,7 @@ export const runAgentStep = async ( trackEvent: TrackEventFn promptAiSdk: PromptAiSdkFn } & ParamsExcluding< - typeof processStreamWithTools, + typeof processStream, | 'agentContext' | 'agentState' | 'agentStepId' @@ -337,6 +338,7 @@ export const runAgentStep = async ( let fullResponse = '' const toolResults: ToolMessage[] = [] + // Raw stream from AI SDK const stream = getAgentStreamFromTemplate({ ...params, agentId: agentState.parentId ? agentState.agentId : undefined, @@ -352,7 +354,7 @@ export const runAgentStep = async ( messageId, toolCalls, toolResults: newToolResults, - } = await processStreamWithTools({ + } = await processStream({ ...params, agentContext, agentState, @@ -468,37 +470,35 @@ export const runAgentStep = async ( export async function loopAgentSteps( params: { - userInputId: string - agentType: AgentTemplateType + addAgentStep: AddAgentStepFn agentState: AgentState - prompt: string | undefined + agentType: AgentTemplateType + clearUserPromptMessagesAfterResponse?: boolean + clientSessionId: string content?: Array - spawnParams: Record | undefined fileContext: ProjectFileContext + finishAgentRun: FinishAgentRunFn localAgentTemplates: Record - clearUserPromptMessagesAfterResponse?: boolean + logger: Logger parentSystemPrompt?: string + prompt: string | undefined signal: AbortSignal - - userId: string | undefined - clientSessionId: string - + spawnParams: Record | undefined startAgentRun: StartAgentRunFn - finishAgentRun: FinishAgentRunFn - addAgentStep: AddAgentStepFn - logger: Logger + userId: string | undefined + userInputId: string } & ParamsExcluding & ParamsExcluding< typeof runProgrammaticStep, - | 'runId' | 'agentState' - | 'template' + | 'onCostCalculated' | 'prompt' - | 'toolCallParams' - | 'stepsComplete' + | 'runId' | 'stepNumber' + | 'stepsComplete' | 'system' - | 'onCostCalculated' + | 'template' + | 'toolCallParams' > & ParamsExcluding & ParamsExcluding< @@ -526,7 +526,7 @@ export async function loopAgentSteps( | 'runId' | 'spawnParams' | 'system' - | 'textOverride' + | 'tools' > & ParamsExcluding< AddAgentStepFn, @@ -543,23 +543,23 @@ export async function loopAgentSteps( output: AgentOutput }> { const { - userInputId, - agentType, + addAgentStep, agentState, - prompt, + agentType, + clearUserPromptMessagesAfterResponse = true, + clientSessionId, content, - spawnParams, fileContext, + finishAgentRun, localAgentTemplates, - userId, - clientSessionId, - clearUserPromptMessagesAfterResponse = true, + logger, parentSystemPrompt, + prompt, signal, + spawnParams, startAgentRun, - finishAgentRun, - addAgentStep, - logger, + userId, + userInputId, } = params const agentTemplate = await getAgentTemplate({ @@ -631,6 +631,19 @@ export async function loopAgentSteps( }, })) ?? '' + const tools = await getToolSet({ + toolNames: agentTemplate.toolNames, + additionalToolDefinitions: async () => { + if (!cachedAdditionalToolDefinitions) { + cachedAdditionalToolDefinitions = await additionalToolDefinitions({ + ...params, + agentTemplate, + }) + } + return cachedAdditionalToolDefinitions + }, + }) + const hasUserMessage = Boolean( prompt || (spawnParams && Object.keys(spawnParams).length > 0), ) @@ -702,26 +715,26 @@ export async function loopAgentSteps( const startTime = new Date() // 1. Run programmatic step first if it exists - let textOverride = null let n: number | undefined = undefined if (agentTemplate.handleSteps) { const programmaticResult = await runProgrammaticStep({ ...params, - runId, + agentState: currentAgentState, - template: agentTemplate, localAgentTemplates, - prompt: currentPrompt, - toolCallParams: currentParams, - system, - stepsComplete: shouldEndTurn, - stepNumber: totalSteps, nResponses, onCostCalculated: async (credits: number) => { agentState.creditsUsed += credits agentState.directCreditsUsed += credits }, + prompt: currentPrompt, + runId, + stepNumber: totalSteps, + stepsComplete: shouldEndTurn, + system, + template: agentTemplate, + toolCallParams: currentParams, }) const { agentState: programmaticAgentState, @@ -729,7 +742,6 @@ export async function loopAgentSteps( stepNumber, generateN, } = programmaticResult - textOverride = programmaticResult.textOverride n = generateN currentAgentState = programmaticAgentState @@ -786,6 +798,15 @@ export async function loopAgentSteps( nResponses: generatedResponses, } = await runAgentStep({ ...params, + + agentState: currentAgentState, + n, + prompt: currentPrompt, + runId, + spawnParams: currentParams, + system, + tools, + additionalToolDefinitions: async () => { if (!cachedAdditionalToolDefinitions) { cachedAdditionalToolDefinitions = await additionalToolDefinitions({ @@ -795,13 +816,6 @@ export async function loopAgentSteps( } return cachedAdditionalToolDefinitions }, - textOverride: textOverride, - runId, - agentState: currentAgentState, - prompt: currentPrompt, - spawnParams: currentParams, - system, - n, }) if (newAgentState.runId) { diff --git a/packages/agent-runtime/src/run-programmatic-step.ts b/packages/agent-runtime/src/run-programmatic-step.ts index f8cda7edf..d7b3fcd56 100644 --- a/packages/agent-runtime/src/run-programmatic-step.ts +++ b/packages/agent-runtime/src/run-programmatic-step.ts @@ -1,13 +1,17 @@ -import { getToolCallString } from '@codebuff/common/tools/utils' import { getErrorObject } from '@codebuff/common/util/error' import { assistantMessage } from '@codebuff/common/util/messages' import { cloneDeep } from 'lodash' import { executeToolCall } from './tools/tool-executor' +import { parseTextWithToolCalls } from './util/parse-tool-calls-from-text' + +import type { ParsedSegment } from './util/parse-tool-calls-from-text' import type { FileProcessingState } from './tools/handlers/tool/write-file' import type { ExecuteToolCallParams } from './tools/tool-executor' import type { CodebuffToolCall } from '@codebuff/common/tools/list' +import { HandleStepsYieldValueSchema } from '@codebuff/common/types/agent-template' + import type { AgentTemplate, StepGenerator, @@ -21,10 +25,12 @@ import type { AddAgentStepFn } from '@codebuff/common/types/contracts/database' import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { ToolMessage } from '@codebuff/common/types/messages/codebuff-message' -import type { ToolResultOutput } from '@codebuff/common/types/messages/content-part' +import type { + ToolCallPart, + ToolResultOutput, +} from '@codebuff/common/types/messages/content-part' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState } from '@codebuff/common/types/session-state' - // Maintains generator state for all agents. Generator state can't be serialized, so we store it in memory. const runIdToGenerator: Record = {} export const runIdToStepAll: Set = new Set() @@ -40,26 +46,26 @@ export function clearAgentGeneratorCache(params: { logger: Logger }) { // Function to handle programmatic agents export async function runProgrammaticStep( params: { + addAgentStep: AddAgentStepFn agentState: AgentState - template: AgentTemplate + clientSessionId: string + fingerprintId: string + handleStepsLogChunk: HandleStepsLogChunkFn + localAgentTemplates: Record + logger: Logger + nResponses?: string[] + onResponseChunk: (chunk: string | PrintModeEvent) => void prompt: string | undefined - toolCallParams: Record | undefined - system: string | undefined - userId: string | undefined repoId: string | undefined repoUrl: string | undefined - userInputId: string - fingerprintId: string - clientSessionId: string - onResponseChunk: (chunk: string | PrintModeEvent) => void - localAgentTemplates: Record - stepsComplete: boolean stepNumber: number - handleStepsLogChunk: HandleStepsLogChunkFn + stepsComplete: boolean + template: AgentTemplate + toolCallParams: Record | undefined sendAction: SendActionFn - addAgentStep: AddAgentStepFn - logger: Logger - nResponses?: string[] + system: string | undefined + userId: string | undefined + userInputId: string } & Omit< ExecuteToolCallParams, | 'toolName' @@ -89,7 +95,6 @@ export async function runProgrammaticStep( >, ): Promise<{ agentState: AgentState - textOverride: string | null endTurn: boolean stepNumber: number generateN?: number @@ -170,7 +175,7 @@ export async function runProgrammaticStep( // Clear the STEP_ALL mode. Stepping can continue if handleSteps doesn't return. runIdToStepAll.delete(agentState.runId) } else { - return { agentState, textOverride: null, endTurn: false, stepNumber } + return { agentState, endTurn: false, stepNumber } } } @@ -205,7 +210,6 @@ export async function runProgrammaticStep( let toolResult: ToolResultOutput[] | undefined = undefined let endTurn = false - let textOverride: string | null = null let generateN: number | undefined = undefined let startTime = new Date() @@ -232,6 +236,16 @@ export async function runProgrammaticStep( endTurn = true break } + + // Validate the yield value from handleSteps + const parseResult = HandleStepsYieldValueSchema.safeParse(result.value) + if (!parseResult.success) { + throw new Error( + `Invalid yield value from handleSteps in agent ${template.id}: ${parseResult.error.message}. ` + + `Received: ${JSON.stringify(result.value)}`, + ) + } + if (result.value === 'STEP') { break } @@ -241,7 +255,25 @@ export async function runProgrammaticStep( } if ('type' in result.value && result.value.type === 'STEP_TEXT') { - textOverride = result.value.text + // Parse text and tool calls, preserving interleaved order + const segments = parseTextWithToolCalls(result.value.text) + + if (segments.length > 0) { + // Execute segments (text and tool calls) in order + toolResult = await executeSegmentsArray(segments, { + ...params, + agentContext, + agentStepId, + agentTemplate: template, + agentState, + fileProcessingState, + fullResponse: '', + previousToolCallFinished: Promise.resolve(), + toolCalls, + toolResults, + onResponseChunk, + }) + } break } @@ -254,121 +286,22 @@ export async function runProgrammaticStep( } // Process tool calls yielded by the generator - const toolCallWithoutId = result.value - const toolCallId = crypto.randomUUID() - const toolCall = { - ...toolCallWithoutId, - toolCallId, - } as CodebuffToolCall & { - includeToolCall?: boolean - } + const toolCall = result.value as ToolCallToExecute - // Note: We don't check if the tool is available for the agent template anymore. - // You can run any tool from handleSteps now! - // if (!template.toolNames.includes(toolCall.toolName)) { - // throw new Error( - // `Tool ${toolCall.toolName} is not available for agent ${template.id}. Available tools: ${template.toolNames.join(', ')}`, - // ) - // } - - const excludeToolFromMessageHistory = toolCall?.includeToolCall === false - // Add assistant message with the tool call before executing it - if (!excludeToolFromMessageHistory) { - const toolCallString = getToolCallString( - toolCall.toolName, - toolCall.input, - ) - onResponseChunk(toolCallString) - agentState.messageHistory.push(assistantMessage(toolCallString)) - // Optional call handles both top-level and nested agents - sendSubagentChunk({ - userInputId, - agentId: agentState.agentId, - agentType: agentState.agentType!, - chunk: toolCallString, - forwardToPrompt: !agentState.parentId, - }) - } - - // Execute the tool synchronously and get the result immediately - // Wrap onResponseChunk to add parentAgentId to nested agent events - await executeToolCall({ + toolResult = await executeSingleToolCall(toolCall, { ...params, - toolName: toolCall.toolName, - input: toolCall.input, - autoInsertEndStepParam: true, - excludeToolFromMessageHistory, - fromHandleSteps: true, - agentContext, agentStepId, agentTemplate: template, + agentState, fileProcessingState, fullResponse: '', previousToolCallFinished: Promise.resolve(), - toolCallId, toolCalls, toolResults, - toolResultsToAddAfterStream: [], - - onResponseChunk: (chunk: string | PrintModeEvent) => { - if (typeof chunk === 'string') { - onResponseChunk(chunk) - return - } - - // Only add parentAgentId if this programmatic agent has a parent (i.e., it's nested) - // This ensures we don't add parentAgentId to top-level spawns - if (agentState.parentId) { - const parentAgentId = agentState.agentId - - switch (chunk.type) { - case 'subagent_start': - case 'subagent_finish': - if (!chunk.parentAgentId) { - onResponseChunk({ - ...chunk, - parentAgentId, - }) - return - } - break - case 'tool_call': - case 'tool_result': { - if (!chunk.parentAgentId) { - const debugPayload = - chunk.type === 'tool_call' - ? { - eventType: chunk.type, - agentId: chunk.agentId, - parentId: parentAgentId, - } - : { - eventType: chunk.type, - parentId: parentAgentId, - } - onResponseChunk({ - ...chunk, - parentAgentId, - }) - return - } - break - } - default: - break - } - } - - // For other events or top-level spawns, send as-is - onResponseChunk(chunk) - }, + onResponseChunk, }) - // Get the latest tool result - const latestToolResult = toolResults[toolResults.length - 1] - toolResult = latestToolResult?.content - if (agentState.runId) { await addAgentStep({ ...params, @@ -393,7 +326,6 @@ export async function runProgrammaticStep( return { agentState, - textOverride, endTurn, stepNumber, generateN, @@ -437,7 +369,6 @@ export async function runProgrammaticStep( return { agentState, - textOverride: null, endTurn, stepNumber, generateN: undefined, @@ -462,3 +393,170 @@ export const getPublicAgentState = ( output, } } + +/** + * Represents a tool call to be executed. + * Can optionally include `includeToolCall: false` to exclude from message history. + */ +type ToolCallToExecute = { + toolName: string + input: Record + includeToolCall?: boolean +} + +/** + * Parameters for executing an array of tool calls. + */ +type ExecuteToolCallsArrayParams = Omit< + ExecuteToolCallParams, + | 'toolName' + | 'input' + | 'autoInsertEndStepParam' + | 'excludeToolFromMessageHistory' + | 'toolCallId' + | 'toolResultsToAddAfterStream' +> & { + agentState: AgentState + onResponseChunk: (chunk: string | PrintModeEvent) => void +} + +/** + * Executes a single tool call. + * Adds the tool call as an assistant message and then executes it. + * + * @returns The tool result from the executed tool call. + */ +async function executeSingleToolCall( + toolCallToExecute: ToolCallToExecute, + params: ExecuteToolCallsArrayParams, +): Promise { + const { agentState, onResponseChunk, toolResults } = params + + // Note: We don't check if the tool is available for the agent template anymore. + // You can run any tool from handleSteps now! + // if (!template.toolNames.includes(toolCall.toolName)) { + // throw new Error( + // `Tool ${toolCall.toolName} is not available for agent ${template.id}. Available tools: ${template.toolNames.join(', ')}`, + // ) + // } + + const toolCallId = crypto.randomUUID() + const excludeToolFromMessageHistory = + toolCallToExecute.includeToolCall === false + + // Add assistant message with the tool call before executing it + if (!excludeToolFromMessageHistory) { + const toolCallPart: ToolCallPart = { + type: 'tool-call', + toolCallId, + toolName: toolCallToExecute.toolName, + input: toolCallToExecute.input, + } + // onResponseChunk({ + // ...toolCallPart, + // type: 'tool_call', + // agentId: agentState.agentId, + // parentAgentId: agentState.parentId, + // }) + // NOTE(James): agentState.messageHistory is readonly for some reason (?!). Recreating the array is a workaround. + agentState.messageHistory = [...agentState.messageHistory] + agentState.messageHistory.push(assistantMessage(toolCallPart)) + // Optional call handles both top-level and nested agents + // sendSubagentChunk({ + // userInputId, + // agentId: agentState.agentId, + // agentType: agentState.agentType!, + // chunk: toolCallString, + // forwardToPrompt: !agentState.parentId, + // }) + } + + // Execute the tool call + await executeToolCall({ + ...params, + toolName: toolCallToExecute.toolName as any, + input: toolCallToExecute.input, + autoInsertEndStepParam: true, + excludeToolFromMessageHistory, + fromHandleSteps: true, + toolCallId, + toolResultsToAddAfterStream: [], + + onResponseChunk: (chunk: string | PrintModeEvent) => { + if (typeof chunk === 'string') { + onResponseChunk(chunk) + return + } + + // Only add parentAgentId if this programmatic agent has a parent (i.e., it's nested) + // This ensures we don't add parentAgentId to top-level spawns + if (agentState.parentId) { + const parentAgentId = agentState.agentId + + switch (chunk.type) { + case 'subagent_start': + case 'subagent_finish': + if (!chunk.parentAgentId) { + onResponseChunk({ + ...chunk, + parentAgentId, + }) + return + } + break + case 'tool_call': + case 'tool_result': { + if (!chunk.parentAgentId) { + onResponseChunk({ + ...chunk, + parentAgentId, + }) + return + } + break + } + default: + break + } + } + + // For other events or top-level spawns, send as-is + onResponseChunk(chunk) + }, + }) + + // Get the latest tool result + return toolResults[toolResults.length - 1]?.content +} + +/** + * Executes an array of segments (text and tool calls) sequentially. + * Text segments are added as assistant messages. + * Tool calls are added as assistant messages and then executed. + * + * @returns The tool result from the last executed tool call. + */ +async function executeSegmentsArray( + segments: ParsedSegment[], + params: ExecuteToolCallsArrayParams, +): Promise { + const { agentState } = params + + let toolResults: ToolResultOutput[] = [] + + for (const segment of segments) { + if (segment.type === 'text') { + // Add text as an assistant message + agentState.messageHistory = [...agentState.messageHistory] + agentState.messageHistory.push(assistantMessage(segment.text)) + } else { + // Handle tool call segment + const toolResult = await executeSingleToolCall(segment, params) + if (toolResult) { + toolResults.push(...toolResult) + } + } + } + + return toolResults +} diff --git a/packages/agent-runtime/src/templates/prompts.ts b/packages/agent-runtime/src/templates/prompts.ts index e1cb77d0a..ab86aaad0 100644 --- a/packages/agent-runtime/src/templates/prompts.ts +++ b/packages/agent-runtime/src/templates/prompts.ts @@ -73,17 +73,6 @@ Notes: - There are two types of input arguments for agents: prompt and params. The prompt is a string, and the params is a json object. Some agents require only one or the other, some require both, and some require none. - Below are the *only* available agents by their agent_type. Other agents may be referenced earlier in the conversation, but they are not available to you. -Example: - -${getToolCallString('spawn_agents', { - agents: [ - { - agent_type: 'example-agent', - prompt: 'Do an example task for me', - }, - ], -})} - Spawn only the below agents: ${agentsDescription}` diff --git a/packages/agent-runtime/src/templates/strings.ts b/packages/agent-runtime/src/templates/strings.ts index 2f7c4e75f..0fd3bc20c 100644 --- a/packages/agent-runtime/src/templates/strings.ts +++ b/packages/agent-runtime/src/templates/strings.ts @@ -11,11 +11,6 @@ import { getProjectFileTreePrompt, getSystemInfoPrompt, } from '../system-prompt/prompts' -import { - fullToolList, - getShortToolInstructions, - getToolsInstructions, -} from '../tools/prompts' import { parseUserMessage } from '../util/messages' import type { AgentTemplate, PlaceholderValue } from './types' @@ -113,8 +108,7 @@ export async function formatPrompt( [PLACEHOLDER.REMAINING_STEPS]: () => `${agentState.stepsRemaining!}`, [PLACEHOLDER.PROJECT_ROOT]: () => fileContext.projectRoot, [PLACEHOLDER.SYSTEM_INFO_PROMPT]: () => getSystemInfoPrompt(fileContext), - [PLACEHOLDER.TOOLS_PROMPT]: async () => - getToolsInstructions(tools, (await additionalToolDefinitions()) ?? {}), + [PLACEHOLDER.TOOLS_PROMPT]: async () => '', [PLACEHOLDER.AGENTS_PROMPT]: () => buildSpawnableAgentsDescription(params), [PLACEHOLDER.USER_CWD]: () => fileContext.cwd, [PLACEHOLDER.USER_INPUT_PROMPT]: () => escapeString(lastUserInput ?? ''), @@ -204,15 +198,7 @@ export async function getAgentPrompt( // Add tool instructions, spawnable agents, and output schema prompts to instructionsPrompt if (promptType.type === 'instructionsPrompt' && agentState.agentType) { - const toolsInstructions = agentTemplate.inheritParentSystemPrompt - ? fullToolList(agentTemplate.toolNames, await additionalToolDefinitions()) - : getShortToolInstructions( - agentTemplate.toolNames, - await additionalToolDefinitions(), - ) addendum += - '\n\n' + - toolsInstructions + '\n\n' + (await buildSpawnableAgentsDescription({ ...params, diff --git a/packages/agent-runtime/src/tool-stream-parser.old.ts b/packages/agent-runtime/src/tool-stream-parser.old.ts new file mode 100644 index 000000000..e7e07ca43 --- /dev/null +++ b/packages/agent-runtime/src/tool-stream-parser.old.ts @@ -0,0 +1,217 @@ +import { AnalyticsEvent } from '@codebuff/common/constants/analytics-events' +import { + endsAgentStepParam, + endToolTag, + startToolTag, + toolNameParam, +} from '@codebuff/common/tools/constants' + +import type { Model } from '@codebuff/common/old-constants' +import type { TrackEventFn } from '@codebuff/common/types/contracts/analytics' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { + PrintModeError, + PrintModeText, +} from '@codebuff/common/types/print-mode' + +const toolExtractionPattern = new RegExp( + `${startToolTag}(.*?)${endToolTag}`, + 'gs', +) + +const completionSuffix = `${JSON.stringify(endsAgentStepParam)}: true\n}${endToolTag}` + +export async function* processStreamWithTags(params: { + stream: AsyncGenerator + processors: Record< + string, + { + onTagStart: (tagName: string, attributes: Record) => void + onTagEnd: (tagName: string, params: Record) => void + } + > + defaultProcessor: (toolName: string) => { + onTagStart: (tagName: string, attributes: Record) => void + onTagEnd: (tagName: string, params: Record) => void + } + onError: (tagName: string, errorMessage: string) => void + onResponseChunk: (chunk: PrintModeText | PrintModeError) => void + logger: Logger + loggerOptions?: { + userId?: string + model?: Model + agentName?: string + } + trackEvent: TrackEventFn +}): AsyncGenerator { + const { + stream, + processors, + defaultProcessor, + onError, + onResponseChunk, + logger, + loggerOptions, + trackEvent, + } = params + + let streamCompleted = false + let buffer = '' + let autocompleted = false + + function extractToolCalls(): string[] { + const matches: string[] = [] + let lastIndex = 0 + for (const match of buffer.matchAll(toolExtractionPattern)) { + if (match.index > lastIndex) { + onResponseChunk({ + type: 'text', + text: buffer.slice(lastIndex, match.index), + }) + } + lastIndex = match.index + match[0].length + matches.push(match[1]) + } + + buffer = buffer.slice(lastIndex) + return matches + } + + function processToolCallContents(contents: string): void { + let parsedParams: any + try { + parsedParams = JSON.parse(contents) + } catch (error: any) { + trackEvent({ + event: AnalyticsEvent.MALFORMED_TOOL_CALL_JSON, + userId: loggerOptions?.userId ?? '', + properties: { + contents: JSON.stringify(contents), + model: loggerOptions?.model, + agent: loggerOptions?.agentName, + error: { + name: error.name, + message: error.message, + stack: error.stack, + }, + autocompleted, + }, + logger, + }) + const shortenedContents = + contents.length < 200 + ? contents + : contents.slice(0, 100) + '...' + contents.slice(-100) + const errorMessage = `Invalid JSON: ${JSON.stringify(shortenedContents)}\nError: ${error.message}` + onResponseChunk({ + type: 'error', + message: errorMessage, + }) + onError('parse_error', errorMessage) + return + } + + const toolName = parsedParams[toolNameParam] as keyof typeof processors + const processor = + typeof toolName === 'string' + ? processors[toolName] ?? defaultProcessor(toolName) + : undefined + if (!processor) { + trackEvent({ + event: AnalyticsEvent.UNKNOWN_TOOL_CALL, + userId: loggerOptions?.userId ?? '', + properties: { + contents, + toolName, + model: loggerOptions?.model, + agent: loggerOptions?.agentName, + autocompleted, + }, + logger, + }) + onError( + 'parse_error', + `Unknown tool ${JSON.stringify(toolName)} for tool call: ${contents}`, + ) + return + } + + trackEvent({ + event: AnalyticsEvent.TOOL_USE, + userId: loggerOptions?.userId ?? '', + properties: { + toolName, + contents, + parsedParams, + autocompleted, + model: loggerOptions?.model, + agent: loggerOptions?.agentName, + }, + logger, + }) + delete parsedParams[toolNameParam] + + processor.onTagStart(toolName, {}) + processor.onTagEnd(toolName, parsedParams) + } + + function extractToolsFromBufferAndProcess(forceFlush = false) { + const matches = extractToolCalls() + matches.forEach(processToolCallContents) + if (forceFlush) { + onResponseChunk({ + type: 'text', + text: buffer, + }) + buffer = '' + } + } + + function* processChunk( + chunk: StreamChunk | undefined, + ): Generator { + if (chunk !== undefined && chunk.type === 'text') { + buffer += chunk.text + } + extractToolsFromBufferAndProcess() + + if (chunk === undefined) { + streamCompleted = true + if (buffer.includes(startToolTag)) { + buffer += completionSuffix + chunk = { + type: 'text', + text: completionSuffix, + } + autocompleted = true + } + extractToolsFromBufferAndProcess(true) + } + + if (chunk) { + yield chunk + } + } + + let messageId: string | null = null + while (true) { + const { value, done } = await stream.next() + if (done) { + messageId = value + break + } + if (streamCompleted) { + break + } + + yield* processChunk(value) + } + + if (!streamCompleted) { + // After the stream ends, try parsing one last time in case there's leftover text + yield* processChunk(undefined) + } + + return messageId +} diff --git a/packages/agent-runtime/src/tool-stream-parser.ts b/packages/agent-runtime/src/tool-stream-parser.ts index 0191596c4..2f096695d 100644 --- a/packages/agent-runtime/src/tool-stream-parser.ts +++ b/packages/agent-runtime/src/tool-stream-parser.ts @@ -1,10 +1,4 @@ import { AnalyticsEvent } from '@codebuff/common/constants/analytics-events' -import { - endsAgentStepParam, - endToolTag, - startToolTag, - toolNameParam, -} from '@codebuff/common/tools/constants' import type { Model } from '@codebuff/common/old-constants' import type { TrackEventFn } from '@codebuff/common/types/contracts/analytics' @@ -13,17 +7,9 @@ import type { Logger } from '@codebuff/common/types/contracts/logger' import type { PrintModeError, PrintModeText, - PrintModeToolCall, } from '@codebuff/common/types/print-mode' -const toolExtractionPattern = new RegExp( - `${startToolTag}(.*?)${endToolTag}`, - 'gs', -) - -const completionSuffix = `${JSON.stringify(endsAgentStepParam)}: true\n}${endToolTag}` - -export async function* processStreamWithTags(params: { +export async function* processStreamWithTools(params: { stream: AsyncGenerator processors: Record< string, @@ -37,9 +23,7 @@ export async function* processStreamWithTags(params: { onTagEnd: (tagName: string, params: Record) => void } onError: (tagName: string, errorMessage: string) => void - onResponseChunk: ( - chunk: PrintModeText | PrintModeToolCall | PrintModeError, - ) => void + onResponseChunk: (chunk: PrintModeText | PrintModeError) => void logger: Logger loggerOptions?: { userId?: string @@ -58,87 +42,18 @@ export async function* processStreamWithTags(params: { loggerOptions, trackEvent, } = params - let streamCompleted = false let buffer = '' let autocompleted = false - function extractToolCalls(): string[] { - const matches: string[] = [] - let lastIndex = 0 - for (const match of buffer.matchAll(toolExtractionPattern)) { - if (match.index > lastIndex) { - onResponseChunk({ - type: 'text', - text: buffer.slice(lastIndex, match.index), - }) - } - lastIndex = match.index + match[0].length - matches.push(match[1]) - } + function processToolCallObject(params: { + toolName: string + input: any + contents?: string + }): void { + const { toolName, input, contents } = params - buffer = buffer.slice(lastIndex) - return matches - } - - function processToolCallContents(contents: string): void { - let parsedParams: any - try { - parsedParams = JSON.parse(contents) - } catch (error: any) { - trackEvent({ - event: AnalyticsEvent.MALFORMED_TOOL_CALL_JSON, - userId: loggerOptions?.userId ?? '', - properties: { - contents: JSON.stringify(contents), - model: loggerOptions?.model, - agent: loggerOptions?.agentName, - error: { - name: error.name, - message: error.message, - stack: error.stack, - }, - autocompleted, - }, - logger, - }) - const shortenedContents = - contents.length < 200 - ? contents - : contents.slice(0, 100) + '...' + contents.slice(-100) - const errorMessage = `Invalid JSON: ${JSON.stringify(shortenedContents)}\nError: ${error.message}` - onResponseChunk({ - type: 'error', - message: errorMessage, - }) - onError('parse_error', errorMessage) - return - } - - const toolName = parsedParams[toolNameParam] as keyof typeof processors - const processor = - typeof toolName === 'string' - ? processors[toolName] ?? defaultProcessor(toolName) - : undefined - if (!processor) { - trackEvent({ - event: AnalyticsEvent.UNKNOWN_TOOL_CALL, - userId: loggerOptions?.userId ?? '', - properties: { - contents, - toolName, - model: loggerOptions?.model, - agent: loggerOptions?.agentName, - autocompleted, - }, - logger, - }) - onError( - 'parse_error', - `Unknown tool ${JSON.stringify(toolName)} for tool call: ${contents}`, - ) - return - } + const processor = processors[toolName] ?? defaultProcessor(toolName) trackEvent({ event: AnalyticsEvent.TOOL_USE, @@ -146,55 +61,48 @@ export async function* processStreamWithTags(params: { properties: { toolName, contents, - parsedParams, + parsedParams: input, autocompleted, model: loggerOptions?.model, agent: loggerOptions?.agentName, }, logger, }) - delete parsedParams[toolNameParam] processor.onTagStart(toolName, {}) - processor.onTagEnd(toolName, parsedParams) + processor.onTagEnd(toolName, input) } - function extractToolsFromBufferAndProcess(forceFlush = false) { - const matches = extractToolCalls() - matches.forEach(processToolCallContents) - if (forceFlush) { + function flush() { + if (buffer) { onResponseChunk({ type: 'text', text: buffer, }) - buffer = '' } + buffer = '' } function* processChunk( chunk: StreamChunk | undefined, ): Generator { - if (chunk !== undefined && chunk.type === 'text') { - buffer += chunk.text - } - extractToolsFromBufferAndProcess() - if (chunk === undefined) { + flush() streamCompleted = true - if (buffer.includes(startToolTag)) { - buffer += completionSuffix - chunk = { - type: 'text', - text: completionSuffix, - } - autocompleted = true - } - extractToolsFromBufferAndProcess(true) + return } - if (chunk) { - yield chunk + if (chunk.type === 'text') { + buffer += chunk.text + } else { + flush() } + + if (chunk.type === 'tool-call') { + processToolCallObject(chunk) + } + + yield chunk } let messageId: string | null = null @@ -207,14 +115,11 @@ export async function* processStreamWithTags(params: { if (streamCompleted) { break } - yield* processChunk(value) } - if (!streamCompleted) { // After the stream ends, try parsing one last time in case there's leftover text yield* processChunk(undefined) } - return messageId } diff --git a/packages/agent-runtime/src/tools/stream-parser.ts b/packages/agent-runtime/src/tools/stream-parser.ts index f47d41a67..435afebd2 100644 --- a/packages/agent-runtime/src/tools/stream-parser.ts +++ b/packages/agent-runtime/src/tools/stream-parser.ts @@ -7,7 +7,7 @@ import { import { generateCompactId } from '@codebuff/common/util/string' import { cloneDeep } from 'lodash' -import { processStreamWithTags } from '../tool-stream-parser' +import { processStreamWithTools } from '../tool-stream-parser' import { executeCustomToolCall, executeToolCall } from './tool-executor' import { expireMessages } from '../util/messages' @@ -33,7 +33,7 @@ export type ToolCallError = { error: string } & Omit -export async function processStreamWithTools( +export async function processStream( params: { agentContext: Record agentTemplate: AgentTemplate @@ -65,7 +65,7 @@ export async function processStreamWithTools( | 'toolResultsToAddAfterStream' > & ParamsExcluding< - typeof processStreamWithTags, + typeof processStreamWithTools, 'processors' | 'defaultProcessor' | 'onError' | 'loggerOptions' >, ) { @@ -80,12 +80,14 @@ export async function processStreamWithTools( runId, signal, userId, + logger, } = params const fullResponseChunks: string[] = [fullResponse] const toolResults: ToolMessage[] = [] const toolResultsToAddAfterStream: ToolMessage[] = [] const toolCalls: (CodebuffToolCall | CustomToolCall)[] = [] + const assistantMessages: Message[] = [] const { promise: streamDonePromise, resolve: resolveStreamDonePromise } = Promise.withResolvers() let previousToolCallFinished = streamDonePromise @@ -122,6 +124,14 @@ export async function processStreamWithTools( toolResultsToAddAfterStream, onCostCalculated, + onResponseChunk: (chunk) => { + if (typeof chunk !== 'string' && chunk.type === 'tool_call') { + assistantMessages.push( + assistantMessage({ ...chunk, type: 'tool-call' }), + ) + } + return onResponseChunk(chunk) + }, }) }, } @@ -147,19 +157,27 @@ export async function processStreamWithTools( toolCalls, toolResults, toolResultsToAddAfterStream, + + onResponseChunk: (chunk) => { + if (typeof chunk !== 'string' && chunk.type === 'tool_call') { + assistantMessages.push( + assistantMessage({ ...chunk, type: 'tool-call' }), + ) + } + return onResponseChunk(chunk) + }, }) }, } } - const streamWithTags = processStreamWithTags({ + const streamWithTags = processStreamWithTools({ ...params, processors: Object.fromEntries([ ...toolNames.map((toolName) => [toolName, toolCallback(toolName)]), - ...Object.keys(fileContext.customToolDefinitions ?? {}).map((toolName) => [ - toolName, - customToolCallback(toolName), - ]), + ...Object.keys(fileContext.customToolDefinitions ?? {}).map( + (toolName) => [toolName, customToolCallback(toolName)], + ), ]), defaultProcessor: customToolCallback, onError: (toolName, error) => { @@ -179,6 +197,21 @@ export async function processStreamWithTools( model: agentTemplate.model, agentName: agentTemplate.id, }, + onResponseChunk: (chunk) => { + if (chunk.type === 'text') { + if (chunk.text) { + assistantMessages.push(assistantMessage(chunk.text)) + } + } else if (chunk.type === 'error') { + // do nothing + } else { + chunk satisfies never + throw new Error( + `Internal error: unhandled chunk type: ${(chunk as any).type}`, + ) + } + return onResponseChunk(chunk) + }, }) let messageId: string | null = null @@ -204,15 +237,17 @@ export async function processStreamWithTools( fullResponseChunks.push(chunk.text) } else if (chunk.type === 'error') { onResponseChunk(chunk) + } else if (chunk.type === 'tool-call') { + // Do nothing, the onResponseChunk for tool is handled in the processor } else { chunk satisfies never + throw new Error(`Unhandled chunk type: ${(chunk as any).type}`) } } agentState.messageHistory = buildArray([ ...expireMessages(agentState.messageHistory, 'agentStep'), - fullResponseChunks.length > 0 && - assistantMessage(fullResponseChunks.join('')), + ...assistantMessages, ...toolResultsToAddAfterStream, ]) diff --git a/packages/agent-runtime/src/tools/tool-executor.ts b/packages/agent-runtime/src/tools/tool-executor.ts index 1baa2b774..4304daed2 100644 --- a/packages/agent-runtime/src/tools/tool-executor.ts +++ b/packages/agent-runtime/src/tools/tool-executor.ts @@ -3,7 +3,6 @@ import { toolParams } from '@codebuff/common/tools/list' import { jsonToolResult } from '@codebuff/common/util/messages' import { generateCompactId } from '@codebuff/common/util/string' import { cloneDeep } from 'lodash' -import z from 'zod/v4' import { checkLiveUserInput } from '../live-user-inputs' import { getMCPToolData } from '../mcp' @@ -66,24 +65,28 @@ export function parseRawToolCall(params: { } const validName = toolName as T - const processedParameters: Record = {} - for (const [param, val] of Object.entries(rawToolCall.input ?? {})) { - processedParameters[param] = val - } + // const processedParameters: Record = {} + // for (const [param, val] of Object.entries(rawToolCall.input ?? {})) { + // processedParameters[param] = val + // } // Add the required codebuff_end_step parameter with the correct value for this tool if requested - if (autoInsertEndStepParam) { - processedParameters[endsAgentStepParam] = - toolParams[validName].endsAgentStep - } + // if (autoInsertEndStepParam) { + // processedParameters[endsAgentStepParam] = + // toolParams[validName].endsAgentStep + // } + + // const paramsSchema = toolParams[validName].endsAgentStep + // ? ( + // toolParams[validName].inputSchema satisfies z.ZodObject as z.ZodObject + // ).extend({ + // [endsAgentStepParam]: z.literal(toolParams[validName].endsAgentStep), + // }) + // : toolParams[validName].inputSchema + + const processedParameters = rawToolCall.input + const paramsSchema = toolParams[validName].inputSchema - const paramsSchema = toolParams[validName].endsAgentStep - ? ( - toolParams[validName].inputSchema satisfies z.ZodObject as z.ZodObject - ).extend({ - [endsAgentStepParam]: z.literal(toolParams[validName].endsAgentStep), - }) - : toolParams[validName].inputSchema const result = paramsSchema.safeParse(processedParameters) if (!result.success) { @@ -178,10 +181,9 @@ export function executeToolCall( toolCallId, toolName, input, - // Only include agentId for subagents (agents with a parent) - ...(agentState.parentId && { agentId: agentState.agentId }), - // Include includeToolCall flag if explicitly set to false - ...(excludeToolFromMessageHistory && { includeToolCall: false }), + agentId: agentState.agentId, + parentAgentId: agentState.parentId, + includeToolCall: !excludeToolFromMessageHistory, }) const toolCall: CodebuffToolCall | ToolCallError = parseRawToolCall({ diff --git a/packages/agent-runtime/src/util/__tests__/parse-tool-calls-from-text.test.ts b/packages/agent-runtime/src/util/__tests__/parse-tool-calls-from-text.test.ts new file mode 100644 index 000000000..a61e82703 --- /dev/null +++ b/packages/agent-runtime/src/util/__tests__/parse-tool-calls-from-text.test.ts @@ -0,0 +1,363 @@ +import { describe, expect, it } from 'bun:test' + +import { + parseToolCallsFromText, + parseTextWithToolCalls, +} from '../parse-tool-calls-from-text' + +describe('parseToolCallsFromText', () => { + it('should parse a single tool call', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + toolName: 'read_files', + input: { paths: ['test.ts'] }, + }) + }) + + it('should parse multiple tool calls', () => { + const text = `Some commentary before + + +{ + "cb_tool_name": "read_files", + "paths": ["file1.ts"] +} + + +Some text between + + +{ + "cb_tool_name": "str_replace", + "path": "file1.ts", + "replacements": [{"old": "foo", "new": "bar"}] +} + + +Some commentary after` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ + toolName: 'read_files', + input: { paths: ['file1.ts'] }, + }) + expect(result[1]).toEqual({ + toolName: 'str_replace', + input: { + path: 'file1.ts', + replacements: [{ old: 'foo', new: 'bar' }], + }, + }) + }) + + it('should remove cb_tool_name from input', () => { + const text = ` +{ + "cb_tool_name": "write_file", + "path": "test.ts", + "content": "console.log('hello')" +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].input).not.toHaveProperty('cb_tool_name') + expect(result[0].input).toEqual({ + path: 'test.ts', + content: "console.log('hello')", + }) + }) + + it('should remove cb_easp from input', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"], + "cb_easp": true +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].input).not.toHaveProperty('cb_easp') + expect(result[0].input).toEqual({ paths: ['test.ts'] }) + }) + + it('should skip malformed JSON', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts" +} + + + +{ + "cb_tool_name": "write_file", + "path": "good.ts", + "content": "valid" +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].toolName).toBe('write_file') + }) + + it('should skip tool calls without cb_tool_name', () => { + const text = ` +{ + "paths": ["test.ts"] +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(0) + }) + + it('should return empty array for text without tool calls', () => { + const text = 'Just some regular text without any tool calls' + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(0) + }) + + it('should return empty array for empty string', () => { + const result = parseToolCallsFromText('') + + expect(result).toHaveLength(0) + }) + + it('should handle complex nested objects in input', () => { + const text = ` +{ + "cb_tool_name": "spawn_agents", + "agents": [ + { + "agent_type": "file-picker", + "prompt": "Find relevant files" + }, + { + "agent_type": "code-searcher", + "params": { + "searchQueries": [ + {"pattern": "function test"} + ] + } + } + ] +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].toolName).toBe('spawn_agents') + expect(result[0].input.agents).toHaveLength(2) + }) + + it('should handle tool calls with escaped characters in strings', () => { + const text = + '\n' + + '{\n' + + ' "cb_tool_name": "str_replace",\n' + + ' "path": "test.ts",\n' + + ' "replacements": [{"old": "console.log(\\"hello\\")", "new": "console.log(\'world\')"}]\n' + + '}\n' + + '' + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + const replacements = result[0].input.replacements as Array<{ + old: string + new: string + }> + expect(replacements[0].old).toBe('console.log("hello")') + }) + + it('should handle tool calls with newlines in content', () => { + const text = + '\n' + + '{\n' + + ' "cb_tool_name": "write_file",\n' + + ' "path": "test.ts",\n' + + ' "content": "line1\\nline2\\nline3"\n' + + '}\n' + + '' + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].input.content).toBe('line1\nline2\nline3') + }) +}) + +describe('parseTextWithToolCalls', () => { + it('should parse interleaved text and tool calls', () => { + const text = `Some commentary before + + +{ + "cb_tool_name": "read_files", + "paths": ["file1.ts"] +} + + +Some text between + + +{ + "cb_tool_name": "write_file", + "path": "file2.ts", + "content": "test" +} + + +Some commentary after` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(5) + expect(result[0]).toEqual({ type: 'text', text: 'Some commentary before' }) + expect(result[1]).toEqual({ + type: 'tool_call', + toolName: 'read_files', + input: { paths: ['file1.ts'] }, + }) + expect(result[2]).toEqual({ type: 'text', text: 'Some text between' }) + expect(result[3]).toEqual({ + type: 'tool_call', + toolName: 'write_file', + input: { path: 'file2.ts', content: 'test' }, + }) + expect(result[4]).toEqual({ type: 'text', text: 'Some commentary after' }) + }) + + it('should return only tool call when no surrounding text', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} +` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + type: 'tool_call', + toolName: 'read_files', + input: { paths: ['test.ts'] }, + }) + }) + + it('should return only text when no tool calls', () => { + const text = 'Just some regular text without any tool calls' + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + type: 'text', + text: 'Just some regular text without any tool calls', + }) + }) + + it('should return empty array for empty string', () => { + const result = parseTextWithToolCalls('') + + expect(result).toHaveLength(0) + }) + + it('should handle text only before tool call', () => { + const text = `Introduction text + + +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} +` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ type: 'text', text: 'Introduction text' }) + expect(result[1].type).toBe('tool_call') + }) + + it('should handle text only after tool call', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} + + +Conclusion text` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(2) + expect(result[0].type).toBe('tool_call') + expect(result[1]).toEqual({ type: 'text', text: 'Conclusion text' }) + }) + + it('should skip malformed tool calls but keep surrounding text', () => { + const text = `Before text + + +{ + "cb_tool_name": "read_files", + "paths": ["test.ts" +} + + +After text` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ type: 'text', text: 'Before text' }) + expect(result[1]).toEqual({ type: 'text', text: 'After text' }) + }) + + it('should trim whitespace from text segments', () => { + const text = ` + Text with whitespace + + +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} + + + More text + ` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(3) + expect(result[0]).toEqual({ type: 'text', text: 'Text with whitespace' }) + expect(result[1].type).toBe('tool_call') + expect(result[2]).toEqual({ type: 'text', text: 'More text' }) + }) +}) diff --git a/packages/agent-runtime/src/util/agent-output.ts b/packages/agent-runtime/src/util/agent-output.ts index 624e3ca63..fe3a8da0a 100644 --- a/packages/agent-runtime/src/util/agent-output.ts +++ b/packages/agent-runtime/src/util/agent-output.ts @@ -1,10 +1,49 @@ import type { AgentTemplate } from '@codebuff/common/types/agent-template' -import type { AssistantMessage } from '@codebuff/common/types/messages/codebuff-message' +import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { AgentState, AgentOutput, } from '@codebuff/common/types/session-state' +/** + * Get the last assistant turn messages, which includes the last assistant message + * and any subsequent tool messages that are responses to its tool calls. + */ +function getLastAssistantTurnMessages(messageHistory: Message[]): Message[] { + // Find the index of the last assistant message + let lastAssistantIndex = -1 + for (let i = messageHistory.length - 1; i >= 0; i--) { + if (messageHistory[i].role === 'assistant') { + lastAssistantIndex = i + break + } + } + + for (let i = lastAssistantIndex; i >= 0; i--) { + if (messageHistory[i].role === 'assistant') { + lastAssistantIndex = i + } else break + } + + if (lastAssistantIndex === -1) { + return [] + } + + // Collect the assistant message and all subsequent tool messages + const result: Message[] = [] + for (let i = lastAssistantIndex; i < messageHistory.length; i++) { + const message = messageHistory[i] + if (message.role === 'assistant' || message.role === 'tool') { + result.push(message) + } else { + // Stop if we hit a user or system message + break + } + } + + return result +} + export function getAgentOutput( agentState: AgentState, agentTemplate: AgentTemplate, @@ -16,11 +55,10 @@ export function getAgentOutput( } } if (agentTemplate.outputMode === 'last_message') { - const assistantMessages = agentState.messageHistory.filter( - (message): message is AssistantMessage => message.role === 'assistant', + const lastTurnMessages = getLastAssistantTurnMessages( + agentState.messageHistory, ) - const lastAssistantMessage = assistantMessages[assistantMessages.length - 1] - if (!lastAssistantMessage) { + if (lastTurnMessages.length === 0) { return { type: 'error', message: 'No response from agent', @@ -28,7 +66,7 @@ export function getAgentOutput( } return { type: 'lastMessage', - value: lastAssistantMessage.content, + value: lastTurnMessages, } } if (agentTemplate.outputMode === 'all_messages') { diff --git a/packages/agent-runtime/src/util/parse-tool-calls-from-text.ts b/packages/agent-runtime/src/util/parse-tool-calls-from-text.ts new file mode 100644 index 000000000..4f9900a9e --- /dev/null +++ b/packages/agent-runtime/src/util/parse-tool-calls-from-text.ts @@ -0,0 +1,117 @@ +import { + startToolTag, + endToolTag, + toolNameParam, +} from '@codebuff/common/tools/constants' + +export type ParsedToolCallFromText = { + type: 'tool_call' + toolName: string + input: Record +} + +export type ParsedTextSegment = { + type: 'text' + text: string +} + +export type ParsedSegment = ParsedToolCallFromText | ParsedTextSegment + +/** + * Parses text containing tool calls in the XML format, + * returning interleaved text and tool call segments in order. + * + * Example input: + * ``` + * Some text before + * + * { + * "cb_tool_name": "read_files", + * "paths": ["file.ts"] + * } + * + * Some text after + * ``` + * + * @param text - The text containing tool calls in XML format + * @returns Array of segments (text and tool calls) in order of appearance + */ +export function parseTextWithToolCalls(text: string): ParsedSegment[] { + const segments: ParsedSegment[] = [] + + // Match ... blocks + const toolExtractionPattern = new RegExp( + `${escapeRegex(startToolTag)}([\\s\\S]*?)${escapeRegex(endToolTag)}`, + 'gs', + ) + + let lastIndex = 0 + + for (const match of text.matchAll(toolExtractionPattern)) { + // Add any text before this tool call + if (match.index !== undefined && match.index > lastIndex) { + const textBefore = text.slice(lastIndex, match.index).trim() + if (textBefore) { + segments.push({ type: 'text', text: textBefore }) + } + } + + const jsonContent = match[1].trim() + + try { + const parsed = JSON.parse(jsonContent) + const toolName = parsed[toolNameParam] + + if (typeof toolName === 'string') { + // Remove the tool name param from the input + const input = { ...parsed } + delete input[toolNameParam] + + // Also remove cb_easp if present + delete input['cb_easp'] + + segments.push({ + type: 'tool_call', + toolName, + input, + }) + } + } catch { + // Skip malformed JSON - don't add segment + } + + // Update lastIndex to after this match + if (match.index !== undefined) { + lastIndex = match.index + match[0].length + } + } + + // Add any remaining text after the last tool call + if (lastIndex < text.length) { + const textAfter = text.slice(lastIndex).trim() + if (textAfter) { + segments.push({ type: 'text', text: textAfter }) + } + } + + return segments +} + +/** + * Parses tool calls from text in the XML format. + * This is a convenience function that returns only tool calls (no text segments). + * + * @param text - The text containing tool calls in XML format + * @returns Array of parsed tool calls with toolName and input + */ +export function parseToolCallsFromText( + text: string, +): Omit[] { + return parseTextWithToolCalls(text) + .filter((segment): segment is ParsedToolCallFromText => segment.type === 'tool_call') + .map(({ toolName, input }) => ({ toolName, input })) +} + +function escapeRegex(string: string): string { + return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') +} diff --git a/sdk/src/__tests__/run-with-retry.test.ts b/sdk/src/__tests__/run-with-retry.test.ts index cf0351cf5..e240b8cff 100644 --- a/sdk/src/__tests__/run-with-retry.test.ts +++ b/sdk/src/__tests__/run-with-retry.test.ts @@ -1,10 +1,12 @@ +import { assistantMessage } from '@codebuff/common/util/messages' import { afterEach, describe, expect, it, mock, spyOn } from 'bun:test' -import { ErrorCodes, NetworkError } from '../errors' +import { ErrorCodes } from '../errors' import { run } from '../run' import * as runModule from '../run' import type { RunState } from '../run-state' +import type { SessionState } from '@codebuff/common/types/session-state' const baseOptions = { apiKey: 'test-key', @@ -19,8 +21,13 @@ describe('run retry wrapper', () => { }) it('returns immediately on success without retrying', async () => { - const expectedState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'hi' } } as RunState - const runSpy = spyOn(runModule, 'runOnce').mockResolvedValueOnce(expectedState) + const expectedState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('hi')] }, + } + const runSpy = spyOn(runModule, 'runOnce').mockResolvedValueOnce( + expectedState, + ) const result = await run(baseOptions) @@ -29,11 +36,14 @@ describe('run retry wrapper', () => { }) it('retries once on retryable error output and then succeeds', async () => { - const errorState = { - sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: Service unavailable' } - } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'hi' } } as RunState + const errorState: RunState = { + sessionState: {} as SessionState, + output: { type: 'error', message: 'NetworkError: Service unavailable' }, + } + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('hi')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -51,7 +61,7 @@ describe('run retry wrapper', () => { it('stops after max retries are exhausted and returns error output', async () => { const errorState = { sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: Connection timeout' } + output: { type: 'error', message: 'NetworkError: Connection timeout' }, } as RunState const runSpy = spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -73,7 +83,7 @@ describe('run retry wrapper', () => { it('does not retry non-retryable error outputs', async () => { const errorState = { sessionState: {} as any, - output: { type: 'error', message: 'Invalid input' } + output: { type: 'error', message: 'Invalid input' }, } as RunState const runSpy = spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -91,7 +101,7 @@ describe('run retry wrapper', () => { it('skips retry when retry is false even for retryable error outputs', async () => { const errorState = { sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: Connection failed' } + output: { type: 'error', message: 'NetworkError: Connection failed' }, } as RunState const runSpy = spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -106,11 +116,14 @@ describe('run retry wrapper', () => { }) it('retries when provided custom retryableErrorCodes set', async () => { - const errorState = { + const errorState: RunState = { sessionState: {} as any, - output: { type: 'error', message: 'Server error (500)' } - } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'hi' } } as RunState + output: { type: 'error', message: 'Server error (500)' }, + } + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('hi')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -149,11 +162,14 @@ describe('run retry wrapper', () => { }) it('calls onRetry callback with correct parameters on error output', async () => { - const errorState = { - sessionState: {} as any, - output: { type: 'error', message: 'Service unavailable (503)' } - } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'done' } } as RunState + const errorState: RunState = { + sessionState: {} as SessionState, + output: { type: 'error', message: 'Service unavailable (503)' }, + } + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('done')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -178,7 +194,7 @@ describe('run retry wrapper', () => { it('calls onRetryExhausted after all retries fail', async () => { const errorState = { sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: timeout' } + output: { type: 'error', message: 'NetworkError: timeout' }, } as RunState spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -200,7 +216,7 @@ describe('run retry wrapper', () => { it('returns error output without sessionState on first attempt failure', async () => { const errorState = { - output: { type: 'error', message: 'Not retryable' } + output: { type: 'error', message: 'Not retryable' }, } as RunState spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -216,14 +232,14 @@ describe('run retry wrapper', () => { it('preserves sessionState from previousRun on retry', async () => { const previousSession = { fileContext: { cwd: '/test' } } as any - const errorState = { - sessionState: { fileContext: { cwd: '/new' } } as any, - output: { type: 'error', message: 'Service unavailable' } - } as RunState - const successState = { - sessionState: { fileContext: { cwd: '/final' } } as any, - output: { type: 'lastMessage', value: 'ok' } - } as RunState + const errorState: RunState = { + sessionState: { fileContext: { cwd: '/new' } } as SessionState, + output: { type: 'error', message: 'Service unavailable' }, + } + const successState: RunState = { + sessionState: { fileContext: { cwd: '/final' } } as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('ok')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -231,7 +247,10 @@ describe('run retry wrapper', () => { const result = await run({ ...baseOptions, - previousRun: { sessionState: previousSession, output: { type: 'lastMessage', value: 'prev' } }, + previousRun: { + sessionState: previousSession, + output: { type: 'lastMessage', value: [assistantMessage('prev')] }, + }, retry: { backoffBaseMs: 1, backoffMaxMs: 2 }, }) @@ -240,11 +259,17 @@ describe('run retry wrapper', () => { }) it('handles 503 Service Unavailable errors as retryable', async () => { - const errorState = { - sessionState: {} as any, - output: { type: 'error', message: 'Error from AI SDK: 503 Service Unavailable' } - } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'ok' } } as RunState + const errorState: RunState = { + sessionState: {} as SessionState, + output: { + type: 'error', + message: 'Error from AI SDK: 503 Service Unavailable', + }, + } + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('ok')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -260,11 +285,14 @@ describe('run retry wrapper', () => { }) it('applies exponential backoff correctly', async () => { - const errorState = { - sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: Connection refused' } + const errorState: RunState = { + sessionState: {} as SessionState, + output: { type: 'error', message: 'NetworkError: Connection refused' }, } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'ok' } } as RunState + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('ok')] }, + } spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) diff --git a/sdk/src/impl/llm.ts b/sdk/src/impl/llm.ts index 07f9563c0..f743a70b7 100644 --- a/sdk/src/impl/llm.ts +++ b/sdk/src/impl/llm.ts @@ -351,6 +351,9 @@ export async function* promptAiSdkStream( } } } + if (chunk.type === 'tool-call') { + yield chunk + } } const flushed = stopSequenceHandler.flush() if (flushed) { diff --git a/web/src/app/api/v1/chat/completions/_post.ts b/web/src/app/api/v1/chat/completions/_post.ts index 0762f3f1b..45f99d675 100644 --- a/web/src/app/api/v1/chat/completions/_post.ts +++ b/web/src/app/api/v1/chat/completions/_post.ts @@ -12,6 +12,7 @@ import { import { handleOpenRouterNonStream, handleOpenRouterStream, + OpenRouterError, } from '@/llm-api/openrouter' import { extractApiKeyFromHeader } from '@/util/auth' @@ -339,6 +340,12 @@ export async function postChatCompletions(params: { }, logger, }) + + // Pass through OpenRouter provider-specific errors + if (error instanceof OpenRouterError) { + return NextResponse.json(error.toJSON(), { status: error.statusCode }) + } + return NextResponse.json( { error: 'Failed to process request' }, { status: 500 }, diff --git a/web/src/llm-api/openrouter.ts b/web/src/llm-api/openrouter.ts index d9a85ed64..173eb9bfc 100644 --- a/web/src/llm-api/openrouter.ts +++ b/web/src/llm-api/openrouter.ts @@ -6,7 +6,10 @@ import { extractRequestMetadata, insertMessageToBigQuery, } from './helpers' -import { OpenRouterStreamChatCompletionChunkSchema } from './type/openrouter' +import { + OpenRouterErrorResponseSchema, + OpenRouterStreamChatCompletionChunkSchema, +} from './type/openrouter' import type { UsageData } from './helpers' import type { OpenRouterStreamChatCompletionChunk } from './type/openrouter' @@ -14,7 +17,6 @@ import type { InsertMessageBigqueryFn } from '@codebuff/common/types/contracts/b import type { Logger } from '@codebuff/common/types/contracts/logger' type StreamState = { responseText: string; reasoningText: string } - function createOpenRouterRequest(params: { body: any openrouterApiKey: string | null @@ -93,9 +95,9 @@ export async function handleOpenRouterNonStream({ const responses = await Promise.all(requests) if (responses.every((r) => !r.ok)) { - throw new Error( - `Failed to make all ${n} requests: ${responses.map((r) => r.statusText).join(', ')}`, - ) + // Return provider-specific error from the first failed response + const firstFailedResponse = responses[0] + throw await parseOpenRouterError(firstFailedResponse) } const allData = await Promise.all(responses.map((r) => r.json())) @@ -183,9 +185,7 @@ export async function handleOpenRouterNonStream({ }) if (!response.ok) { - throw new Error( - `OpenRouter API error (${response.statusText}): ${await response.text()}`, - ) + throw await parseOpenRouterError(response) } const data = await response.json() @@ -261,9 +261,7 @@ export async function handleOpenRouterStream({ }) if (!response.ok) { - throw new Error( - `OpenRouter API error (${response.statusText}): ${await response.text()}`, - ) + throw await parseOpenRouterError(response) } const reader = response.body?.getReader() @@ -532,3 +530,84 @@ async function handleStreamChunk({ state.reasoningText += choice.delta?.reasoning ?? '' return state } + +/** + * Custom error class for OpenRouter API errors that preserves provider-specific details. + */ +export class OpenRouterError extends Error { + constructor( + public readonly statusCode: number, + public readonly statusText: string, + public readonly errorBody: { + error: { + message: string + code: string | number | null + type?: string | null + param?: unknown + metadata?: { + raw?: string + provider_name?: string + } + } + }, + ) { + super(errorBody.error.message) + this.name = 'OpenRouterError' + } + + /** + * Returns the error in a format suitable for API responses. + */ + toJSON() { + return { + error: { + message: this.errorBody.error.message, + code: this.errorBody.error.code, + type: this.errorBody.error.type, + param: this.errorBody.error.param, + metadata: this.errorBody.error.metadata, + }, + } + } +} + +/** + * Parses an error response from OpenRouter and returns an OpenRouterError. + */ +async function parseOpenRouterError( + response: Response, +): Promise { + const errorText = await response.text() + let errorBody: OpenRouterError['errorBody'] + try { + const parsed = JSON.parse(errorText) + const validated = OpenRouterErrorResponseSchema.safeParse(parsed) + if (validated.success) { + errorBody = { + error: { + message: validated.data.error.message, + code: validated.data.error.code ?? null, + type: validated.data.error.type, + param: validated.data.error.param, + // metadata is not in the schema but OpenRouter includes it for provider errors + metadata: (parsed as any).error?.metadata, + }, + } + } else { + errorBody = { + error: { + message: errorText || response.statusText, + code: response.status, + }, + } + } + } catch { + errorBody = { + error: { + message: errorText || response.statusText, + code: response.status, + }, + } + } + return new OpenRouterError(response.status, response.statusText, errorBody) +}