Skip to content
15 changes: 8 additions & 7 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
interceptor.Setup(logger, asyncRecorder, mcpProxy)

if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{
Client: guessClient(r),
ID: interceptor.ID().String(),
InitiatorID: actor.ID,
Metadata: actor.Metadata,
Model: interceptor.Model(),
Provider: p.Name(),
UserAgent: r.UserAgent(),
Client: guessClient(r),
ID: interceptor.ID().String(),
InitiatorID: actor.ID,
Metadata: actor.Metadata,
Model: interceptor.Model(),
Provider: p.Name(),
UserAgent: r.UserAgent(),
CorrelatingToolCallID: interceptor.CorrelatingToolCallID(),
}); err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err))
logger.Warn(ctx, "failed to record interception", slog.Error(err))
Expand Down
8 changes: 8 additions & 0 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,19 @@ func TestAnthropicMessages(t *testing.T) {
streaming bool
expectedInputTokens int
expectedOutputTokens int
expectedToolCallID string
}{
{
streaming: true,
expectedInputTokens: 2,
expectedOutputTokens: 66,
expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo",
},
{
streaming: false,
expectedInputTokens: 5,
expectedOutputTokens: 84,
expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq",
},
}

Expand Down Expand Up @@ -134,6 +137,7 @@ func TestAnthropicMessages(t *testing.T) {
toolUsages := recorderClient.RecordedToolUsages()
require.Len(t, toolUsages, 1)
assert.Equal(t, "Read", toolUsages[0].Tool)
assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID)
require.IsType(t, json.RawMessage{}, toolUsages[0].Args)
var args map[string]any
require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args))
Expand Down Expand Up @@ -278,16 +282,19 @@ func TestOpenAIChatCompletions(t *testing.T) {
cases := []struct {
streaming bool
expectedInputTokens, expectedOutputTokens int
expectedToolCallID string
}{
{
streaming: true,
expectedInputTokens: 60,
expectedOutputTokens: 15,
expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B",
},
{
streaming: false,
expectedInputTokens: 60,
expectedOutputTokens: 15,
expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc",
},
}

Expand Down Expand Up @@ -347,6 +354,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
toolUsages := recorderClient.RecordedToolUsages()
require.Len(t, toolUsages, 1)
assert.Equal(t, "read_file", toolUsages[0].Tool)
assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID)
require.IsType(t, map[string]any{}, toolUsages[0].Args)
require.Contains(t, toolUsages[0].Args, "path")
assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"])
Expand Down
13 changes: 13 additions & 0 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder,
i.mcpProxy = mcpProxy
}

func (i *interceptionBase) CorrelatingToolCallID() *string {
if len(i.req.Messages) == 0 {
return nil
}

// The tool result should be the last input message.
msg := i.req.Messages[len(i.req.Messages)-1]
if msg.OfTool == nil {
return nil
}
return &msg.OfTool.ToolCallID
}

func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
Expand Down
76 changes: 76 additions & 0 deletions intercept/chatcompletions/base_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package chatcompletions

import (
"testing"

"github.com/coder/aibridge/utils"
"github.com/openai/openai-go/v3"
"github.com/stretchr/testify/require"
)

func TestScanForCorrelatingToolCallID(t *testing.T) {
t.Parallel()

tests := []struct {
name string
messages []openai.ChatCompletionMessageParamUnion
expected *string
}{
{
name: "no messages",
messages: nil,
expected: nil,
},
{
name: "no tool messages",
messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("hello"),
openai.AssistantMessage("hi there"),
},
expected: nil,
},
{
name: "single tool message",
messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("hello"),
openai.ToolMessage("result", "call_abc"),
},
expected: utils.PtrTo("call_abc"),
},
{
name: "multiple tool messages returns last",
messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("hello"),
openai.ToolMessage("first result", "call_first"),
openai.AssistantMessage("thinking"),
openai.ToolMessage("second result", "call_second"),
},
expected: utils.PtrTo("call_second"),
},
{
name: "last message is not a tool message",
messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("hello"),
openai.ToolMessage("first result", "call_first"),
openai.AssistantMessage("thinking"),
},
expected: nil,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

base := &interceptionBase{
req: &ChatCompletionNewParamsWrapper{
ChatCompletionNewParams: openai.ChatCompletionNewParams{
Messages: tc.messages,
},
},
}

require.Equal(t, tc.expected, base.CorrelatingToolCallID())
})
}
}
2 changes: 2 additions & 0 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: completion.ID,
ToolCallID: toolCall.ID,
Tool: toolCall.Function.Name,
Args: i.unmarshalArgs(toolCall.Function.Arguments),
Injected: false,
Expand Down Expand Up @@ -161,6 +162,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: completion.ID,
ToolCallID: tc.ID,
ServerURL: &tool.ServerURL,
Tool: tool.Name,
Args: args,
Expand Down
3 changes: 3 additions & 0 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,12 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: processor.getMsgID(),
ToolCallID: toolCall.ID,
Tool: toolCall.Name,
Args: i.unmarshalArgs(toolCall.Arguments),
Injected: false,
})

