Skip to content

Commit 498272d

Browse files
authored
feat: record and correlate tool call IDs across interceptions (#188)
Track provider-supplied tool call IDs with each tool use response. Add `CorrelatingToolCallID` to the `Interceptor` interface for interception lineage tracking. Each interceptor scans backward through request messages to find the most recent tool call result, identifying the parent interception that triggered the current one. Adapted from the [aibridge `prompt_provenance_poc`](main...prompt_provenance_poc) branch. Downstream: [coder/coder#22246](coder/coder#22246). Closes #165
1 parent cc60987 commit 498272d

17 files changed

Lines changed: 404 additions & 38 deletions

bridge.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,14 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
203203
interceptor.Setup(logger, asyncRecorder, mcpProxy)
204204

205205
if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{
206-
Client: guessClient(r),
207-
ID: interceptor.ID().String(),
208-
InitiatorID: actor.ID,
209-
Metadata: actor.Metadata,
210-
Model: interceptor.Model(),
211-
Provider: p.Name(),
212-
UserAgent: r.UserAgent(),
206+
Client: guessClient(r),
207+
ID: interceptor.ID().String(),
208+
InitiatorID: actor.ID,
209+
Metadata: actor.Metadata,
210+
Model: interceptor.Model(),
211+
Provider: p.Name(),
212+
UserAgent: r.UserAgent(),
213+
CorrelatingToolCallID: interceptor.CorrelatingToolCallID(),
213214
}); err != nil {
214215
span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err))
215216
logger.Warn(ctx, "failed to record interception", slog.Error(err))

bridge_integration_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,19 @@ func TestAnthropicMessages(t *testing.T) {
6464
streaming bool
6565
expectedInputTokens int
6666
expectedOutputTokens int
67+
expectedToolCallID string
6768
}{
6869
{
6970
streaming: true,
7071
expectedInputTokens: 2,
7172
expectedOutputTokens: 66,
73+
expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo",
7274
},
7375
{
7476
streaming: false,
7577
expectedInputTokens: 5,
7678
expectedOutputTokens: 84,
79+
expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq",
7780
},
7881
}
7982

@@ -134,6 +137,7 @@ func TestAnthropicMessages(t *testing.T) {
134137
toolUsages := recorderClient.RecordedToolUsages()
135138
require.Len(t, toolUsages, 1)
136139
assert.Equal(t, "Read", toolUsages[0].Tool)
140+
assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID)
137141
require.IsType(t, json.RawMessage{}, toolUsages[0].Args)
138142
var args map[string]any
139143
require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args))
@@ -278,16 +282,19 @@ func TestOpenAIChatCompletions(t *testing.T) {
278282
cases := []struct {
279283
streaming bool
280284
expectedInputTokens, expectedOutputTokens int
285+
expectedToolCallID string
281286
}{
282287
{
283288
streaming: true,
284289
expectedInputTokens: 60,
285290
expectedOutputTokens: 15,
291+
expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B",
286292
},
287293
{
288294
streaming: false,
289295
expectedInputTokens: 60,
290296
expectedOutputTokens: 15,
297+
expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc",
291298
},
292299
}
293300

@@ -347,6 +354,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
347354
toolUsages := recorderClient.RecordedToolUsages()
348355
require.Len(t, toolUsages, 1)
349356
assert.Equal(t, "read_file", toolUsages[0].Tool)
357+
assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID)
350358
require.IsType(t, map[string]any{}, toolUsages[0].Args)
351359
require.Contains(t, toolUsages[0].Args, "path")
352360
assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"])

intercept/chatcompletions/base.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder,
6363
i.mcpProxy = mcpProxy
6464
}
6565

