Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 7 additions & 47 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,6 @@ const (
recordingTimeout = time.Second * 5
)

const (
// Possible values for the "client" field in interception records.
// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
ClientClaude = "Claude Code"
ClientCodex = "Codex"
ClientCursor = "Cursor"
ClientCopilotVSC = "GitHub Copilot (VS Code)"
ClientCopilotCLI = "GitHub Copilot (CLI)"
ClientKilo = "Kilo Code"
ClientMux = "Mux"
ClientRoo = "Roo Code"
ClientZed = "Zed"
ClientUnknown = "Unknown"
)

// RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs;
// specifically, OpenAI's & Anthropic's at present.
// RequestBridge intercepts requests to - and responses from - these upstream services to provide
Expand Down Expand Up @@ -167,6 +152,11 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
ctx, span := tracer.Start(r.Context(), "Intercept")
defer span.End()

// We execute this before CreateInterceptor since the interceptors
// read the request body and don't reset them.
client := guessClient(r)
sessionID := guessSessionID(client, r)

interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer)
if err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err))
Expand Down Expand Up @@ -203,13 +193,14 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
interceptor.Setup(logger, asyncRecorder, mcpProxy)

if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{
Client: guessClient(r),
ID: interceptor.ID().String(),
InitiatorID: actor.ID,
Metadata: actor.Metadata,
Model: interceptor.Model(),
Provider: p.Name(),
UserAgent: r.UserAgent(),
Client: string(client),
ClientSessionID: sessionID,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we maybe add this to an integration test? The unit tests for guessSessionID are great, but it's missing a test for how it integrates with interceptions. Maybe just an assertion that checks the ClientSessionID matches the expected value

CorrelatingToolCallID: interceptor.CorrelatingToolCallID(),
}); err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err))
Expand Down Expand Up @@ -338,34 +329,3 @@ func mergeContexts(base, other context.Context) context.Context {
}()
return ctx
}

// guessClient attempts to guess the client application from the request headers.
// Not all clients set proper user agent headers, so this is a best-effort approach.
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
func guessClient(r *http.Request) string {
userAgent := strings.ToLower(r.UserAgent())
originator := r.Header.Get("originator")

// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
switch {
case strings.HasPrefix(userAgent, "mux/"):
return ClientMux
case strings.HasPrefix(userAgent, "claude"):
return ClientClaude
case strings.HasPrefix(userAgent, "codex"):
return ClientCodex
case strings.HasPrefix(userAgent, "zed/"):
return ClientZed
case strings.HasPrefix(userAgent, "githubcopilotchat/"):
return ClientCopilotVSC
case strings.HasPrefix(userAgent, "copilot/"):
return ClientCopilotCLI
case strings.HasPrefix(userAgent, "kilo-code/") || originator == "kilo-code":
return ClientKilo
case strings.HasPrefix(userAgent, "roo-code/") || originator == "roo-code":
return ClientRoo
case r.Header.Get("x-cursor-client-version") != "":
return ClientCursor
}
return ClientUnknown
}
6 changes: 3 additions & 3 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ func TestSimple(t *testing.T) {
createRequest func(*testing.T, string, []byte) *http.Request
expectedMsgID string
userAgent string
expectedClient string
expectedClient aibridge.Client
}{
{
name: config.ProviderAnthropic,
Expand All @@ -561,7 +561,7 @@ func TestSimple(t *testing.T) {
createRequest: createAnthropicMessagesReq,
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
userAgent: "claude-cli/2.0.67 (external, cli)",
expectedClient: aibridge.ClientClaude,
expectedClient: aibridge.ClientClaudeCode,
},
{
name: config.ProviderOpenAI,
Expand Down Expand Up @@ -671,7 +671,7 @@ func TestSimple(t *testing.T) {
interceptions := recorderClient.RecordedInterceptions()
require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions)
assert.Equal(t, tc.userAgent, interceptions[0].UserAgent)
assert.Equal(t, tc.expectedClient, interceptions[0].Client)
assert.Equal(t, string(tc.expectedClient), interceptions[0].Client)

recorderClient.VerifyAllInterceptionsEnded(t)
})
Expand Down
100 changes: 0 additions & 100 deletions bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,103 +104,3 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
})
}
}

func TestGuessClient(t *testing.T) {
t.Parallel()

tests := []struct {
name string
userAgent string
headers map[string]string
wantClient string
}{
{
name: "mux",
userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22",
wantClient: ClientMux,
},
{
name: "claude_code",
userAgent: "claude-cli/2.0.67 (external, cli)",
wantClient: ClientClaude,
},
{
name: "codex_cli",
userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef",
wantClient: ClientCodex,
},
{
name: "zed",
userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)",
wantClient: ClientZed,
},
{
name: "github_copilot_vsc",
userAgent: "GitHubCopilotChat/0.37.2026011603",
wantClient: ClientCopilotVSC,
},
{
name: "github_copilot_cli",
userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)",
wantClient: ClientCopilotCLI,
},
{
name: "kilo_code_user_agent",
userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1",
wantClient: ClientKilo,
},
{
name: "kilo_code_originator",
headers: map[string]string{"Originator": "kilo-code"},
wantClient: ClientKilo,
},
{
name: "roo_code_user_agent",
userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1",
wantClient: ClientRoo,
},
{
name: "roo_code_originator",
headers: map[string]string{"Originator": "roo-code"},
wantClient: ClientRoo,
},
{
name: "cursor_x_cursor_client_version",
userAgent: "connect-es/1.6.1",
headers: map[string]string{"X-Cursor-client-version": "0.50.0"},
wantClient: ClientCursor,
},
{
name: "cursor_x_cursor_some_other_header",
headers: map[string]string{"x-cursor-client-version": "abc123"},
wantClient: ClientCursor,
},
{
name: "unknown_client",
userAgent: "ccclaude-cli/calude-with-wrong-prefix",
wantClient: ClientUnknown,
},
{
name: "empty_user_agent",
userAgent: "",
wantClient: ClientUnknown,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

req, err := http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)

