diff --git a/bridge.go b/bridge.go index 463eeb7..355a61e 100644 --- a/bridge.go +++ b/bridge.go @@ -29,21 +29,6 @@ const ( recordingTimeout = time.Second * 5 ) -const ( - // Possible values for the "client" field in interception records. - // Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44 - ClientClaude = "Claude Code" - ClientCodex = "Codex" - ClientCursor = "Cursor" - ClientCopilotVSC = "GitHub Copilot (VS Code)" - ClientCopilotCLI = "GitHub Copilot (CLI)" - ClientKilo = "Kilo Code" - ClientMux = "Mux" - ClientRoo = "Roo Code" - ClientZed = "Zed" - ClientUnknown = "Unknown" -) - // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; // specifically, OpenAI's & Anthropic's at present. // RequestBridge intercepts requests to - and responses from - these upstream services to provide @@ -167,6 +152,11 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() + // We execute this before CreateInterceptor since the interceptors + // read the request body and don't reset them. + client := guessClient(r) + sessionID := guessSessionID(client, r) + interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) @@ -203,13 +193,14 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC interceptor.Setup(logger, asyncRecorder, mcpProxy) if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{ - Client: guessClient(r), ID: interceptor.ID().String(), InitiatorID: actor.ID, Metadata: actor.Metadata, Model: interceptor.Model(), Provider: p.Name(), UserAgent: r.UserAgent(), + Client: string(client), + ClientSessionID: sessionID, CorrelatingToolCallID: interceptor.CorrelatingToolCallID(), }); err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err)) @@ -338,34 +329,3 @@ func mergeContexts(base, other context.Context) context.Context { }() return ctx } - -// guessClient attempts to guess the client application from the request headers. -// Not all clients set proper user agent headers, so this is a best-effort approach. -// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101. -func guessClient(r *http.Request) string { - userAgent := strings.ToLower(r.UserAgent()) - originator := r.Header.Get("originator") - - // Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44 - switch { - case strings.HasPrefix(userAgent, "mux/"): - return ClientMux - case strings.HasPrefix(userAgent, "claude"): - return ClientClaude - case strings.HasPrefix(userAgent, "codex"): - return ClientCodex - case strings.HasPrefix(userAgent, "zed/"): - return ClientZed - case strings.HasPrefix(userAgent, "githubcopilotchat/"): - return ClientCopilotVSC - case strings.HasPrefix(userAgent, "copilot/"): - return ClientCopilotCLI - case strings.HasPrefix(userAgent, "kilo-code/") || originator == "kilo-code": - return ClientKilo - case strings.HasPrefix(userAgent, "roo-code/") || originator == "roo-code": - return ClientRoo - case r.Header.Get("x-cursor-client-version") != "": - return ClientCursor - } - return ClientUnknown -} diff --git a/bridge_integration_test.go b/bridge_integration_test.go index bf83264..60af5b1 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -549,7 +549,7 @@ func TestSimple(t *testing.T) { createRequest func(*testing.T, string, []byte) *http.Request expectedMsgID string userAgent string - expectedClient string + expectedClient aibridge.Client }{ { name: config.ProviderAnthropic, @@ -561,7 +561,7 @@ func TestSimple(t *testing.T) { createRequest: createAnthropicMessagesReq, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "claude-cli/2.0.67 (external, cli)", - expectedClient: aibridge.ClientClaude, + expectedClient: aibridge.ClientClaudeCode, }, { name: config.ProviderOpenAI, @@ -671,7 +671,7 @@ func TestSimple(t *testing.T) { interceptions := recorderClient.RecordedInterceptions() require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions) assert.Equal(t, tc.userAgent, interceptions[0].UserAgent) - assert.Equal(t, tc.expectedClient, interceptions[0].Client) + assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) recorderClient.VerifyAllInterceptionsEnded(t) }) diff --git a/bridge_test.go b/bridge_test.go index 2e58c0c..1709be1 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -104,103 +104,3 @@ func TestPassthroughRoutesForProviders(t *testing.T) { }) } } - -func TestGuessClient(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - userAgent string - headers map[string]string - wantClient string - }{ - { - name: "mux", - userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22", - wantClient: ClientMux, - }, - { - name: "claude_code", - userAgent: "claude-cli/2.0.67 (external, cli)", - wantClient: ClientClaude, - }, - { - name: "codex_cli", - userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef", - wantClient: ClientCodex, - }, - { - name: "zed", - userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", - wantClient: ClientZed, - }, - { - name: "github_copilot_vsc", - userAgent: "GitHubCopilotChat/0.37.2026011603", - wantClient: ClientCopilotVSC, - }, - { - name: "github_copilot_cli", - userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)", - wantClient: ClientCopilotCLI, - }, - { - name: "kilo_code_user_agent", - userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1", - wantClient: ClientKilo, - }, - { - name: "kilo_code_originator", - headers: map[string]string{"Originator": "kilo-code"}, - wantClient: ClientKilo, - }, - { - name: "roo_code_user_agent", - userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1", - wantClient: ClientRoo, - }, - { - name: "roo_code_originator", - headers: map[string]string{"Originator": "roo-code"}, - wantClient: ClientRoo, - }, - { - name: "cursor_x_cursor_client_version", - userAgent: "connect-es/1.6.1", - headers: map[string]string{"X-Cursor-client-version": "0.50.0"}, - wantClient: ClientCursor, - }, - { - name: "cursor_x_cursor_some_other_header", - headers: map[string]string{"x-cursor-client-version": "abc123"}, - wantClient: ClientCursor, - }, - { - name: "unknown_client", - userAgent: "ccclaude-cli/calude-with-wrong-prefix", - wantClient: ClientUnknown, - }, - { - name: "empty_user_agent", - userAgent: "", - wantClient: ClientUnknown, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - req, err := http.NewRequest(http.MethodGet, "", nil) - require.NoError(t, err) - - req.Header.Set("User-Agent", tt.userAgent) - for key, value := range tt.headers { - req.Header.Set(key, value) - } - - got := guessClient(req) - require.Equal(t, tt.wantClient, got) - }) - } -} diff --git a/client.go b/client.go new file mode 100644 index 0000000..f7da258 --- /dev/null +++ b/client.go @@ -0,0 +1,54 @@ +package aibridge + +import ( + "net/http" + "strings" +) + +type Client string + +const ( + // Possible values for the "client" field in interception records. + // Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44 + ClientClaudeCode Client = "Claude Code" + ClientCodex Client = "Codex" + ClientZed Client = "Zed" + ClientCopilotVSC Client = "GitHub Copilot (VS Code)" + ClientCopilotCLI Client = "GitHub Copilot (CLI)" + ClientKilo Client = "Kilo Code" + ClientMux Client = "Mux" + ClientRoo Client = "Roo Code" + ClientCursor Client = "Cursor" + ClientUnknown Client = "Unknown" +) + +// guessClient attempts to guess the client application from the request headers. +// Not all clients set proper user agent headers, so this is a best-effort approach. +// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101. +func guessClient(r *http.Request) Client { + userAgent := strings.ToLower(r.UserAgent()) + originator := r.Header.Get("originator") + + // Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44 + switch { + case strings.HasPrefix(userAgent, "mux/"): + return ClientMux + case strings.HasPrefix(userAgent, "claude"): + return ClientClaudeCode + case strings.HasPrefix(userAgent, "codex"): + return ClientCodex + case strings.HasPrefix(userAgent, "zed/"): + return ClientZed + case strings.HasPrefix(userAgent, "githubcopilotchat/"): + return ClientCopilotVSC + case strings.HasPrefix(userAgent, "copilot/"): + return ClientCopilotCLI + case strings.HasPrefix(userAgent, "kilo-code/") || originator == "kilo-code": + return ClientKilo + case strings.HasPrefix(userAgent, "roo-code/") || originator == "roo-code": + return ClientRoo + case r.Header.Get("x-cursor-client-version") != "": + return ClientCursor + } + return ClientUnknown +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..42188e1 --- /dev/null +++ b/client_test.go @@ -0,0 +1,108 @@ +package aibridge + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGuessClient(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userAgent string + headers map[string]string + wantClient Client + }{ + { + name: "mux", + userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22", + wantClient: ClientMux, + }, + { + name: "claude_code", + userAgent: "claude-cli/2.0.67 (external, cli)", + wantClient: ClientClaudeCode, + }, + { + name: "codex_cli", + userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef", + wantClient: ClientCodex, + }, + { + name: "zed", + userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", + wantClient: ClientZed, + }, + { + name: "github_copilot_vsc", + userAgent: "GitHubCopilotChat/0.37.2026011603", + wantClient: ClientCopilotVSC, + }, + { + name: "github_copilot_cli", + userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)", + wantClient: ClientCopilotCLI, + }, + { + name: "kilo_code_user_agent", + userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1", + wantClient: ClientKilo, + }, + { + name: "kilo_code_originator", + headers: map[string]string{"Originator": "kilo-code"}, + wantClient: ClientKilo, + }, + { + name: "roo_code_user_agent", + userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1", + wantClient: ClientRoo, + }, + { + name: "roo_code_originator", + headers: map[string]string{"Originator": "roo-code"}, + wantClient: ClientRoo, + }, + { + name: "cursor_x_cursor_client_version", + userAgent: "connect-es/1.6.1", + headers: map[string]string{"X-Cursor-client-version": "0.50.0"}, + wantClient: ClientCursor, + }, + { + name: "cursor_x_cursor_some_other_header", + headers: map[string]string{"x-cursor-client-version": "abc123"}, + wantClient: ClientCursor, + }, + { + name: "unknown_client", + userAgent: "ccclaude-cli/calude-with-wrong-prefix", + wantClient: ClientUnknown, + }, + { + name: "empty_user_agent", + userAgent: "", + wantClient: ClientUnknown, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + + req.Header.Set("User-Agent", tt.userAgent) + for key, value := range tt.headers { + req.Header.Set(key, value) + } + + got := guessClient(req) + require.Equal(t, tt.wantClient, got) + }) + } +} diff --git a/provider/openai.go b/provider/openai.go index 730fc68..43d6811 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -116,7 +116,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("read body: %w", err) } var req responses.ResponsesNewParamsWrapper - if err := json.Unmarshal(payload, &req); err != nil { + if err := json.Unmarshal(payload, &req); err != nil { // TODO: should probably change to json.NewDecoder. return nil, fmt.Errorf("unmarshal request body: %w", err) } if req.Stream { diff --git a/recorder/types.go b/recorder/types.go index 82c34d0..b33494d 100644 --- a/recorder/types.go +++ b/recorder/types.go @@ -26,13 +26,14 @@ type ToolArgs any type Metadata map[string]any type InterceptionRecord struct { - Client string ID string InitiatorID string Metadata Metadata Model string Provider string StartedAt time.Time + ClientSessionID *string + Client string UserAgent string CorrelatingToolCallID *string } diff --git a/responses_integration_test.go b/responses_integration_test.go index 4b82bfb..5cde9c9 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -45,7 +45,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectToolRecorded *recorder.ToolUsageRecord expectTokenUsage *recorder.TokenUsageRecord userAgent string - expectedClient string + expectedClient aibridge.Client }{ { name: "blocking_simple", @@ -63,7 +63,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { }, }, userAgent: "claude-cli/2.0.67 (external, cli)", - expectedClient: aibridge.ClientClaude, + expectedClient: aibridge.ClientClaudeCode, }, { name: "blocking_builtin_tool", @@ -369,7 +369,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Equal(t, intc.Provider, config.ProviderOpenAI) require.Equal(t, intc.Model, tc.expectModel) require.Equal(t, tc.userAgent, intc.UserAgent) - require.Equal(t, tc.expectedClient, intc.Client) + require.Equal(t, string(tc.expectedClient), intc.Client) recordedPrompts := mockRecorder.RecordedPromptUsages() if tc.expectPromptRecorded != "" { diff --git a/session.go b/session.go new file mode 100644 index 0000000..b00a929 --- /dev/null +++ b/session.go @@ -0,0 +1,84 @@ +package aibridge + +import ( + "bytes" + "io" + "net/http" + "regexp" + "strings" + + "github.com/coder/aibridge/utils" + "github.com/tidwall/gjson" +) + +var claudeCodePattern = regexp.MustCompile(`_session_(.+)$`) // Save compilation on each call. + +// guessSessionID attempts to retrieve a session ID which may have been sent by +// the client. We only attempt to retrieve sessions using methods recognized for +// the given client. +func guessSessionID(client Client, r *http.Request) *string { + switch client { + case ClientClaudeCode: + /* Claude Code adds the session ID into the `metadata.user_id` field in the JSON body. + { + ... + "metadata": { + "user_id": "user_{sha256}_account_{account_id}_session_{uuid_v4}" + }, + ... + } */ + payload, err := io.ReadAll(r.Body) + if err != nil { + // Failing silently is suitable here; if the body cannot be read, we won't be able to do much more. + return nil + } + _ = r.Body.Close() + + // Restore the request body. + r.Body = io.NopCloser(bytes.NewReader(payload)) + userID := gjson.GetBytes(payload, "metadata.user_id") + if !userID.Exists() { + return nil + } + + matches := claudeCodePattern.FindStringSubmatch(userID.String()) + if len(matches) < 2 { + return nil + } + return cleanRef(matches[1]) + case ClientCodex: + return cleanRef(r.Header.Get("session_id")) + case ClientMux: + return cleanRef(r.Header.Get("X-Mux-Workspace-Id")) + case ClientZed: + return nil // Zed does not send a session ID from Zed Agent or Text Thread. + case ClientCopilotVSC: + // This does not map precisely to what we consider a session, but it's close enough. + // Most other providers' equivalent of this would persist for the duration of a + // conversation; it does seem to persist across an agentic loop though, which is + // all we really need. + // + // There's also `vscode-sessionid` but that's persistent for the duration of the + // VS Code window. + return cleanRef(r.Header.Get("x-interaction-id")) + case ClientCopilotCLI: + return cleanRef(r.Header.Get("X-Client-Session-Id")) + case ClientKilo: + return cleanRef(r.Header.Get("X-KILOCODE-TASKID")) + case ClientRoo: + return nil // RooCode doesn't send a session ID. + case ClientCursor: + return nil // Cursor is not currently supported. + default: + return nil + } +} + +func cleanRef(str string) *string { + str = strings.TrimSpace(str) + if str == "" { + return nil + } + + return utils.PtrTo(str) +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..44305f9 --- /dev/null +++ b/session_test.go @@ -0,0 +1,178 @@ +package aibridge + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/coder/aibridge/utils" + "github.com/stretchr/testify/require" +) + +func TestGuessSessionID(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + client Client + body string + headers map[string]string + sessionID *string + }{ + // Claude Code. + { + name: "claude_code_with_valid_session", + client: ClientClaudeCode, + body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`, + sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"), + }, + { + name: "claude_code_missing_metadata", + client: ClientClaudeCode, + body: `{"model":"claude-3"}`, + }, + { + name: "claude_code_missing_user_id", + client: ClientClaudeCode, + body: `{"metadata":{}}`, + }, + { + name: "claude_code_user_id_without_session", + client: ClientClaudeCode, + body: `{"metadata":{"user_id":"user_abc123_account_456"}}`, + }, + { + name: "claude_code_empty_body", + client: ClientClaudeCode, + body: ``, + }, + { + name: "claude_code_invalid_json", + client: ClientClaudeCode, + body: `not json at all`, + }, + // Codex. + { + name: "codex_with_session_header", + client: ClientCodex, + headers: map[string]string{"session_id": "codex-session-123"}, + sessionID: utils.PtrTo("codex-session-123"), + }, + { + name: "codex_with_whitespace_in_header", + client: ClientCodex, + headers: map[string]string{"session_id": " codex-session-123 "}, + sessionID: utils.PtrTo("codex-session-123"), + }, + { + name: "codex_without_session_header", + client: ClientCodex, + }, + // Other clients shouldn't use others' logic. + { + name: "unknown_client_returns_empty", + client: ClientUnknown, + body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`, + }, + { + name: "zed_returns_empty", + client: ClientZed, + headers: map[string]string{"session_id": "zed-session"}, + body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`, + }, + // Mux. + { + name: "mux_with_workspace_header", + client: ClientMux, + headers: map[string]string{"X-Mux-Workspace-Id": "ws-abc-123"}, + sessionID: utils.PtrTo("ws-abc-123"), + }, + { + name: "mux_without_workspace_header", + client: ClientMux, + }, + // Copilot VS Code. + { + name: "copilot_vsc_with_interaction_id", + client: ClientCopilotVSC, + headers: map[string]string{"x-interaction-id": "interaction-xyz"}, + sessionID: utils.PtrTo("interaction-xyz"), + }, + { + name: "copilot_vsc_without_interaction_id", + client: ClientCopilotVSC, + }, + // Copilot CLI. + { + name: "copilot_cli_with_session_header", + client: ClientCopilotCLI, + headers: map[string]string{"X-Client-Session-Id": "cli-sess-456"}, + sessionID: utils.PtrTo("cli-sess-456"), + }, + { + name: "copilot_cli_without_session_header", + client: ClientCopilotCLI, + }, + // Kilo. + { + name: "kilo_with_task_id", + client: ClientKilo, + headers: map[string]string{"X-KILOCODE-TASKID": "task-789"}, + sessionID: utils.PtrTo("task-789"), + }, + { + name: "kilo_without_task_id", + client: ClientKilo, + }, + // Roo. + { + name: "roo_returns_empty", + client: ClientRoo, + }, + // Cursor. + { + name: "cursor_returns_empty", + client: ClientCursor, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + body := tc.body + req, err := http.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(body)) + require.NoError(t, err) + + for key, value := range tc.headers { + req.Header.Set(key, value) + } + + got := guessSessionID(tc.client, req) + require.Equal(t, tc.sessionID, got) + + // Verify the body was restored and can be read again. + restored, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, body, string(restored)) + }) + } +} + +func TestUnreadableBody(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodPost, "http://localhost", &errReader{}) + require.NoError(t, err) + + got := guessSessionID(ClientClaudeCode, req) + require.Nil(t, got) +} + +// errReader is an io.Reader that always returns an error. +type errReader struct{} + +func (e *errReader) Read([]byte) (int, error) { + return 0, io.ErrUnexpectedEOF +}