From 0957029d33a8917d7f4a3f0e13714eeee21f958f Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Thu, 19 Feb 2026 18:07:04 +0200 Subject: [PATCH 1/9] feat: record and correlate tool call IDs across interceptions Add CorrelatingToolCallID to the Interceptor interface, enabling interception lineage tracking. Each interceptor now scans backward through request messages/input to find the most recent tool call result, correctly identifying the parent interception that triggered the current one. Also record ToolCallID in non-injected tool usage for the responses API, which was previously missing and prevented lineage queries from finding parent interceptions. Changes: - Add CorrelatingToolCallID() to Interceptor interface - Add correlatingToolCallID field to all interceptor base types - Add backward scan loops in chatcompletions, messages, responses - Add scanForCorrelatingToolCallID() for responses API - Record ToolCallID in both injected and non-injected tool usage - Add CorrelatingToolCallID to InterceptionRecordEnded - Add ToolCallID to ToolUsageRecord - Pass CorrelatingToolCallID when ending interceptions in bridge.go - Update tests and integration test expectations Signed-off-by: Danny Kopping --- bridge.go | 2 +- intercept/chatcompletions/base.go | 6 ++++ intercept/chatcompletions/blocking.go | 12 ++++++++ intercept/chatcompletions/streaming.go | 12 ++++++++ intercept/interceptor.go | 9 ++++++ intercept/messages/base.go | 6 ++++ intercept/messages/blocking.go | 15 ++++++++++ intercept/messages/streaming.go | 17 +++++++++++ intercept/responses/base.go | 40 +++++++++++++++++++------- intercept/responses/base_test.go | 16 +++++++++-- intercept/responses/blocking.go | 1 + intercept/responses/injected_tools.go | 1 + intercept/responses/streaming.go | 1 + recorder/types.go | 6 ++-- responses_integration_test.go | 36 ++++++++++++----------- 15 files changed, 148 insertions(+), 32 deletions(-) diff --git a/bridge.go b/bridge.go index 6d4e38d..a2b8613 100644 --- a/bridge.go +++ b/bridge.go @@ -250,7 +250,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC log.Debug(ctx, "interception ended") } - asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()}) + asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String(), CorrelatingToolCallID: interceptor.CorrelatingToolCallID()}) // Ensure all recording have completed before completing request. asyncRecorder.Wait() diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index ac7476b..899cde3 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -34,6 +34,8 @@ type interceptionBase struct { recorder recorder.Recorder mcpProxy mcp.ServerProxier + + correlatingToolCallID string } func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { @@ -63,6 +65,10 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, i.mcpProxy = mcpProxy } +func (i *interceptionBase) CorrelatingToolCallID() string { + return i.correlatingToolCallID +} + func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index c650ade..e989ff4 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -68,6 +68,18 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req i.injectTools() + // Scan the request for tool results; we use these to correlate requests + // together. We iterate backward so we find the last (most recent) tool + // result, which correctly identifies the parent interception. + for idx := len(i.req.Messages) - 1; idx >= 0; idx-- { + msg := i.req.Messages[idx] + if msg.OfTool == nil { + continue + } + i.correlatingToolCallID = msg.OfTool.ToolCallID + break + } + prompt, err := i.req.lastUserPrompt() if err != nil { logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(err)) diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index fb18f82..2c8f2cc 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -79,6 +79,18 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re i.injectTools() + // Scan the request for tool results; we use these to correlate requests + // together. We iterate backward so we find the last (most recent) tool + // result, which correctly identifies the parent interception. + for idx := len(i.req.Messages) - 1; idx >= 0; idx-- { + msg := i.req.Messages[idx] + if msg.OfTool == nil { + continue + } + i.correlatingToolCallID = msg.OfTool.ToolCallID + break + } + // Allow us to interrupt watch via cancel. ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/intercept/interceptor.go b/intercept/interceptor.go index 39fe615..c767022 100644 --- a/intercept/interceptor.go +++ b/intercept/interceptor.go @@ -25,4 +25,13 @@ type Interceptor interface { Streaming() bool // TraceAttributes returns tracing attributes for this [Interceptor] TraceAttributes(*http.Request) []attribute.KeyValue + // CorrelatingToolCallID returns the ID of a tool call result submitted + // in the request, if present. This is used to correlate the current + // interception back to the previous interception that issued those tool + // calls. If multiple tool use results are present, we use the last one + // (most recent). Both Anthropic's /v1/messages and OpenAI's /v1/responses + // require that ALL tool results are submitted for tool choices returned + // by the model, so any single tool call ID is sufficient to identify the + // parent interception. + CorrelatingToolCallID() string } diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 6f4f01f..2471c6d 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -45,6 +45,8 @@ type interceptionBase struct { recorder recorder.Recorder mcpProxy mcp.ServerProxier + + correlatingToolCallID string } func (i *interceptionBase) ID() uuid.UUID { @@ -57,6 +59,10 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, i.mcpProxy = mcpProxy } +func (i *interceptionBase) CorrelatingToolCallID() string { + return i.correlatingToolCallID +} + func (i *interceptionBase) Model() string { if i.req == nil { return "coder-aibridge-unknown" diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 7ab2bed..8534128 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -88,6 +88,21 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req messages := i.req.MessageNewParams logger := i.logger.With(slog.F("model", i.req.Model)) + // Scan the request for tool results; we use these to correlate requests + // together. We iterate backward so we find the last (most recent) tool + // result, which correctly identifies the parent interception. + if len(messages.Messages) > 0 { + content := messages.Messages[len(messages.Messages)-1].Content + for idx := len(content) - 1; idx >= 0; idx-- { + block := content[idx] + if block.OfToolResult == nil { + continue + } + i.correlatingToolCallID = block.OfToolResult.ToolUseID + break + } + } + var resp *anthropic.Message // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 4fc19fd..3176ac8 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -148,6 +148,23 @@ newStream: stream := i.newStream(streamCtx, svc, messages) + // Scan the request for tool results; we use these to correlate + // requests together. We iterate backward so we find the last + // (most recent) tool result, which correctly identifies the + // parent interception. + if len(messages.Messages) > 0 { + content := messages.Messages[len(messages.Messages)-1].Content + for idx := len(content) - 1; idx >= 0; idx-- { + block := content[idx] + if block.OfToolResult == nil { + continue + } + + i.correlatingToolCallID = block.OfToolResult.ToolUseID + break + } + } + var message anthropic.Message var lastToolName string diff --git a/intercept/responses/base.go b/intercept/responses/base.go index dcd72a0..73686c4 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -37,16 +37,17 @@ const ( ) type responsesInterceptionBase struct { - id uuid.UUID - req *ResponsesNewParamsWrapper - reqPayload []byte - cfg config.OpenAI - model string - recorder recorder.Recorder - mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer + id uuid.UUID + req *ResponsesNewParamsWrapper + reqPayload []byte + cfg config.OpenAI + model string + recorder recorder.Recorder + mcpProxy mcp.ServerProxier + logger slog.Logger + metrics metrics.Metrics + tracer trace.Tracer + correlatingToolCallID string } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { @@ -80,6 +81,24 @@ func (i *responsesInterceptionBase) Model() string { return i.model } +func (i *responsesInterceptionBase) CorrelatingToolCallID() string { + return i.correlatingToolCallID +} + +// scanForCorrelatingToolCallID scans the request input for function call +// output items and sets correlatingToolCallID to the CallID of the last one +// found, which correctly identifies the most recent parent interception. +func (i *responsesInterceptionBase) scanForCorrelatingToolCallID() { + for idx := len(i.req.Input.OfInputItemList) - 1; idx >= 0; idx-- { + item := i.req.Input.OfInputItemList[idx] + if item.OfFunctionCallOutput == nil { + continue + } + i.correlatingToolCallID = item.OfFunctionCallOutput.CallID + return + } +} + func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), @@ -263,6 +282,7 @@ func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Conte if err := i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: response.ID, + ToolCallID: item.CallID, Tool: item.Name, Args: args, Injected: false, diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index de72010..b28bd6e 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -265,6 +265,7 @@ func TestRecordToolUsage(t *testing.T) { Output: []oairesponses.ResponseOutputItemUnion{ { Type: "function_call", + CallID: "call_abc", Name: "get_weather", Arguments: "", }, @@ -274,6 +275,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_456", + ToolCallID: "call_abc", Tool: "get_weather", Args: "", Injected: false, @@ -287,11 +289,13 @@ func TestRecordToolUsage(t *testing.T) { Output: []oairesponses.ResponseOutputItemUnion{ { Type: "function_call", + CallID: "call_1", Name: "get_weather", Arguments: `{"location": "NYC"}`, }, { Type: "function_call", + CallID: "call_2", Name: "bad_json_args", Arguments: `{"bad": args`, }, @@ -301,12 +305,14 @@ func TestRecordToolUsage(t *testing.T) { Role: "assistant", }, { - Type: "custom_tool_call", - Name: "search", - Input: `{\"query\": \"test\"}`, + Type: "custom_tool_call", + CallID: "call_3", + Name: "search", + Input: `{\"query\": \"test\"}`, }, { Type: "function_call", + CallID: "call_4", Name: "calculate", Arguments: `{"a": 1, "b": 2}`, }, @@ -316,6 +322,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ToolCallID: "call_1", Tool: "get_weather", Args: map[string]any{"location": "NYC"}, Injected: false, @@ -323,6 +330,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ToolCallID: "call_2", Tool: "bad_json_args", Args: `{"bad": args`, Injected: false, @@ -330,6 +338,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ToolCallID: "call_3", Tool: "search", Args: `{\"query\": \"test\"}`, Injected: false, @@ -337,6 +346,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ToolCallID: "call_4", Tool: "calculate", Args: map[string]any{"a": float64(1), "b": float64(2)}, Injected: false, diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 3e94a6c..87f5ef5 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -60,6 +60,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * i.injectTools() i.disableParallelToolCalls() + i.scanForCorrelatingToolCallID() var ( response *responses.Response diff --git a/intercept/responses/injected_tools.go b/intercept/responses/injected_tools.go index e97c885..8a47801 100644 --- a/intercept/responses/injected_tools.go +++ b/intercept/responses/injected_tools.go @@ -212,6 +212,7 @@ func (i *responsesInterceptionBase) invokeInjectedTool(ctx context.Context, resp _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: responseID, + ToolCallID: fc.CallID, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: args, diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 6925d86..8f1b6a1 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -71,6 +71,7 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r i.injectTools() i.disableParallelToolCalls() + i.scanForCorrelatingToolCallID() events := eventstream.NewEventStream(ctx, i.logger.Named("sse-sender"), nil) go events.Start(w, r) diff --git a/recorder/types.go b/recorder/types.go index 3dc61c9..5d3e462 100644 --- a/recorder/types.go +++ b/recorder/types.go @@ -37,8 +37,9 @@ type InterceptionRecord struct { } type InterceptionRecordEnded struct { - ID string - EndedAt time.Time + ID string + EndedAt time.Time + CorrelatingToolCallID string } type TokenUsageRecord struct { @@ -63,6 +64,7 @@ type PromptUsageRecord struct { type ToolUsageRecord struct { InterceptionID string MsgID, Tool string + ToolCallID string ServerURL *string Args ToolArgs Injected bool diff --git a/responses_integration_test.go b/responses_integration_test.go index 2521253..4b82bfb 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -71,10 +71,11 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel: "gpt-4.1", expectPromptRecorded: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", expectToolRecorded: &recorder.ToolUsageRecord{ - MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", - Tool: "add", - Args: map[string]any{"a": float64(3), "b": float64(5)}, - Injected: false, + MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", + Tool: "add", + ToolCallID: "call_CJSaa2u51JG996575oVljuNq", + Args: map[string]any{"a": float64(3), "b": float64(5)}, + Injected: false, }, expectTokenUsage: &recorder.TokenUsageRecord{ MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", @@ -111,10 +112,11 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel: "gpt-5", expectPromptRecorded: "Use the code_exec tool to print hello world to the console.", expectToolRecorded: &recorder.ToolUsageRecord{ - MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", - Tool: "code_exec", - Args: "print(\"hello world\")", - Injected: false, + MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", + Tool: "code_exec", + ToolCallID: "call_haf8njtwrVZ1754Gm6fjAtuA", + Args: "print(\"hello world\")", + Injected: false, }, expectTokenUsage: &recorder.TokenUsageRecord{ MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", @@ -207,10 +209,11 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel: "gpt-4.1", expectPromptRecorded: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", expectToolRecorded: &recorder.ToolUsageRecord{ - MsgID: "resp_0c3fb28cfcf463a500695fa2f0239481a095ec6ce3dfe4d458", - Tool: "add", - Args: map[string]any{"a": float64(3), "b": float64(5)}, - Injected: false, + MsgID: "resp_0c3fb28cfcf463a500695fa2f0239481a095ec6ce3dfe4d458", + Tool: "add", + ToolCallID: "call_7VaiUXZYuuuwWwviCrckxq6t", + Args: map[string]any{"a": float64(3), "b": float64(5)}, + Injected: false, }, expectTokenUsage: &recorder.TokenUsageRecord{ MsgID: "resp_0c3fb28cfcf463a500695fa2f0239481a095ec6ce3dfe4d458", @@ -249,10 +252,11 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectModel: "gpt-5", expectPromptRecorded: "Use the code_exec tool to print hello world to the console.", expectToolRecorded: &recorder.ToolUsageRecord{ - MsgID: "resp_0c26996bc41c2a0500696942e83634819fb71b2b8ff8a4a76c", - Tool: "code_exec", - Args: "print(\"hello world\")", - Injected: false, + MsgID: "resp_0c26996bc41c2a0500696942e83634819fb71b2b8ff8a4a76c", + Tool: "code_exec", + ToolCallID: "call_2gSnF58IEhXLwlbnqbm5XKMd", + Args: "print(\"hello world\")", + Injected: false, }, expectTokenUsage: &recorder.TokenUsageRecord{ MsgID: "resp_0c26996bc41c2a0500696942e83634819fb71b2b8ff8a4a76c", From a402786c92f13668243a22a62106e7ee003a1de4 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Mon, 23 Feb 2026 15:09:37 +0200 Subject: [PATCH 2/9] test: add tests for tool call ID correlation and extract scan methods Extract scanForCorrelatingToolCallID() into methods on the base structs for chatcompletions and messages interceptors, matching the existing pattern in responses. Add TestScanForCorrelatingToolCallID table-driven tests for all three interceptors. Update MockRecorder to store full InterceptionRecordEnded structs instead of just timestamps, enabling richer test assertions. Co-Authored-By: Claude Opus 4.6 --- intercept/chatcompletions/base.go | 15 +++++ intercept/chatcompletions/base_test.go | 67 +++++++++++++++++++++ intercept/chatcompletions/blocking.go | 16 ++--- intercept/chatcompletions/streaming.go | 16 ++--- intercept/messages/base.go | 19 ++++++ intercept/messages/base_test.go | 81 +++++++++++++++++++++++++ intercept/messages/blocking.go | 19 ++---- intercept/messages/streaming.go | 21 ++----- intercept/responses/base_test.go | 83 ++++++++++++++++++++++++++ internal/testutil/mock_recorder.go | 15 +++-- 10 files changed, 296 insertions(+), 56 deletions(-) create mode 100644 intercept/chatcompletions/base_test.go diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 899cde3..21a0eeb 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -69,6 +69,21 @@ func (i *interceptionBase) CorrelatingToolCallID() string { return i.correlatingToolCallID } +// scanForCorrelatingToolCallID scans the request messages for tool +// result messages and sets correlatingToolCallID to the ToolCallID +// of the last one found, which correctly identifies the most recent +// parent interception. +func (i *interceptionBase) scanForCorrelatingToolCallID() { + for idx := len(i.req.Messages) - 1; idx >= 0; idx-- { + msg := i.req.Messages[idx] + if msg.OfTool == nil { + continue + } + i.correlatingToolCallID = msg.OfTool.ToolCallID + return + } +} + func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), diff --git a/intercept/chatcompletions/base_test.go b/intercept/chatcompletions/base_test.go new file mode 100644 index 0000000..bdc671d --- /dev/null +++ b/intercept/chatcompletions/base_test.go @@ -0,0 +1,67 @@ +package chatcompletions + +import ( + "testing" + + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/require" +) + +func TestScanForCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + messages []openai.ChatCompletionMessageParamUnion + expected string + }{ + { + name: "no messages", + messages: nil, + expected: "", + }, + { + name: "no tool messages", + messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + openai.AssistantMessage("hi there"), + }, + expected: "", + }, + { + name: "single tool message", + messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + openai.ToolMessage("result", "call_abc"), + }, + expected: "call_abc", + }, + { + name: "multiple tool messages returns last", + messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + openai.ToolMessage("first result", "call_first"), + openai.AssistantMessage("thinking"), + openai.ToolMessage("second result", "call_second"), + }, + expected: "call_second", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + base := &interceptionBase{ + req: &ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: tc.messages, + }, + }, + } + + base.scanForCorrelatingToolCallID() + require.Equal(t, tc.expected, base.CorrelatingToolCallID()) + }) + } +} diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index e989ff4..d1753e7 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -68,17 +68,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req i.injectTools() - // Scan the request for tool results; we use these to correlate requests - // together. We iterate backward so we find the last (most recent) tool - // result, which correctly identifies the parent interception. - for idx := len(i.req.Messages) - 1; idx >= 0; idx-- { - msg := i.req.Messages[idx] - if msg.OfTool == nil { - continue - } - i.correlatingToolCallID = msg.OfTool.ToolCallID - break - } + i.scanForCorrelatingToolCallID() prompt, err := i.req.lastUserPrompt() if err != nil { @@ -136,10 +126,12 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, + ToolCallID: toolCall.ID, Tool: toolCall.Function.Name, Args: i.unmarshalArgs(toolCall.Function.Arguments), Injected: false, }) + } } } @@ -173,6 +165,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: completion.ID, + ToolCallID: tc.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: args, @@ -180,6 +173,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req InvocationError: err, }) + if err != nil { // Always provide a tool result even if the tool call failed errorResponse := map[string]interface{}{ diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index 2c8f2cc..b561050 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -79,17 +79,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re i.injectTools() - // Scan the request for tool results; we use these to correlate requests - // together. We iterate backward so we find the last (most recent) tool - // result, which correctly identifies the parent interception. - for idx := len(i.req.Messages) - 1; idx >= 0; idx-- { - msg := i.req.Messages[idx] - if msg.OfTool == nil { - continue - } - i.correlatingToolCallID = msg.OfTool.ToolCallID - break - } + i.scanForCorrelatingToolCallID() // Allow us to interrupt watch via cancel. ctx, cancel := context.WithCancel(ctx) @@ -185,10 +175,12 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), + ToolCallID: toolCall.ID, Tool: toolCall.Name, Args: i.unmarshalArgs(toolCall.Arguments), Injected: false, }) + toolCall = nil } else { // When the provider responds with only tool calls (no text content), @@ -296,6 +288,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: processor.getMsgID(), + ToolCallID: id, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: args, @@ -303,6 +296,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re InvocationError: toolErr, }) + // Reset. toolCall = nil diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 2471c6d..eeae649 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -63,6 +63,25 @@ func (i *interceptionBase) CorrelatingToolCallID() string { return i.correlatingToolCallID } +// scanForCorrelatingToolCallID scans the last message's content +// blocks for tool result blocks and sets correlatingToolCallID +// to the ToolUseID of the last one found, which correctly +// identifies the most recent parent interception. +func (i *interceptionBase) scanForCorrelatingToolCallID() { + if len(i.req.Messages) == 0 { + return + } + content := i.req.Messages[len(i.req.Messages)-1].Content + for idx := len(content) - 1; idx >= 0; idx-- { + block := content[idx] + if block.OfToolResult == nil { + continue + } + i.correlatingToolCallID = block.OfToolResult.ToolUseID + return + } +} + func (i *interceptionBase) Model() string { if i.req == nil { return "coder-aibridge-unknown" diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index 5413a7d..8a00871 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -12,6 +12,87 @@ import ( "github.com/stretchr/testify/require" ) +func TestScanForCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + messages []anthropic.MessageParam + expected string + }{ + { + name: "no messages", + messages: nil, + expected: "", + }, + { + name: "last message has no tool_result blocks", + messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("hello")), + }, + expected: "", + }, + { + name: "single tool_result block", + messages: []anthropic.MessageParam{ + anthropic.NewUserMessage( + anthropic.ContentBlockParamUnion{ + OfToolResult: &anthropic.ToolResultBlockParam{ + ToolUseID: "toolu_abc", + Content: []anthropic.ToolResultBlockParamContentUnion{ + {OfText: &anthropic.TextBlockParam{Text: "result"}}, + }, + }, + }, + ), + }, + expected: "toolu_abc", + }, + { + name: "multiple tool_result blocks returns last", + messages: []anthropic.MessageParam{ + anthropic.NewUserMessage( + anthropic.ContentBlockParamUnion{ + OfToolResult: &anthropic.ToolResultBlockParam{ + ToolUseID: "toolu_first", + Content: []anthropic.ToolResultBlockParamContentUnion{ + {OfText: &anthropic.TextBlockParam{Text: "first"}}, + }, + }, + }, + anthropic.NewTextBlock("some text"), + anthropic.ContentBlockParamUnion{ + OfToolResult: &anthropic.ToolResultBlockParam{ + ToolUseID: "toolu_second", + Content: []anthropic.ToolResultBlockParamContentUnion{ + {OfText: &anthropic.TextBlockParam{Text: "second"}}, + }, + }, + }, + ), + }, + expected: "toolu_second", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + base := &interceptionBase{ + req: &MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: tc.messages, + }, + }, + } + + base.scanForCorrelatingToolCallID() + require.Equal(t, tc.expected, base.CorrelatingToolCallID()) + }) + } +} + func TestAWSBedrockValidation(t *testing.T) { t.Parallel() diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 8534128..2e254ac 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -88,20 +88,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req messages := i.req.MessageNewParams logger := i.logger.With(slog.F("model", i.req.Model)) - // Scan the request for tool results; we use these to correlate requests - // together. We iterate backward so we find the last (most recent) tool - // result, which correctly identifies the parent interception. - if len(messages.Messages) > 0 { - content := messages.Messages[len(messages.Messages)-1].Content - for idx := len(content) - 1; idx >= 0; idx-- { - block := content[idx] - if block.OfToolResult == nil { - continue - } - i.correlatingToolCallID = block.OfToolResult.ToolUseID - break - } - } + i.scanForCorrelatingToolCallID() var resp *anthropic.Message // Accumulate usage across the entire streaming interaction (including tool reinvocations). @@ -167,10 +154,12 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: resp.ID, + ToolCallID: toolUse.ID, Tool: toolUse.Name, Args: toolUse.Input, Injected: false, }) + } // If no injected tool calls, we're done. @@ -203,6 +192,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: resp.ID, + ToolCallID: tc.ID, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: tc.Input, @@ -210,6 +200,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req InvocationError: err, }) + if err != nil { // Always provide a tool_result even if the tool call failed messages.Messages = append(messages.Messages, diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 3176ac8..e5b133b 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -148,22 +148,7 @@ newStream: stream := i.newStream(streamCtx, svc, messages) - // Scan the request for tool results; we use these to correlate - // requests together. We iterate backward so we find the last - // (most recent) tool result, which correctly identifies the - // parent interception. - if len(messages.Messages) > 0 { - content := messages.Messages[len(messages.Messages)-1].Content - for idx := len(content) - 1; idx >= 0; idx-- { - block := content[idx] - if block.OfToolResult == nil { - continue - } - - i.correlatingToolCallID = block.OfToolResult.ToolUseID - break - } - } + i.scanForCorrelatingToolCallID() var message anthropic.Message var lastToolName string @@ -315,6 +300,7 @@ newStream: _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: message.ID, + ToolCallID: id, ServerURL: &tool.ServerURL, Tool: tool.Name, Args: input, @@ -322,6 +308,7 @@ newStream: InvocationError: err, }) + if err != nil { // Always provide a tool_result even if the tool call failed messages.Messages = append(messages.Messages, @@ -428,10 +415,12 @@ newStream: _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: message.ID, + ToolCallID: variant.ID, Tool: variant.Name, Args: variant.Input, Injected: false, }) + } } } diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index b28bd6e..6b3db4b 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -14,6 +14,89 @@ import ( "github.com/stretchr/testify/require" ) +func TestScanForCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input []oairesponses.ResponseInputItemUnionParam + expected string + }{ + { + name: "no input items", + input: nil, + expected: "", + }, + { + name: "no function_call_output items", + input: []oairesponses.ResponseInputItemUnionParam{ + { + OfMessage: &oairesponses.EasyInputMessageParam{ + Role: "user", + }, + }, + }, + expected: "", + }, + { + name: "single function_call_output", + input: []oairesponses.ResponseInputItemUnionParam{ + { + OfMessage: &oairesponses.EasyInputMessageParam{ + Role: "user", + }, + }, + { + OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ + CallID: "call_abc", + }, + }, + }, + expected: "call_abc", + }, + { + name: "multiple function_call_outputs returns last", + input: []oairesponses.ResponseInputItemUnionParam{ + { + OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ + CallID: "call_first", + }, + }, + { + OfMessage: &oairesponses.EasyInputMessageParam{ + Role: "user", + }, + }, + { + OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ + CallID: "call_second", + }, + }, + }, + expected: "call_second", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + base := &responsesInterceptionBase{ + req: &ResponsesNewParamsWrapper{ + ResponseNewParams: oairesponses.ResponseNewParams{ + Input: oairesponses.ResponseNewParamsInputUnion{ + OfInputItemList: tc.input, + }, + }, + }, + } + + base.scanForCorrelatingToolCallID() + require.Equal(t, tc.expected, base.CorrelatingToolCallID()) + }) + } +} + func TestLastUserPrompt(t *testing.T) { t.Parallel() diff --git a/internal/testutil/mock_recorder.go b/internal/testutil/mock_recorder.go index d4c9c8e..ac39006 100644 --- a/internal/testutil/mock_recorder.go +++ b/internal/testutil/mock_recorder.go @@ -6,7 +6,6 @@ import ( "slices" "sync" "testing" - "time" "github.com/coder/aibridge/recorder" "github.com/stretchr/testify/require" @@ -21,7 +20,7 @@ type MockRecorder struct { tokenUsages []*recorder.TokenUsageRecord userPrompts []*recorder.PromptUsageRecord toolUsages []*recorder.ToolUsageRecord - interceptionsEnd map[string]time.Time + interceptionsEnd map[string]*recorder.InterceptionRecordEnded } func (m *MockRecorder) RecordInterception(ctx context.Context, req *recorder.InterceptionRecord) error { @@ -35,12 +34,12 @@ func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorde m.mu.Lock() defer m.mu.Unlock() if m.interceptionsEnd == nil { - m.interceptionsEnd = make(map[string]time.Time) + m.interceptionsEnd = make(map[string]*recorder.InterceptionRecordEnded) } if !slices.ContainsFunc(m.interceptions, func(intc *recorder.InterceptionRecord) bool { return intc.ID == req.ID }) { return fmt.Errorf("id not found") } - m.interceptionsEnd[req.ID] = req.EndedAt + m.interceptionsEnd[req.ID] = req return nil } @@ -107,6 +106,14 @@ func (m *MockRecorder) ToolUsages() []*recorder.ToolUsageRecord { return m.toolUsages } +// RecordedInterceptionEnd returns the stored InterceptionRecordEnded for the +// given interception ID, or nil if not found. +func (m *MockRecorder) RecordedInterceptionEnd(id string) *recorder.InterceptionRecordEnded { + m.mu.Lock() + defer m.mu.Unlock() + return m.interceptionsEnd[id] +} + // VerifyAllInterceptionsEnded verifies all recorded interceptions have been marked as completed. func (m *MockRecorder) VerifyAllInterceptionsEnded(t *testing.T) { t.Helper() From 9bacc4a20b2cd85137ce72d95c555e415647e1ef Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Mon, 23 Feb 2026 15:19:04 +0200 Subject: [PATCH 3/9] chore: make fmt Signed-off-by: Danny Kopping --- intercept/chatcompletions/blocking.go | 2 -- intercept/chatcompletions/streaming.go | 1 - intercept/messages/blocking.go | 1 - intercept/messages/streaming.go | 2 -- 4 files changed, 6 deletions(-) diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index d1753e7..db2a4f2 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -131,7 +131,6 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req Args: i.unmarshalArgs(toolCall.Function.Arguments), Injected: false, }) - } } } @@ -173,7 +172,6 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req InvocationError: err, }) - if err != nil { // Always provide a tool result even if the tool call failed errorResponse := map[string]interface{}{ diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index b561050..4170af8 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -296,7 +296,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re InvocationError: toolErr, }) - // Reset. toolCall = nil diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 2e254ac..ef8bb38 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -200,7 +200,6 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req InvocationError: err, }) - if err != nil { // Always provide a tool_result even if the tool call failed messages.Messages = append(messages.Messages, diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index e5b133b..2f174fa 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -308,7 +308,6 @@ newStream: InvocationError: err, }) - if err != nil { // Always provide a tool_result even if the tool call failed messages.Messages = append(messages.Messages, @@ -420,7 +419,6 @@ newStream: Args: variant.Input, Injected: false, }) - } } } From d654682d18a5122c1994a1414e44313c093c5ad0 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Mon, 23 Feb 2026 15:27:51 +0200 Subject: [PATCH 4/9] fix: move scanForCorrelatingToolCallID before newStream loop The scan must happen once on the original request messages, not on every loop iteration. Inside the loop, appended injected tool results would be picked up instead of the original correlating tool call. Co-Authored-By: Claude Opus 4.6 --- intercept/messages/streaming.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 2f174fa..a4c65e8 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -131,6 +131,8 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re messages := i.req.MessageNewParams + i.scanForCorrelatingToolCallID() + // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage @@ -148,8 +150,6 @@ newStream: stream := i.newStream(streamCtx, svc, messages) - i.scanForCorrelatingToolCallID() - var message anthropic.Message var lastToolName string From e33d68cf212f7d006cf776ab417b02b1b9c381c4 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 24 Feb 2026 13:50:17 +0200 Subject: [PATCH 5/9] chore: simplify Signed-off-by: Danny Kopping --- intercept/chatcompletions/base.go | 17 ++++--------- intercept/chatcompletions/base_test.go | 1 - intercept/chatcompletions/blocking.go | 2 -- intercept/chatcompletions/streaming.go | 2 -- intercept/messages/base.go | 19 +++++--------- intercept/messages/base_test.go | 1 - intercept/messages/blocking.go | 2 -- intercept/messages/streaming.go | 2 -- intercept/responses/base.go | 35 +++++++++++--------------- intercept/responses/base_test.go | 1 - intercept/responses/blocking.go | 1 - intercept/responses/streaming.go | 1 - 12 files changed, 26 insertions(+), 58 deletions(-) diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 21a0eeb..3e66c27 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -34,8 +34,6 @@ type interceptionBase struct { recorder recorder.Recorder mcpProxy mcp.ServerProxier - - correlatingToolCallID string } func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { @@ -65,23 +63,18 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, i.mcpProxy = mcpProxy } +// CorrelatingToolCallID scans the request messages for tool result +// messages and returns the ToolCallID of the last one found, which +// correctly identifies the most recent parent interception. func (i *interceptionBase) CorrelatingToolCallID() string { - return i.correlatingToolCallID -} - -// scanForCorrelatingToolCallID scans the request messages for tool -// result messages and sets correlatingToolCallID to the ToolCallID -// of the last one found, which correctly identifies the most recent -// parent interception. -func (i *interceptionBase) scanForCorrelatingToolCallID() { for idx := len(i.req.Messages) - 1; idx >= 0; idx-- { msg := i.req.Messages[idx] if msg.OfTool == nil { continue } - i.correlatingToolCallID = msg.OfTool.ToolCallID - return + return msg.OfTool.ToolCallID } + return "" } func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { diff --git a/intercept/chatcompletions/base_test.go b/intercept/chatcompletions/base_test.go index bdc671d..d5c8a18 100644 --- a/intercept/chatcompletions/base_test.go +++ b/intercept/chatcompletions/base_test.go @@ -60,7 +60,6 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, } - base.scanForCorrelatingToolCallID() require.Equal(t, tc.expected, base.CorrelatingToolCallID()) }) } diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index db2a4f2..9a84d14 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -68,8 +68,6 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req i.injectTools() - i.scanForCorrelatingToolCallID() - prompt, err := i.req.lastUserPrompt() if err != nil { logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(err)) diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index 4170af8..ff3b78c 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -79,8 +79,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re i.injectTools() - i.scanForCorrelatingToolCallID() - // Allow us to interrupt watch via cancel. ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/intercept/messages/base.go b/intercept/messages/base.go index eeae649..50f89c3 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -45,8 +45,6 @@ type interceptionBase struct { recorder recorder.Recorder mcpProxy mcp.ServerProxier - - correlatingToolCallID string } func (i *interceptionBase) ID() uuid.UUID { @@ -59,17 +57,12 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, i.mcpProxy = mcpProxy } +// CorrelatingToolCallID scans the last message's content blocks for +// tool result blocks and returns the ToolUseID of the last one found, +// which correctly identifies the most recent parent interception. func (i *interceptionBase) CorrelatingToolCallID() string { - return i.correlatingToolCallID -} - -// scanForCorrelatingToolCallID scans the last message's content -// blocks for tool result blocks and sets correlatingToolCallID -// to the ToolUseID of the last one found, which correctly -// identifies the most recent parent interception. -func (i *interceptionBase) scanForCorrelatingToolCallID() { if len(i.req.Messages) == 0 { - return + return "" } content := i.req.Messages[len(i.req.Messages)-1].Content for idx := len(content) - 1; idx >= 0; idx-- { @@ -77,9 +70,9 @@ func (i *interceptionBase) scanForCorrelatingToolCallID() { if block.OfToolResult == nil { continue } - i.correlatingToolCallID = block.OfToolResult.ToolUseID - return + return block.OfToolResult.ToolUseID } + return "" } func (i *interceptionBase) Model() string { diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index 8a00871..efcb010 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -87,7 +87,6 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, } - base.scanForCorrelatingToolCallID() require.Equal(t, tc.expected, base.CorrelatingToolCallID()) }) } diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index ef8bb38..e22b97f 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -88,8 +88,6 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req messages := i.req.MessageNewParams logger := i.logger.With(slog.F("model", i.req.Model)) - i.scanForCorrelatingToolCallID() - var resp *anthropic.Message // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index a4c65e8..4e87fd8 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -131,8 +131,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re messages := i.req.MessageNewParams - i.scanForCorrelatingToolCallID() - // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 73686c4..685796e 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -37,17 +37,16 @@ const ( ) type responsesInterceptionBase struct { - id uuid.UUID - req *ResponsesNewParamsWrapper - reqPayload []byte - cfg config.OpenAI - model string - recorder recorder.Recorder - mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer - correlatingToolCallID string + id uuid.UUID + req *ResponsesNewParamsWrapper + reqPayload []byte + cfg config.OpenAI + model string + recorder recorder.Recorder + mcpProxy mcp.ServerProxier + logger slog.Logger + metrics metrics.Metrics + tracer trace.Tracer } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { @@ -81,22 +80,18 @@ func (i *responsesInterceptionBase) Model() string { return i.model } +// CorrelatingToolCallID scans the request input for function call +// output items and returns the CallID of the last one found, which +// correctly identifies the most recent parent interception. func (i *responsesInterceptionBase) CorrelatingToolCallID() string { - return i.correlatingToolCallID -} - -// scanForCorrelatingToolCallID scans the request input for function call -// output items and sets correlatingToolCallID to the CallID of the last one -// found, which correctly identifies the most recent parent interception. -func (i *responsesInterceptionBase) scanForCorrelatingToolCallID() { for idx := len(i.req.Input.OfInputItemList) - 1; idx >= 0; idx-- { item := i.req.Input.OfInputItemList[idx] if item.OfFunctionCallOutput == nil { continue } - i.correlatingToolCallID = item.OfFunctionCallOutput.CallID - return + return item.OfFunctionCallOutput.CallID } + return "" } func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 6b3db4b..094b055 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -91,7 +91,6 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, } - base.scanForCorrelatingToolCallID() require.Equal(t, tc.expected, base.CorrelatingToolCallID()) }) } diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 87f5ef5..3e94a6c 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -60,7 +60,6 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * i.injectTools() i.disableParallelToolCalls() - i.scanForCorrelatingToolCallID() var ( response *responses.Response diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 8f1b6a1..6925d86 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -71,7 +71,6 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r i.injectTools() i.disableParallelToolCalls() - i.scanForCorrelatingToolCallID() events := eventstream.NewEventStream(ctx, i.logger.Named("sse-sender"), nil) go events.Start(w, r) From 6160f1ef75d51aeffeb47e29ad5e10c2b266d384 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 24 Feb 2026 13:50:27 +0200 Subject: [PATCH 6/9] chore: add integration test assertions Signed-off-by: Danny Kopping --- bridge_integration_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 068474c..bf83264 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -64,16 +64,19 @@ func TestAnthropicMessages(t *testing.T) { streaming bool expectedInputTokens int expectedOutputTokens int + expectedToolCallID string }{ { streaming: true, expectedInputTokens: 2, expectedOutputTokens: 66, + expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", }, { streaming: false, expectedInputTokens: 5, expectedOutputTokens: 84, + expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq", }, } @@ -134,6 +137,7 @@ func TestAnthropicMessages(t *testing.T) { toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "Read", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) require.IsType(t, json.RawMessage{}, toolUsages[0].Args) var args map[string]any require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) @@ -278,16 +282,19 @@ func TestOpenAIChatCompletions(t *testing.T) { cases := []struct { streaming bool expectedInputTokens, expectedOutputTokens int + expectedToolCallID string }{ { streaming: true, expectedInputTokens: 60, expectedOutputTokens: 15, + expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", }, { streaming: false, expectedInputTokens: 60, expectedOutputTokens: 15, + expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc", }, } @@ -347,6 +354,7 @@ func TestOpenAIChatCompletions(t *testing.T) { toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "read_file", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) require.IsType(t, map[string]any{}, toolUsages[0].Args) require.Contains(t, toolUsages[0].Args, "path") assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) From 5d5a42b5a91c91856f09d1f224ddebee5394c959 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 24 Feb 2026 13:59:42 +0200 Subject: [PATCH 7/9] refactor: return *string from CorrelatingToolCallID() --- intercept/chatcompletions/base.go | 6 +++--- intercept/chatcompletions/base_test.go | 11 ++++++----- intercept/interceptor.go | 2 +- intercept/messages/base.go | 8 ++++---- intercept/messages/base_test.go | 11 ++++++----- intercept/responses/base.go | 6 +++--- intercept/responses/base_test.go | 11 ++++++----- recorder/types.go | 2 +- 8 files changed, 30 insertions(+), 27 deletions(-) diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 3e66c27..c068e83 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -66,15 +66,15 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, // CorrelatingToolCallID scans the request messages for tool result // messages and returns the ToolCallID of the last one found, which // correctly identifies the most recent parent interception. -func (i *interceptionBase) CorrelatingToolCallID() string { +func (i *interceptionBase) CorrelatingToolCallID() *string { for idx := len(i.req.Messages) - 1; idx >= 0; idx-- { msg := i.req.Messages[idx] if msg.OfTool == nil { continue } - return msg.OfTool.ToolCallID + return &msg.OfTool.ToolCallID } - return "" + return nil } func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { diff --git a/intercept/chatcompletions/base_test.go b/intercept/chatcompletions/base_test.go index d5c8a18..2246a7e 100644 --- a/intercept/chatcompletions/base_test.go +++ b/intercept/chatcompletions/base_test.go @@ -3,6 +3,7 @@ package chatcompletions import ( "testing" + "github.com/coder/aibridge/utils" "github.com/openai/openai-go/v3" "github.com/stretchr/testify/require" ) @@ -13,12 +14,12 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { tests := []struct { name string messages []openai.ChatCompletionMessageParamUnion - expected string + expected *string }{ { name: "no messages", messages: nil, - expected: "", + expected: nil, }, { name: "no tool messages", @@ -26,7 +27,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { openai.UserMessage("hello"), openai.AssistantMessage("hi there"), }, - expected: "", + expected: nil, }, { name: "single tool message", @@ -34,7 +35,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { openai.UserMessage("hello"), openai.ToolMessage("result", "call_abc"), }, - expected: "call_abc", + expected: utils.PtrTo("call_abc"), }, { name: "multiple tool messages returns last", @@ -44,7 +45,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { openai.AssistantMessage("thinking"), openai.ToolMessage("second result", "call_second"), }, - expected: "call_second", + expected: utils.PtrTo("call_second"), }, } diff --git a/intercept/interceptor.go b/intercept/interceptor.go index c767022..cbd29d6 100644 --- a/intercept/interceptor.go +++ b/intercept/interceptor.go @@ -33,5 +33,5 @@ type Interceptor interface { // require that ALL tool results are submitted for tool choices returned // by the model, so any single tool call ID is sufficient to identify the // parent interception. - CorrelatingToolCallID() string + CorrelatingToolCallID() *string } diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 50f89c3..0668410 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -60,9 +60,9 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, // CorrelatingToolCallID scans the last message's content blocks for // tool result blocks and returns the ToolUseID of the last one found, // which correctly identifies the most recent parent interception. -func (i *interceptionBase) CorrelatingToolCallID() string { +func (i *interceptionBase) CorrelatingToolCallID() *string { if len(i.req.Messages) == 0 { - return "" + return nil } content := i.req.Messages[len(i.req.Messages)-1].Content for idx := len(content) - 1; idx >= 0; idx-- { @@ -70,9 +70,9 @@ func (i *interceptionBase) CorrelatingToolCallID() string { if block.OfToolResult == nil { continue } - return block.OfToolResult.ToolUseID + return &block.OfToolResult.ToolUseID } - return "" + return nil } func (i *interceptionBase) Model() string { diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index efcb010..a35b656 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -8,6 +8,7 @@ import ( "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/config" "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/utils" mcpgo "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" ) @@ -18,19 +19,19 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { tests := []struct { name string messages []anthropic.MessageParam - expected string + expected *string }{ { name: "no messages", messages: nil, - expected: "", + expected: nil, }, { name: "last message has no tool_result blocks", messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("hello")), }, - expected: "", + expected: nil, }, { name: "single tool_result block", @@ -46,7 +47,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, ), }, - expected: "toolu_abc", + expected: utils.PtrTo("toolu_abc"), }, { name: "multiple tool_result blocks returns last", @@ -71,7 +72,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, ), }, - expected: "toolu_second", + expected: utils.PtrTo("toolu_second"), }, } diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 685796e..ba06fa4 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -83,15 +83,15 @@ func (i *responsesInterceptionBase) Model() string { // CorrelatingToolCallID scans the request input for function call // output items and returns the CallID of the last one found, which // correctly identifies the most recent parent interception. -func (i *responsesInterceptionBase) CorrelatingToolCallID() string { +func (i *responsesInterceptionBase) CorrelatingToolCallID() *string { for idx := len(i.req.Input.OfInputItemList) - 1; idx >= 0; idx-- { item := i.req.Input.OfInputItemList[idx] if item.OfFunctionCallOutput == nil { continue } - return item.OfFunctionCallOutput.CallID + return &item.OfFunctionCallOutput.CallID } - return "" + return nil } func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 094b055..5ad4bb8 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -9,6 +9,7 @@ import ( "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/recorder" + "github.com/coder/aibridge/utils" "github.com/google/uuid" oairesponses "github.com/openai/openai-go/v3/responses" "github.com/stretchr/testify/require" @@ -20,12 +21,12 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { tests := []struct { name string input []oairesponses.ResponseInputItemUnionParam - expected string + expected *string }{ { name: "no input items", input: nil, - expected: "", + expected: nil, }, { name: "no function_call_output items", @@ -36,7 +37,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, }, }, - expected: "", + expected: nil, }, { name: "single function_call_output", @@ -52,7 +53,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, }, }, - expected: "call_abc", + expected: utils.PtrTo("call_abc"), }, { name: "multiple function_call_outputs returns last", @@ -73,7 +74,7 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, }, }, - expected: "call_second", + expected: utils.PtrTo("call_second"), }, } diff --git a/recorder/types.go b/recorder/types.go index 5d3e462..d36dacf 100644 --- a/recorder/types.go +++ b/recorder/types.go @@ -39,7 +39,7 @@ type InterceptionRecord struct { type InterceptionRecordEnded struct { ID string EndedAt time.Time - CorrelatingToolCallID string + CorrelatingToolCallID *string } type TokenUsageRecord struct { From 799926009d9678e5a2ede9f858d2a6f9471c58f8 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Wed, 25 Feb 2026 14:33:36 +0200 Subject: [PATCH 8/9] chore: only consider final input message for tool results Signed-off-by: Danny Kopping --- intercept/chatcompletions/base.go | 19 +++++++++---------- intercept/chatcompletions/base_test.go | 9 +++++++++ intercept/messages/base.go | 3 --- intercept/messages/base_test.go | 16 ++++++++++++++++ intercept/responses/base.go | 19 +++++++++---------- intercept/responses/base_test.go | 16 ++++++++++++++++ 6 files changed, 59 insertions(+), 23 deletions(-) diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index c068e83..aed8f88 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -63,18 +63,17 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, i.mcpProxy = mcpProxy } -// CorrelatingToolCallID scans the request messages for tool result -// messages and returns the ToolCallID of the last one found, which -// correctly identifies the most recent parent interception. func (i *interceptionBase) CorrelatingToolCallID() *string { - for idx := len(i.req.Messages) - 1; idx >= 0; idx-- { - msg := i.req.Messages[idx] - if msg.OfTool == nil { - continue - } - return &msg.OfTool.ToolCallID + if len(i.req.Messages) == 0 { + return nil + } + + // The tool result should be the last input message. + msg := i.req.Messages[len(i.req.Messages)-1] + if msg.OfTool == nil { + return nil } - return nil + return &msg.OfTool.ToolCallID } func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { diff --git a/intercept/chatcompletions/base_test.go b/intercept/chatcompletions/base_test.go index 2246a7e..1647a2d 100644 --- a/intercept/chatcompletions/base_test.go +++ b/intercept/chatcompletions/base_test.go @@ -47,6 +47,15 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, expected: utils.PtrTo("call_second"), }, + { + name: "last message is not a tool message", + messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + openai.ToolMessage("first result", "call_first"), + openai.AssistantMessage("thinking"), + }, + expected: nil, + }, } for _, tc := range tests { diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 0668410..387591d 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -57,9 +57,6 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, i.mcpProxy = mcpProxy } -// CorrelatingToolCallID scans the last message's content blocks for -// tool result blocks and returns the ToolUseID of the last one found, -// which correctly identifies the most recent parent interception. func (i *interceptionBase) CorrelatingToolCallID() *string { if len(i.req.Messages) == 0 { return nil diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index a35b656..cca890e 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -74,6 +74,22 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, expected: utils.PtrTo("toolu_second"), }, + { + name: "last message is not a tool result", + messages: []anthropic.MessageParam{ + anthropic.NewUserMessage( + anthropic.ContentBlockParamUnion{ + OfToolResult: &anthropic.ToolResultBlockParam{ + ToolUseID: "toolu_first", + Content: []anthropic.ToolResultBlockParamContentUnion{ + {OfText: &anthropic.TextBlockParam{Text: "first"}}, + }, + }, + }), + anthropic.NewUserMessage(anthropic.NewTextBlock("some text")), + }, + expected: nil, + }, } for _, tc := range tests { diff --git a/intercept/responses/base.go b/intercept/responses/base.go index ba06fa4..b531a71 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -80,18 +80,17 @@ func (i *responsesInterceptionBase) Model() string { return i.model } -// CorrelatingToolCallID scans the request input for function call -// output items and returns the CallID of the last one found, which -// correctly identifies the most recent parent interception. func (i *responsesInterceptionBase) CorrelatingToolCallID() *string { - for idx := len(i.req.Input.OfInputItemList) - 1; idx >= 0; idx-- { - item := i.req.Input.OfInputItemList[idx] - if item.OfFunctionCallOutput == nil { - continue - } - return &item.OfFunctionCallOutput.CallID + if len(i.req.Input.OfInputItemList) == 0 { + return nil } - return nil + + // The tool result should be the last input message. + item := i.req.Input.OfInputItemList[len(i.req.Input.OfInputItemList)-1] + if item.OfFunctionCallOutput == nil { + return nil + } + return &item.OfFunctionCallOutput.CallID } func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 5ad4bb8..ad0d59a 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -76,6 +76,22 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { }, expected: utils.PtrTo("call_second"), }, + { + name: "last input is not a tool result", + input: []oairesponses.ResponseInputItemUnionParam{ + { + OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{ + CallID: "call_first", + }, + }, + { + OfMessage: &oairesponses.EasyInputMessageParam{ + Role: "user", + }, + }, + }, + expected: nil, + }, } for _, tc := range tests { From 1f6e9da4201f1821d74da952fd857568ba0244af Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Wed, 25 Feb 2026 15:39:19 +0200 Subject: [PATCH 9/9] feat: move correlating tool call ID to RecordInterception Move CorrelatingToolCallID from InterceptionRecordEnded to InterceptionRecord so that the parent correlation is passed at interception start time, when it is already known from the incoming request's input messages. Co-Authored-By: Claude Opus 4.6 --- bridge.go | 17 +++++++++-------- recorder/types.go | 22 +++++++++++----------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/bridge.go b/bridge.go index a2b8613..7033b89 100644 --- a/bridge.go +++ b/bridge.go @@ -203,13 +203,14 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC interceptor.Setup(logger, asyncRecorder, mcpProxy) if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{ - Client: guessClient(r), - ID: interceptor.ID().String(), - InitiatorID: actor.ID, - Metadata: actor.Metadata, - Model: interceptor.Model(), - Provider: p.Name(), - UserAgent: r.UserAgent(), + Client: guessClient(r), + ID: interceptor.ID().String(), + InitiatorID: actor.ID, + Metadata: actor.Metadata, + Model: interceptor.Model(), + Provider: p.Name(), + UserAgent: r.UserAgent(), + CorrelatingToolCallID: interceptor.CorrelatingToolCallID(), }); err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err)) logger.Warn(ctx, "failed to record interception", slog.Error(err)) @@ -250,7 +251,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC log.Debug(ctx, "interception ended") } - asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String(), CorrelatingToolCallID: interceptor.CorrelatingToolCallID()}) + asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()}) // Ensure all recording have completed before completing request. asyncRecorder.Wait() diff --git a/recorder/types.go b/recorder/types.go index d36dacf..82c34d0 100644 --- a/recorder/types.go +++ b/recorder/types.go @@ -26,20 +26,20 @@ type ToolArgs any type Metadata map[string]any type InterceptionRecord struct { - Client string - ID string - InitiatorID string - Metadata Metadata - Model string - Provider string - StartedAt time.Time - UserAgent string + Client string + ID string + InitiatorID string + Metadata Metadata + Model string + Provider string + StartedAt time.Time + UserAgent string + CorrelatingToolCallID *string } type InterceptionRecordEnded struct { - ID string - EndedAt time.Time - CorrelatingToolCallID *string + ID string + EndedAt time.Time } type TokenUsageRecord struct {