req.Header.Set("User-Agent", tt.userAgent)
for key, value := range tt.headers {
req.Header.Set(key, value)
}

got := guessClient(req)
require.Equal(t, tt.wantClient, got)
})
}
}
54 changes: 54 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package aibridge

import (
"net/http"
"strings"
)

type Client string

const (
// Possible values for the "client" field in interception records.
// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
ClientClaudeCode Client = "Claude Code"
ClientCodex Client = "Codex"
ClientZed Client = "Zed"
ClientCopilotVSC Client = "GitHub Copilot (VS Code)"
ClientCopilotCLI Client = "GitHub Copilot (CLI)"
ClientKilo Client = "Kilo Code"
ClientMux Client = "Mux"
ClientRoo Client = "Roo Code"
ClientCursor Client = "Cursor"
ClientUnknown Client = "Unknown"
)

// guessClient attempts to guess the client application from the request headers.
// Not all clients set proper user agent headers, so this is a best-effort approach.
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
func guessClient(r *http.Request) Client {
userAgent := strings.ToLower(r.UserAgent())
originator := r.Header.Get("originator")

// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
switch {
case strings.HasPrefix(userAgent, "mux/"):
return ClientMux
case strings.HasPrefix(userAgent, "claude"):
return ClientClaudeCode
case strings.HasPrefix(userAgent, "codex"):
return ClientCodex
case strings.HasPrefix(userAgent, "zed/"):
return ClientZed
case strings.HasPrefix(userAgent, "githubcopilotchat/"):
return ClientCopilotVSC
case strings.HasPrefix(userAgent, "copilot/"):
return ClientCopilotCLI
case strings.HasPrefix(userAgent, "kilo-code/") || originator == "kilo-code":
return ClientKilo
case strings.HasPrefix(userAgent, "roo-code/") || originator == "roo-code":
return ClientRoo
case r.Header.Get("x-cursor-client-version") != "":
return ClientCursor
}
return ClientUnknown
}
108 changes: 108 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package aibridge

import (
"net/http"
"testing"

"github.com/stretchr/testify/require"
)

func TestGuessClient(t *testing.T) {
t.Parallel()

tests := []struct {
name string
userAgent string
headers map[string]string
wantClient Client
}{
{
name: "mux",
userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22",
wantClient: ClientMux,
},
{
name: "claude_code",
userAgent: "claude-cli/2.0.67 (external, cli)",
wantClient: ClientClaudeCode,
},
{
name: "codex_cli",
userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef",
wantClient: ClientCodex,
},
{
name: "zed",
userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)",
wantClient: ClientZed,
},
{
name: "github_copilot_vsc",
userAgent: "GitHubCopilotChat/0.37.2026011603",
wantClient: ClientCopilotVSC,
},
{
name: "github_copilot_cli",
userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)",
wantClient: ClientCopilotCLI,
},
{
name: "kilo_code_user_agent",
userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1",
wantClient: ClientKilo,
},
{
name: "kilo_code_originator",
headers: map[string]string{"Originator": "kilo-code"},
wantClient: ClientKilo,
},
{
name: "roo_code_user_agent",
userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1",
wantClient: ClientRoo,
},
{
name: "roo_code_originator",
headers: map[string]string{"Originator": "roo-code"},
wantClient: ClientRoo,
},
{
name: "cursor_x_cursor_client_version",
userAgent: "connect-es/1.6.1",
headers: map[string]string{"X-Cursor-client-version": "0.50.0"},
wantClient: ClientCursor,
},
{
name: "cursor_x_cursor_some_other_header",
headers: map[string]string{"x-cursor-client-version": "abc123"},
wantClient: ClientCursor,
},
{
name: "unknown_client",
userAgent: "ccclaude-cli/calude-with-wrong-prefix",
wantClient: ClientUnknown,
},
{
name: "empty_user_agent",
userAgent: "",
wantClient: ClientUnknown,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

req, err := http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)

req.Header.Set("User-Agent", tt.userAgent)
for key, value := range tt.headers {
req.Header.Set(key, value)
}

got := guessClient(req)
require.Equal(t, tt.wantClient, got)
})
}
}
2 changes: 1 addition & 1 deletion provider/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
return nil, fmt.Errorf("read body: %w", err)
}
var req responses.ResponsesNewParamsWrapper
if err := json.Unmarshal(payload, &req); err != nil {
if err := json.Unmarshal(payload, &req); err != nil { // TODO: should probably change to json.NewDecoder.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: raise issue for follow-up; needs a scaletest done afterwards.

return nil, fmt.Errorf("unmarshal request body: %w", err)
}
if req.Stream {
Expand Down
Loading