66+
func (i *interceptionBase) CorrelatingToolCallID() *string {
67+
if len(i.req.Messages) == 0 {
68+
return nil
69+
}
70+
71+
// The tool result should be the last input message.
72+
msg := i.req.Messages[len(i.req.Messages)-1]
73+
if msg.OfTool == nil {
74+
return nil
75+
}
76+
return &msg.OfTool.ToolCallID
77+
}
78+
6679
func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
6780
return []attribute.KeyValue{
6881
attribute.String(tracing.RequestPath, r.URL.Path),
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package chatcompletions
2+
3+
import (
4+
"testing"
5+
6+
"github.com/coder/aibridge/utils"
7+
"github.com/openai/openai-go/v3"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestScanForCorrelatingToolCallID(t *testing.T) {
12+
t.Parallel()
13+
14+
tests := []struct {
15+
name string
16+
messages []openai.ChatCompletionMessageParamUnion
17+
expected *string
18+
}{
19+
{
20+
name: "no messages",
21+
messages: nil,
22+
expected: nil,
23+
},
24+
{
25+
name: "no tool messages",
26+
messages: []openai.ChatCompletionMessageParamUnion{
27+
openai.UserMessage("hello"),
28+
openai.AssistantMessage("hi there"),
29+
},
30+
expected: nil,
31+
},
32+
{
33+
name: "single tool message",
34+
messages: []openai.ChatCompletionMessageParamUnion{
35+
openai.UserMessage("hello"),
36+
openai.ToolMessage("result", "call_abc"),
37+
},
38+
expected: utils.PtrTo("call_abc"),
39+
},
40+
{
41+
name: "multiple tool messages returns last",
42+
messages: []openai.ChatCompletionMessageParamUnion{
43+
openai.UserMessage("hello"),
44+
openai.ToolMessage("first result", "call_first"),
45+
openai.AssistantMessage("thinking"),
46+
openai.ToolMessage("second result", "call_second"),
47+
},
48+
expected: utils.PtrTo("call_second"),
49+
},
50+
{
51+
name: "last message is not a tool message",
52+
messages: []openai.ChatCompletionMessageParamUnion{
53+
openai.UserMessage("hello"),
54+
openai.ToolMessage("first result", "call_first"),
55+
openai.AssistantMessage("thinking"),
56+
},
57+
expected: nil,
58+
},
59+
}
60+
61+
for _, tc := range tests {
62+
t.Run(tc.name, func(t *testing.T) {
63+
t.Parallel()
64+
65+
base := &interceptionBase{
66+
req: &ChatCompletionNewParamsWrapper{
67+
ChatCompletionNewParams: openai.ChatCompletionNewParams{
68+
Messages: tc.messages,
69+
},
70+
},
71+
}
72+
73+
require.Equal(t, tc.expected, base.CorrelatingToolCallID())
74+
})
75+
}
76+
}

intercept/chatcompletions/blocking.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
124124
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
125125
InterceptionID: i.ID().String(),
126126
MsgID: completion.ID,
127+
ToolCallID: toolCall.ID,
127128
Tool: toolCall.Function.Name,
128129
Args: i.unmarshalArgs(toolCall.Function.Arguments),
129130
Injected: false,
@@ -161,6 +162,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
161162
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
162163
InterceptionID: i.ID().String(),
163164
MsgID: completion.ID,
165+
ToolCallID: tc.ID,
164166
ServerURL: &tool.ServerURL,
165167
Tool: tool.Name,
166168
Args: args,

intercept/chatcompletions/streaming.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,12 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
173173
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
174174
InterceptionID: i.ID().String(),
175175
MsgID: processor.getMsgID(),
176+
ToolCallID: toolCall.ID,
176177
Tool: toolCall.Name,
177178
Args: i.unmarshalArgs(toolCall.Arguments),
178179
Injected: false,
179180
})
181+
180182
toolCall = nil
181183
} else {
182184
// When the provider responds with only tool calls (no text content),
@@ -284,6 +286,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
284286
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
285287
InterceptionID: i.ID().String(),
286288
MsgID: processor.getMsgID(),
289+
ToolCallID: id,
287290
ServerURL: &tool.ServerURL,
288291
Tool: tool.Name,
289292
Args: args,

intercept/interceptor.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,13 @@ type Interceptor interface {
2525
Streaming() bool
2626
// TraceAttributes returns tracing attributes for this [Interceptor]
2727
TraceAttributes(*http.Request) []attribute.KeyValue
28+
// CorrelatingToolCallID returns the ID of a tool call result submitted
29+
// in the request, if present. This is used to correlate the current
30+
// interception back to the previous interception that issued those tool
31+
// calls. If multiple tool use results are present, we use the last one
32+
// (most recent). Both Anthropic's /v1/messages and OpenAI's /v1/responses
33+
// require that ALL tool results are submitted for tool choices returned
34+
// by the model, so any single tool call ID is sufficient to identify the
35+
// parent interception.
36+
CorrelatingToolCallID() *string
2837
}

intercept/messages/base.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder,
5757
i.mcpProxy = mcpProxy
5858
}
5959