toolCall = nil
} else {
// When the provider responds with only tool calls (no text content),
Expand Down Expand Up @@ -284,6 +286,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: processor.getMsgID(),
ToolCallID: id,
ServerURL: &tool.ServerURL,
Tool: tool.Name,
Args: args,
Expand Down
9 changes: 9 additions & 0 deletions intercept/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,13 @@ type Interceptor interface {
Streaming() bool
// TraceAttributes returns tracing attributes for this [Interceptor]
TraceAttributes(*http.Request) []attribute.KeyValue
// CorrelatingToolCallID returns the ID of a tool call result submitted
// in the request, if present. This is used to correlate the current
// interception back to the previous interception that issued those tool
// calls. If multiple tool use results are present, we use the last one
// (most recent). Both Anthropic's /v1/messages and OpenAI's /v1/responses
// require that ALL tool results are submitted for tool choices returned
// by the model, so any single tool call ID is sufficient to identify the
// parent interception.
CorrelatingToolCallID() *string
}
15 changes: 15 additions & 0 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder,
i.mcpProxy = mcpProxy
}

func (i *interceptionBase) CorrelatingToolCallID() *string {
if len(i.req.Messages) == 0 {
return nil
}
content := i.req.Messages[len(i.req.Messages)-1].Content
for idx := len(content) - 1; idx >= 0; idx-- {
block := content[idx]
if block.OfToolResult == nil {
continue
}
return &block.OfToolResult.ToolUseID
}
return nil
}

func (i *interceptionBase) Model() string {
if i.req == nil {
return "coder-aibridge-unknown"
Expand Down
97 changes: 97 additions & 0 deletions intercept/messages/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,107 @@ import (
"github.com/anthropics/anthropic-sdk-go/shared/constant"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/utils"
mcpgo "github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"
)

func TestScanForCorrelatingToolCallID(t *testing.T) {
t.Parallel()

tests := []struct {
name string
messages []anthropic.MessageParam
expected *string
}{
{
name: "no messages",
messages: nil,
expected: nil,
},
{
name: "last message has no tool_result blocks",
messages: []anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock("hello")),
},
expected: nil,
},
{
name: "single tool_result block",
messages: []anthropic.MessageParam{
anthropic.NewUserMessage(
anthropic.ContentBlockParamUnion{
OfToolResult: &anthropic.ToolResultBlockParam{
ToolUseID: "toolu_abc",
Content: []anthropic.ToolResultBlockParamContentUnion{
{OfText: &anthropic.TextBlockParam{Text: "result"}},
},
},
},
),
},
expected: utils.PtrTo("toolu_abc"),
},
{
name: "multiple tool_result blocks returns last",
messages: []anthropic.MessageParam{
anthropic.NewUserMessage(
anthropic.ContentBlockParamUnion{
OfToolResult: &anthropic.ToolResultBlockParam{
ToolUseID: "toolu_first",
Content: []anthropic.ToolResultBlockParamContentUnion{
{OfText: &anthropic.TextBlockParam{Text: "first"}},
},
},
},
anthropic.NewTextBlock("some text"),
anthropic.ContentBlockParamUnion{
OfToolResult: &anthropic.ToolResultBlockParam{
ToolUseID: "toolu_second",
Content: []anthropic.ToolResultBlockParamContentUnion{
{OfText: &anthropic.TextBlockParam{Text: "second"}},
},
},
},
),
},
expected: utils.PtrTo("toolu_second"),
},
{
name: "last message is not a tool result",
messages: []anthropic.MessageParam{
anthropic.NewUserMessage(
anthropic.ContentBlockParamUnion{
OfToolResult: &anthropic.ToolResultBlockParam{
ToolUseID: "toolu_first",
Content: []anthropic.ToolResultBlockParamContentUnion{
{OfText: &anthropic.TextBlockParam{Text: "first"}},
},
},
}),
anthropic.NewUserMessage(anthropic.NewTextBlock("some text")),
},
expected: nil,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

base := &interceptionBase{
req: &MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Messages: tc.messages,
},
},
}

require.Equal(t, tc.expected, base.CorrelatingToolCallID())
})
}
}

func TestAWSBedrockValidation(t *testing.T) {
t.Parallel()

Expand Down
3 changes: 3 additions & 0 deletions intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: resp.ID,
ToolCallID: toolUse.ID,
Tool: toolUse.Name,
Args: toolUse.Input,
Injected: false,
})

}

// If no injected tool calls, we're done.
Expand Down Expand Up @@ -188,6 +190,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: resp.ID,
ToolCallID: tc.ID,
ServerURL: &tool.ServerURL,
Tool: tool.Name,
Args: tc.Input,
Expand Down
2 changes: 2 additions & 0 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ newStream:
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: message.ID,
ToolCallID: id,
ServerURL: &tool.ServerURL,
Tool: tool.Name,
Args: input,
Expand Down Expand Up @@ -411,6 +412,7 @@ newStream:
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: message.ID,
ToolCallID: variant.ID,
Tool: variant.Name,
Args: variant.Input,
Injected: false,
Expand Down
Loading