diff --git a/bridge.go b/bridge.go index 6d4e38d..7033b89 100644 --- a/bridge.go +++ b/bridge.go @@ -203,13 +203,14 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC interceptor.Setup(logger, asyncRecorder, mcpProxy) if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{ - Client: guessClient(r), - ID: interceptor.ID().String(), - InitiatorID: actor.ID, - Metadata: actor.Metadata, - Model: interceptor.Model(), - Provider: p.Name(), - UserAgent: r.UserAgent(), + Client: guessClient(r), + ID: interceptor.ID().String(), + InitiatorID: actor.ID, + Metadata: actor.Metadata, + Model: interceptor.Model(), + Provider: p.Name(), + UserAgent: r.UserAgent(), + CorrelatingToolCallID: interceptor.CorrelatingToolCallID(), }); err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err)) logger.Warn(ctx, "failed to record interception", slog.Error(err)) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 068474c..bf83264 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -64,16 +64,19 @@ func TestAnthropicMessages(t *testing.T) { streaming bool expectedInputTokens int expectedOutputTokens int + expectedToolCallID string }{ { streaming: true, expectedInputTokens: 2, expectedOutputTokens: 66, + expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", }, { streaming: false, expectedInputTokens: 5, expectedOutputTokens: 84, + expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq", }, } @@ -134,6 +137,7 @@ func TestAnthropicMessages(t *testing.T) { toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "Read", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) require.IsType(t, json.RawMessage{}, toolUsages[0].Args) var args map[string]any require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) @@ -278,16 +282,19 @@ func TestOpenAIChatCompletions(t *testing.T) { cases := []struct { streaming bool expectedInputTokens, expectedOutputTokens int + expectedToolCallID string }{ { streaming: true, expectedInputTokens: 60, expectedOutputTokens: 15, + expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", }, { streaming: false, expectedInputTokens: 60, expectedOutputTokens: 15, + expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc", }, } @@ -347,6 +354,7 @@ func TestOpenAIChatCompletions(t *testing.T) { toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "read_file", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) require.IsType(t, map[string]any{}, toolUsages[0].Args) require.Contains(t, toolUsages[0].Args, "path") assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index ac7476b..aed8f88 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -63,6 +63,19 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, i.mcpProxy = mcpProxy } +func (i *interceptionBase) CorrelatingToolCallID() *string { + if len(i.req.Messages) == 0 { + return nil + } + + // The tool result should be the last input message. + msg := i.req.Messages[len(i.req.Messages)-1] + if msg.OfTool == nil { + return nil + } + return &msg.OfTool.ToolCallID +} + func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), diff --git a/intercept/chatcompletions/base_test.go b/intercept/chatcompletions/base_test.go new file mode 100644 index 0000000..1647a2d --- /dev/null +++ b/intercept/chatcompletions/base_test.go @@ -0,0 +1,76 @@ +package chatcompletions + +import ( + "testing" + + "github.com/coder/aibridge/utils" + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/require" +) + +func TestScanForCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + messages []openai.ChatCompletionMessageParamUnion + expected *string + }{ + { + name: "no messages", + messages: nil, + expected: nil, + }, + { + name: "no tool messages", + messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + openai.AssistantMessage("hi there"), + }, + expected: nil, + }, + { + name: "single tool message", + messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + openai.ToolMessage("result", "call_abc"), + }, + expected: utils.PtrTo("call_abc"), + }, + { + name: "multiple tool messages returns last", + messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + openai.ToolMessage("first result", "call_first"), + openai.AssistantMessage("thinking"), + openai.ToolMessage("second result", "call_second"), + }, + expected: utils.PtrTo("call_second"), + }, + { + name: "last message is not a tool message", + messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + openai.ToolMessage("first result", "call_first"), + openai.AssistantMessage("thinking"), + }, + expected: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + base := &interceptionBase{ + req: &ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: tc.messages, + }, + }, + } + + require.Equal(t, tc.expected, base.CorrelatingToolCallID()) + }) + } +} diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index c650ade..9a84d14 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -124,6 +124,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, + ToolCallID: toolCall.ID, Tool: toolCall.Function.Name, Args: i.unmarshalArgs(toolCall.Function.Arguments), Injected: false, @@ -161,6 +162,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, + ToolCallID: tc.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: args, diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index fb18f82..ff3b78c 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -173,10 +173,12 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), + ToolCallID: toolCall.ID, Tool: toolCall.Name, Args: i.unmarshalArgs(toolCall.Arguments), Injected: false, }) + toolCall = nil } else { // When the provider responds with only tool calls (no text content), @@ -284,6 +286,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), + ToolCallID: id, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: args, diff --git a/intercept/interceptor.go b/intercept/interceptor.go index 39fe615..cbd29d6 100644 --- a/intercept/interceptor.go +++ b/intercept/interceptor.go @@ -25,4 +25,13 @@ type Interceptor interface { Streaming() bool // TraceAttributes returns tracing attributes for this [Interceptor] TraceAttributes(*http.Request) []attribute.KeyValue + // CorrelatingToolCallID returns the ID of a tool call result submitted + // in the request, if present. This is used to correlate the current + // interception back to the previous interception that issued those tool + // calls. If multiple tool use results are present, we use the last one + // (most recent). Both Anthropic's /v1/messages and OpenAI's /v1/responses + // require that ALL tool results are submitted for tool choices returned + // by the model, so any single tool call ID is sufficient to identify the + // parent interception. + CorrelatingToolCallID() *string } diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 6f4f01f..387591d 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -57,6 +57,21 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, i.mcpProxy = mcpProxy } +func (i *interceptionBase) CorrelatingToolCallID() *string { + if len(i.req.Messages) == 0 { + return nil + } + content := i.req.Messages[len(i.req.Messages)-1].Content + for idx := len(content) - 1; idx >= 0; idx-- { + block := content[idx] + if block.OfToolResult == nil { + continue + } + return &block.OfToolResult.ToolUseID + } + return nil +} + func (i *interceptionBase) Model() string { if i.req == nil { return "coder-aibridge-unknown" diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index 5413a7d..cca890e 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -8,10 +8,107 @@ import ( "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/config" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/utils" mcpgo "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" ) +func TestScanForCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + messages []anthropic.MessageParam + expected *string + }{ + { + name: "no messages", + messages: nil, + expected: nil, + }, + { + name: "last message has no tool_result blocks", + messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("hello")), + }, + expected: nil, + }, + { + name: "single tool_result block", + messages: []anthropic.MessageParam{ + anthropic.NewUserMessage( + anthropic.ContentBlockParamUnion{ + OfToolResult: &anthropic.ToolResultBlockParam{ + ToolUseID: "toolu_abc", + Content: []anthropic.ToolResultBlockParamContentUnion{ + {OfText: &anthropic.TextBlockParam{Text: "result"}}, + }, + }, + }, + ), + }, + expected: utils.PtrTo("toolu_abc"), + }, + { + name: "multiple tool_result blocks returns last", + messages: []anthropic.MessageParam{ + anthropic.NewUserMessage( + anthropic.ContentBlockParamUnion{ + OfToolResult: &anthropic.ToolResultBlockParam{ + ToolUseID: "toolu_first", + Content: []anthropic.ToolResultBlockParamContentUnion{ + {OfText: &anthropic.TextBlockParam{Text: "first"}}, + }, + }, + }, + anthropic.NewTextBlock("some text"), + anthropic.ContentBlockParamUnion{ + OfToolResult: &anthropic.ToolResultBlockParam{ + ToolUseID: "toolu_second", + Content: []anthropic.ToolResultBlockParamContentUnion{ + {OfText: &anthropic.TextBlockParam{Text: "second"}}, + }, + }, + }, + ), + }, + expected: utils.PtrTo("toolu_second"), + }, + { + name: "last message is not a tool result", + messages: []anthropic.MessageParam{ + anthropic.NewUserMessage( + anthropic.ContentBlockParamUnion{ + OfToolResult: &anthropic.ToolResultBlockParam{ + ToolUseID: "toolu_first", + Content: []anthropic.ToolResultBlockParamContentUnion{ + {OfText: &anthropic.TextBlockParam{Text: "first"}}, + }, + }, + }), + anthropic.NewUserMessage(anthropic.NewTextBlock("some text")), + }, + expected: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + base := &interceptionBase{ + req: &MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: tc.messages, + }, + }, + } + + require.Equal(t, tc.expected, base.CorrelatingToolCallID()) + }) + } +} + func TestAWSBedrockValidation(t *testing.T) { t.Parallel() diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 7ab2bed..e22b97f 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -152,10 +152,12 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: resp.ID, + ToolCallID: toolUse.ID, Tool: toolUse.Name, Args: toolUse.Input, Injected: false, }) + } // If no injected tool calls, we're done. @@ -188,6 +190,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: resp.ID, + ToolCallID: tc.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: tc.Input, diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 4fc19fd..4e87fd8 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -298,6 +298,7 @@ newStream: _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: message.ID, + ToolCallID: id, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: input, @@ -411,6 +412,7 @@ newStream: _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: message.ID, + ToolCallID: variant.ID, Tool: variant.Name, Args: variant.Input, Injected: false, diff --git a/intercept/responses/base.go b/intercept/responses/base.go index dcd72a0..b531a71 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -80,6 +80,19 @@ func (i *responsesInterceptionBase) Model() string { return i.model } +func (i *responsesInterceptionBase) CorrelatingToolCallID() *string { + if len(i.req.Input.OfInputItemList) == 0 { + return nil + } + + // The tool result should be the last input message. + item := i.req.Input.OfInputItemList[len(i.req.Input.OfInputItemList)-1] + if item.OfFunctionCallOutput == nil { + return nil + } + return &item.OfFunctionCallOutput.CallID +} + func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), @@ -263,6 +276,7 @@ func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Conte if err := i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: response.ID, + ToolCallID: item.CallID, Tool: item.Name, Args: args, Injected: false, diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index de72010..ad0d59a 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -9,11 +9,110 @@ import ( "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/recorder" + "github.com/coder/aibridge/utils" "github.com/google/uuid" oairesponses "github.com/openai/openai-go/v3/responses" "github.com/stretchr/testify/require" ) +func TestScanForCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input []oairesponses.ResponseInputItemUnionParam + expected *string + }{ + { + name: "no input items", + input: nil, + expected: nil, + }, + { + name: "no function_call_output items", + input: []oairesponses.ResponseInputItemUnionParam{ + { + OfMessage: &oairesponses.EasyInputMessageParam{ + Role: "user", + }, + }, + }, + expected: nil, + }, + { + name: "single function_call_output", + input: []oairesponses.ResponseInputItemUnionParam{ + { + OfMessage: &oairesponses.EasyInputMessageParam{ + Role: "user", + }, + }, + { + OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ + CallID: "call_abc", + }, + }, + }, + expected: utils.PtrTo("call_abc"), + }, + { + name: "multiple function_call_outputs returns last", + input: []oairesponses.ResponseInputItemUnionParam{ + { + OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ + CallID: "call_first", + }, + }, + { + OfMessage: &oairesponses.EasyInputMessageParam{ + Role: "user", + }, + }, + { + OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ + CallID: "call_second", + }, + }, + }, + expected: utils.PtrTo("call_second"), + }, + { + name: "last input is not a tool result", + input: []oairesponses.ResponseInputItemUnionParam{ + { + OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ + CallID: "call_first", + }, + }, + { + OfMessage: &oairesponses.EasyInputMessageParam{ + Role: "user", + }, + }, + }, + expected: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + base := &responsesInterceptionBase{ + req: &ResponsesNewParamsWrapper{ + ResponseNewParams: oairesponses.ResponseNewParams{ + Input: oairesponses.ResponseNewParamsInputUnion{ + OfInputItemList: tc.input, + }, + }, + }, + } + + require.Equal(t, tc.expected, base.CorrelatingToolCallID()) + }) + } +} + func TestLastUserPrompt(t *testing.T) { t.Parallel() @@ -265,6 +364,7 @@ func TestRecordToolUsage(t *testing.T) { Output: []oairesponses.ResponseOutputItemUnion{ { Type: "function_call", + CallID: "call_abc", Name: "get_weather", Arguments: "", }, @@ -274,6 +374,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_456", + ToolCallID: "call_abc", Tool: "get_weather", Args: "", Injected: false, @@ -287,11 +388,13 @@ func TestRecordToolUsage(t *testing.T) { Output: []oairesponses.ResponseOutputItemUnion{ { Type: "function_call", + CallID: "call_1", Name: "get_weather", Arguments: `{"location": "NYC"}`, }, { Type: "function_call", + CallID: "call_2", Name: "bad_json_args", Arguments: `{"bad": args`, }, @@ -301,12 +404,14 @@ func TestRecordToolUsage(t *testing.T) { Role: "assistant", }, { - Type: "custom_tool_call", - Name: "search", - Input: `{\"query\": \"test\"}`, + Type: "custom_tool_call", + CallID: "call_3", + Name: "search", + Input: `{\"query\": \"test\"}`, }, { Type: "function_call", + CallID: "call_4", Name: "calculate", Arguments: `{"a": 1, "b": 2}`, }, @@ -316,6 +421,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ToolCallID: "call_1", Tool: "get_weather", Args: map[string]any{"location": "NYC"}, Injected: false, @@ -323,6 +429,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ToolCallID: "call_2", Tool: "bad_json_args", Args: `{"bad": args`, Injected: false, @@ -330,6 +437,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ToolCallID: "call_3", Tool: "search", Args: `{\"query\": \"test\"}`, Injected: false, @@ -337,6 +445,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ToolCallID: "call_4", Tool: "calculate", Args: map[string]any{"a": float64(1), "b": float64(2)}, Injected: false, diff --git a/intercept/responses/injected_tools.go b/intercept/responses/injected_tools.go index e97c885..8a47801 100644 --- a/intercept/responses/injected_tools.go +++ b/intercept/responses/injected_tools.go @@ -212,6 +212,7 @@ func (i *responsesInterceptionBase) invokeInjectedTool(ctx context.Context, resp _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: responseID, + ToolCallID: fc.CallID, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: args, diff --git a/internal/testutil/mock_recorder.go b/internal/testutil/mock_recorder.go index d4c9c8e..ac39006 100644 --- a/internal/testutil/mock_recorder.go +++ b/internal/testutil/mock_recorder.go @@ -6,7 +6,6 @@ import ( "slices" "sync" "testing" - "time" "github.com/coder/aibridge/recorder" "github.com/stretchr/testify/require" @@ -21,7 +20,7 @@ type MockRecorder struct { tokenUsages []*recorder.TokenUsageRecord userPrompts []*recorder.PromptUsageRecord toolUsages []*recorder.ToolUsageRecord - interceptionsEnd map[string]time.Time + interceptionsEnd map[string]*recorder.InterceptionRecordEnded } func (m *MockRecorder) RecordInterception(ctx context.Context, req *recorder.InterceptionRecord) error { @@ -35,12 +34,12 @@ func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorde m.mu.Lock() defer m.mu.Unlock() if m.interceptionsEnd == nil { - m.interceptionsEnd = make(map[string]time.Time) + m.interceptionsEnd = make(map[string]*recorder.InterceptionRecordEnded) } if !slices.ContainsFunc(m.interceptions, func(intc *recorder.InterceptionRecord) bool { return intc.ID == req.ID }) { return fmt.Errorf("id not found") } - m.interceptionsEnd[req.ID] = req.EndedAt + m.interceptionsEnd[req.ID] = req return nil } @@ -107,6 +106,14 @@ func (m *MockRecorder) ToolUsages() []*recorder.ToolUsageRecord { return m.toolUsages } +// RecordedInterceptionEnd returns the stored InterceptionRecordEnded for the +// given interception ID, or nil if not found. +func (m *MockRecorder) RecordedInterceptionEnd(id string) *recorder.InterceptionRecordEnded { + m.mu.Lock() + defer m.mu.Unlock() + return m.interceptionsEnd[id] +} + // VerifyAllInterceptionsEnded verifies all recorded interceptions have been marked as completed. func (m *MockRecorder) VerifyAllInterceptionsEnded(t *testing.T) { t.Helper() diff --git a/recorder/types.go b/recorder/types.go index 3dc61c9..82c34d0 100644 --- a/recorder/types.go +++ b/recorder/types.go @@ -26,14 +26,15 @@ type ToolArgs any type Metadata map[string]any type InterceptionRecord struct { - Client string - ID string - InitiatorID string - Metadata Metadata - Model string - Provider string - StartedAt time.Time - UserAgent string + Client string + ID string + InitiatorID string + Metadata Metadata + Model string + Provider string + StartedAt time.Time + UserAgent string + CorrelatingToolCallID *string } type InterceptionRecordEnded struct { @@ -63,6 +64,7 @@ type PromptUsageRecord struct { type ToolUsageRecord struct { InterceptionID string MsgID, Tool string + ToolCallID string ServerURL *string Args ToolArgs Injected bool diff --git a/responses_integration_test.go b/responses_integration_test.go index 2521253..4b82bfb 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -71,10 +71,11 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel: "gpt-4.1", expectPromptRecorded: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", expectToolRecorded: &recorder.ToolUsageRecord{ - MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", - Tool: "add", - Args: map[string]any{"a": float64(3), "b": float64(5)}, - Injected: false, + MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", + Tool: "add", + ToolCallID: "call_CJSaa2u51JG996575oVljuNq", + Args: map[string]any{"a": float64(3), "b": float64(5)}, + Injected: false, }, expectTokenUsage: &recorder.TokenUsageRecord{ MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", @@ -111,10 +112,11 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel: "gpt-5", expectPromptRecorded: "Use the code_exec tool to print hello world to the console.", expectToolRecorded: &recorder.ToolUsageRecord{ - MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", - Tool: "code_exec", - Args: "print(\"hello world\")", - Injected: false, + MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", + Tool: "code_exec", + ToolCallID: "call_haf8njtwrVZ1754Gm6fjAtuA", + Args: "print(\"hello world\")", + Injected: false, }, expectTokenUsage: &recorder.TokenUsageRecord{ MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", @@ -207,10 +209,11 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel: "gpt-4.1", expectPromptRecorded: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", expectToolRecorded: &recorder.ToolUsageRecord{ - MsgID: "resp_0c3fb28cfcf463a500695fa2f0239481a095ec6ce3dfe4d458", - Tool: "add", - Args: map[string]any{"a": float64(3), "b": float64(5)}, - Injected: false, + MsgID: "resp_0c3fb28cfcf463a500695fa2f0239481a095ec6ce3dfe4d458", + Tool: "add", + ToolCallID: "call_7VaiUXZYuuuwWwviCrckxq6t", + Args: map[string]any{"a": float64(3), "b": float64(5)}, + Injected: false, }, expectTokenUsage: &recorder.TokenUsageRecord{ MsgID: "resp_0c3fb28cfcf463a500695fa2f0239481a095ec6ce3dfe4d458", @@ -249,10 +252,11 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel: "gpt-5", expectPromptRecorded: "Use the code_exec tool to print hello world to the console.", expectToolRecorded: &recorder.ToolUsageRecord{ - MsgID: "resp_0c26996bc41c2a0500696942e83634819fb71b2b8ff8a4a76c", - Tool: "code_exec", - Args: "print(\"hello world\")", - Injected: false, + MsgID: "resp_0c26996bc41c2a0500696942e83634819fb71b2b8ff8a4a76c", + Tool: "code_exec", + ToolCallID: "call_2gSnF58IEhXLwlbnqbm5XKMd", + Args: "print(\"hello world\")", + Injected: false, }, expectTokenUsage: &recorder.TokenUsageRecord{ MsgID: "resp_0c26996bc41c2a0500696942e83634819fb71b2b8ff8a4a76c",