60+
func (i *interceptionBase) CorrelatingToolCallID() *string {
61+
if len(i.req.Messages) == 0 {
62+
return nil
63+
}
64+
content := i.req.Messages[len(i.req.Messages)-1].Content
65+
for idx := len(content) - 1; idx >= 0; idx-- {
66+
block := content[idx]
67+
if block.OfToolResult == nil {
68+
continue
69+
}
70+
return &block.OfToolResult.ToolUseID
71+
}
72+
return nil
73+
}
74+
6075
func (i *interceptionBase) Model() string {
6176
if i.req == nil {
6277
return "coder-aibridge-unknown"

intercept/messages/base_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,107 @@ import (
88
"github.com/anthropics/anthropic-sdk-go/shared/constant"
99
"github.com/coder/aibridge/config"
1010
"github.com/coder/aibridge/mcp"
11+
"github.com/coder/aibridge/utils"
1112
mcpgo "github.com/mark3labs/mcp-go/mcp"
1213
"github.com/stretchr/testify/require"
1314
)
1415

16+
func TestScanForCorrelatingToolCallID(t *testing.T) {
17+
t.Parallel()
18+
19+
tests := []struct {
20+
name string
21+
messages []anthropic.MessageParam
22+
expected *string
23+
}{
24+
{
25+
name: "no messages",
26+
messages: nil,
27+
expected: nil,
28+
},
29+
{
30+
name: "last message has no tool_result blocks",
31+
messages: []anthropic.MessageParam{
32+
anthropic.NewUserMessage(anthropic.NewTextBlock("hello")),
33+
},
34+
expected: nil,
35+
},
36+
{
37+
name: "single tool_result block",
38+
messages: []anthropic.MessageParam{
39+
anthropic.NewUserMessage(
40+
anthropic.ContentBlockParamUnion{
41+
OfToolResult: &anthropic.ToolResultBlockParam{
42+
ToolUseID: "toolu_abc",
43+
Content: []anthropic.ToolResultBlockParamContentUnion{
44+
{OfText: &anthropic.TextBlockParam{Text: "result"}},
45+
},
46+
},
47+
},
48+
),
49+
},
50+
expected: utils.PtrTo("toolu_abc"),
51+
},
52+
{
53+
name: "multiple tool_result blocks returns last",
54+
messages: []anthropic.MessageParam{
55+
anthropic.NewUserMessage(
56+
anthropic.ContentBlockParamUnion{
57+
OfToolResult: &anthropic.ToolResultBlockParam{
58+
ToolUseID: "toolu_first",
59+
Content: []anthropic.ToolResultBlockParamContentUnion{
60+
{OfText: &anthropic.TextBlockParam{Text: "first"}},
61+
},
62+
},
63+
},
64+
anthropic.NewTextBlock("some text"),
65+
anthropic.ContentBlockParamUnion{
66+
OfToolResult: &anthropic.ToolResultBlockParam{
67+
ToolUseID: "toolu_second",
68+
Content: []anthropic.ToolResultBlockParamContentUnion{
69+
{OfText: &anthropic.TextBlockParam{Text: "second"}},
70+
},
71+
},
72+
},
73+
),
74+
},
75+
expected: utils.PtrTo("toolu_second"),
76+
},
77+
{
78+
name: "last message is not a tool result",
79+
messages: []anthropic.MessageParam{
80+
anthropic.NewUserMessage(
81+
anthropic.ContentBlockParamUnion{
82+
OfToolResult: &anthropic.ToolResultBlockParam{
83+
ToolUseID: "toolu_first",
84+
Content: []anthropic.ToolResultBlockParamContentUnion{
85+
{OfText: &anthropic.TextBlockParam{Text: "first"}},
86+
},
87+
},
88+
}),
89+
anthropic.NewUserMessage(anthropic.NewTextBlock("some text")),
90+
},
91+
expected: nil,
92+
},
93+
}
94+
95+
for _, tc := range tests {
96+
t.Run(tc.name, func(t *testing.T) {
97+
t.Parallel()
98+
99+
base := &interceptionBase{
100+
req: &MessageNewParamsWrapper{
101+
MessageNewParams: anthropic.MessageNewParams{
102+
Messages: tc.messages,
103+
},
104+
},
105+
}
106+
107+
require.Equal(t, tc.expected, base.CorrelatingToolCallID())
108+
})
109+
}
110+
}
111+
15112
func TestAWSBedrockValidation(t *testing.T) {
16113
t.Parallel()
17114

intercept/messages/blocking.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,12 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
152152
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
153153
InterceptionID: i.ID().String(),
154154
MsgID: resp.ID,
155+
ToolCallID: toolUse.ID,
155156
Tool: toolUse.Name,
156157
Args: toolUse.Input,
157158
Injected: false,
158159
})
160+
159161
}
160162

161163
// If no injected tool calls, we're done.
@@ -188,6 +190,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
188190
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
189191
InterceptionID: i.ID().String(),
190192
MsgID: resp.ID,
193+
ToolCallID: tc.ID,
191194
ServerURL: &tool.ServerURL,
192195
Tool: tool.Name,
193196
Args: tc.Input,

0 commit comments

Comments
 (0)