diff --git a/README.md b/README.md index 9ccd942..abb0fda 100644 --- a/README.md +++ b/README.md @@ -91,11 +91,16 @@ Important `proxy.json` fields: | `local_token` | generated | Required bearer token for local proxy requests. | | `pinned_claude_account` | unset | Optional Claude account email or UUID to force proxy selection. | | `diagnostics_log` | unset | Optional JSONL routing diagnostics log path for advanced local debugging. | +| `payload_diagnostics_log` | unset | Optional JSONL payload diagnostics log path. **Disabled by default.** See warning below. | | `headroom` | `false` | Enables the headroom compression bridge when true. | | `headroom_mode` | `cache` | Compression strategy when set; valid values are `cache` and `token`. | Routing diagnostics are disabled by default. To enable them, set `diagnostics_log` in `proxy.json` to a local file path and restart the proxy. The log is append-only JSONL containing redacted route metadata such as method, path, provider, route kind, status, latency, selected-account hint, failover flag, and safe error code. It is intended for advanced local debugging and UAT, and enabling it does not change routing policy. +**Payload diagnostics** (`payload_diagnostics_log`) are disabled by default. To enable them, set `payload_diagnostics_log` in `proxy.json` to a local file path and **restart the proxy** (hot reload is not supported). Each entry in the log is a JSONL record with fields including `time`, `method`, `path`, `provider`, `route_kind`, `model`, `client_kind`, `session_key`, `session_source`, `session_signal`, `frame_index`, `body_bytes`, and `body`. Codex WebSocket client text frames are logged with `route_kind: codex_websocket_frame` so native Codex sessions can expose signals such as new session, continuation, long session, clear, and compact transitions. + +> **WARNING:** The payload diagnostics log contains **raw request bodies** including prompts, tool inputs, system prompts, compact summaries, and message content. It is unsafe to share without careful review. The log does not record headers, tokens, or credential values by itself, but the body content may contain sensitive information. The `session_key` and `session_source` fields are derived correlation metadata — `session_key` is a deterministic 12-hex-character hash of a session identifier (never the raw value), and `session_source` identifies which signal was used (e.g. `x-claude-code-session-id`, `session_id`, `x-codex-window-id`, `body:conversation_id`, `body:thread_id`, `body:previous_response_id`, `ws:thread_id`, `ws:previous_response_id`). Payload diagnostics can derive the key from known conversation/thread identifiers in JSON request bodies or Codex WebSocket frames when otherwise-identical local client sessions have no differentiating session header. + ## Model Registry `cq models` manages the local model registry used by the proxy, Claude Code model caches, and Codex model cache integration. diff --git a/cmd/cq/proxy.go b/cmd/cq/proxy.go index 70573e2..3798121 100644 --- a/cmd/cq/proxy.go +++ b/cmd/cq/proxy.go @@ -164,7 +164,8 @@ func runProxyStart(opts proxyCommandOptions) error { claudeProvider := claudeprov.New(refreshClient) quotaCache := proxy.NewQuotaCache(claudeProvider.FetchAccountUsage, cache.DefaultDir()) baseSelector := proxy.NewAccountSelector(discover, activeEmail, quotaCache) - selector := proxy.NewPinnedClaudeSelector(baseSelector, discover, cfg.PinnedClaudeAccount, quotaCache) + affinitySelector := proxy.NewSessionAffinitySelector(baseSelector, discover, quotaCache) + selector := proxy.NewPinnedClaudeSelector(affinitySelector, discover, cfg.PinnedClaudeAccount, quotaCache) selector.SetPinExpireFunc(clearPersistedClaudePin) if cfg.PinnedClaudeAccount != "" { fmt.Fprintf(os.Stderr, "cq: pinned claude account: %s\n", cfg.PinnedClaudeAccount) @@ -340,6 +341,21 @@ func runProxyStart(opts proxyCommandOptions) error { } } + var payloadDiag *proxy.PayloadWriter + if cfg.PayloadDiagnosticsLog != "" { + payloadDiag, err = proxy.OpenPayloadWriter(cfg.PayloadDiagnosticsLog) + if err != nil { + fmt.Fprintf(os.Stderr, "cq: payload diagnostics: %v (continuing without payload diagnostics)\n", err) + } else { + fmt.Fprintf(os.Stderr, "cq: payload diagnostics enabled — WARNING: log contains raw request bodies including prompts and message content\n") + defer func() { + if err := payloadDiag.Close(); err != nil { + fmt.Fprintf(os.Stderr, "cq: payload diagnostics: close: %v\n", err) + } + }() + } + } + srv := &proxy.Server{ Config: cfg, Selector: selector, @@ -350,6 +366,7 @@ func runProxyStart(opts proxyCommandOptions) error { CodexUpgradeTransport: codexUpgradeTransport, Headroom: headroom, Diag: diagnostics, + PayloadDiag: payloadDiag, HeadroomMode: resolvedMode, Catalog: catalog, Refresher: proxyRefresher, diff --git a/internal/proxy/codex_compact.go b/internal/proxy/codex_compact.go index 47e0ca8..d3fa470 100644 --- a/internal/proxy/codex_compact.go +++ b/internal/proxy/codex_compact.go @@ -74,6 +74,7 @@ func (s *Server) handleNativeCodexCompact(w http.ResponseWriter, r *http.Request Error: rec.diagnosticsError(), } event.applyRouteDiagnostics(routeDiag) + event.applySessionCorrelation(r.Header) s.emitDiagnostics(event) }() } @@ -98,6 +99,24 @@ func (s *Server) handleNativeCodexCompact(w http.ResponseWriter, r *http.Request model = extractModel(body) fmt.Fprintf(os.Stderr, "cq: route POST %s model=%q provider=codex (native compact)\n", requestPath, model) + // Emit payload diagnostics before forwarding. + if s.PayloadDiag != nil { + sessionKey, sessionSource := payloadSessionCorrelation(r.Header, body) + s.emitPayloadDiagnostics(PayloadEvent{ + Time: time.Now().UTC(), + Method: r.Method, + Path: r.URL.Path, + Provider: "codex", + RouteKind: "codex_compact", + Model: model, + ClientKind: clientKindFromUserAgent(r.Header.Get("User-Agent")), + SessionKey: sessionKey, + SessionSource: sessionSource, + BodyBytes: len(body), + Body: encodeBody(body), + }) + } + // Build upstream request targeting /responses/compact (no headroom applied). upstreamURL := s.Config.CodexUpstream + "/responses/compact" upReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) diff --git a/internal/proxy/codex_selector.go b/internal/proxy/codex_selector.go index f6a3539..bf6c4a4 100644 --- a/internal/proxy/codex_selector.go +++ b/internal/proxy/codex_selector.go @@ -58,26 +58,25 @@ func (s *codexSelector) Select(ctx context.Context, exclude ...string) (*codex.C } func (s *codexSelector) selectAccount(accounts []codex.CodexAccount, excludeSet map[string]bool, requestedModel string, requireCompatible bool) *codex.CodexAccount { + var best *codex.CodexAccount + bestRemaining := -1 + for i := range accounts { a := &accounts[i] if !s.isEligible(a, excludeSet, requestedModel, requireCompatible) { continue } - if a.IsActive { - result := *a - return &result + remaining := s.accountRemaining(a) + if s.betterCandidate(a, remaining, best, bestRemaining) { + best = a + bestRemaining = remaining } } - - for i := range accounts { - a := &accounts[i] - if !s.isEligible(a, excludeSet, requestedModel, requireCompatible) { - continue - } - result := *a - return &result + if best == nil { + return nil } - return nil + result := *best + return &result } func (s *codexSelector) isEligible(a *codex.CodexAccount, excludeSet map[string]bool, requestedModel string, requireCompatible bool) bool { @@ -90,6 +89,38 @@ func (s *codexSelector) isEligible(a *codex.CodexAccount, excludeSet map[string] return true } +func (s *codexSelector) accountRemaining(a *codex.CodexAccount) int { + if s.quota == nil { + return -1 + } + snap, ok := s.snapshot(a) + if !ok || time.Since(snap.FetchedAt) > transientQuotaMaxAge { + return -1 + } + return snap.Result.MinRemainingPct() +} + +func (s *codexSelector) betterCandidate(candidate *codex.CodexAccount, candidateRemaining int, current *codex.CodexAccount, currentRemaining int) bool { + if current == nil { + return true + } + if candidateRemaining >= 0 || currentRemaining >= 0 { + if candidateRemaining != currentRemaining { + if candidateRemaining == 0 && currentRemaining < 0 { + return false + } + if currentRemaining == 0 && candidateRemaining < 0 { + return true + } + return candidateRemaining > currentRemaining + } + } + if candidate.IsActive != current.IsActive { + return candidate.IsActive + } + return false +} + func codexRequestedModel(ctx context.Context) string { if ctx == nil { return "" @@ -116,10 +147,7 @@ func (s *codexSelector) hasQuota(a *codex.CodexAccount) bool { if s.quota == nil { return true } - snap, ok := s.quota.Snapshot(a.AccountID) - if !ok { - snap, ok = s.quota.Snapshot(a.Email) - } + snap, ok := s.snapshot(a) if !ok { return true } @@ -132,6 +160,14 @@ func (s *codexSelector) hasQuota(a *codex.CodexAccount) bool { return snap.Result.MinRemainingPct() != 0 } +func (s *codexSelector) snapshot(a *codex.CodexAccount) (QuotaSnapshot, bool) { + snap, ok := s.quota.Snapshot(a.AccountID) + if !ok { + snap, ok = s.quota.Snapshot(a.Email) + } + return snap, ok +} + func codexAcctExcluded(a *codex.CodexAccount, excludeSet map[string]bool) bool { return (a.Email != "" && excludeSet[a.Email]) || (a.AccountID != "" && excludeSet[a.AccountID]) || diff --git a/internal/proxy/codex_selector_test.go b/internal/proxy/codex_selector_test.go index 4aa85b6..c648269 100644 --- a/internal/proxy/codex_selector_test.go +++ b/internal/proxy/codex_selector_test.go @@ -176,6 +176,28 @@ func TestCodexSelector_SkipsExhaustedAccounts(t *testing.T) { } } +func TestCodexSelector_PrefersHigherQuotaOverActiveLowQuota(t *testing.T) { + now := time.Now() + quotaReader := stubQuotaReader{ + "low": {Result: quota.Result{Windows: map[quota.WindowName]quota.Window{quota.Window5Hour: {RemainingPct: 5}}}, FetchedAt: now}, + "high": {Result: quota.Result{Windows: map[quota.WindowName]quota.Window{quota.Window5Hour: {RemainingPct: 80}}}, FetchedAt: now}, + } + sel := NewCodexSelector(func() []codex.CodexAccount { + return []codex.CodexAccount{ + {AccountID: "low", Email: "low@test.com", AccessToken: "t1", IsActive: true}, + {AccountID: "high", Email: "high@test.com", AccessToken: "t2", IsActive: false}, + } + }, quotaReader) + + acct, err := sel.Select(context.Background()) + if err != nil { + t.Fatalf("Select error: %v", err) + } + if acct == nil || acct.Email != "high@test.com" { + t.Fatalf("got %+v, want high@test.com", acct) + } +} + func TestCodexSelector_DoesNotSwitchWhenAllAccountsExhausted(t *testing.T) { now := time.Now() quotaReader := stubQuotaReader{ diff --git a/internal/proxy/config.go b/internal/proxy/config.go index b9c78cd..6bff8da 100644 --- a/internal/proxy/config.go +++ b/internal/proxy/config.go @@ -37,6 +37,13 @@ type Config struct { // a specific account identified by email or AccountUUID. Omitted when empty. PinnedClaudeAccount string `json:"pinned_claude_account,omitempty"` DiagnosticsLog string `json:"diagnostics_log,omitempty"` + // PayloadDiagnosticsLog is the optional path to a JSONL file for payload + // diagnostics. When set, the proxy logs request body metadata (including raw + // request bodies) for every buffered request. Disabled by default. + // WARNING: this log contains raw request bodies including prompts, tool + // inputs, system prompts, compact summaries, and message content. Do not + // share without review. Requires a proxy restart to take effect. + PayloadDiagnosticsLog string `json:"payload_diagnostics_log,omitempty"` } // ResolvedHeadroomMode returns the effective HeadroomMode for this config. diff --git a/internal/proxy/config_test.go b/internal/proxy/config_test.go index 53bffd8..70b21c7 100644 --- a/internal/proxy/config_test.go +++ b/internal/proxy/config_test.go @@ -59,6 +59,58 @@ func TestConfigDiagnosticsLogDefaultDisabled(t *testing.T) { } } +func TestConfigPayloadDiagnosticsLogJSONRoundTrip(t *testing.T) { + cfg := Config{ + Port: DefaultPort, + ClaudeUpstream: DefaultUpstream, + CodexUpstream: DefaultCodexUpstream, + LocalToken: "tok", + PayloadDiagnosticsLog: "/tmp/cq-payloads.jsonl", + } + data, err := json.Marshal(cfg) + if err != nil { + t.Fatal(err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if string(raw["payload_diagnostics_log"]) != `"/tmp/cq-payloads.jsonl"` { + t.Fatalf("payload_diagnostics_log = %s, want configured path in %s", raw["payload_diagnostics_log"], data) + } + + var roundTrip Config + if err := json.Unmarshal(data, &roundTrip); err != nil { + t.Fatal(err) + } + if roundTrip.PayloadDiagnosticsLog != cfg.PayloadDiagnosticsLog { + t.Fatalf("PayloadDiagnosticsLog = %q, want %q", roundTrip.PayloadDiagnosticsLog, cfg.PayloadDiagnosticsLog) + } +} + +func TestConfigPayloadDiagnosticsLogDefaultDisabled(t *testing.T) { + var cfg Config + if err := json.Unmarshal([]byte(`{"port":19280,"local_token":"tok"}`), &cfg); err != nil { + t.Fatal(err) + } + if cfg.PayloadDiagnosticsLog != "" { + t.Fatalf("PayloadDiagnosticsLog = %q, want empty", cfg.PayloadDiagnosticsLog) + } + + data, err := json.Marshal(Config{Port: DefaultPort, LocalToken: "tok"}) + if err != nil { + t.Fatal(err) + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if _, ok := raw["payload_diagnostics_log"]; ok { + t.Fatalf("payload_diagnostics_log should be omitted when empty: %s", data) + } +} + func TestConfigDiagnosticsLogPersisted(t *testing.T) { configHome := t.TempDir() t.Setenv("XDG_CONFIG_HOME", configHome) diff --git a/internal/proxy/diag.go b/internal/proxy/diag.go index 6ba571a..eabb280 100644 --- a/internal/proxy/diag.go +++ b/internal/proxy/diag.go @@ -5,24 +5,29 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "net/http" "os" + "sort" + "strings" "sync" "time" ) type RouteEvent struct { - Time time.Time `json:"time"` - Method string `json:"method"` - Path string `json:"path"` - Provider string `json:"provider"` - RouteKind string `json:"route_kind,omitempty"` - Model string `json:"model,omitempty"` - AccountHint string `json:"account_hint,omitempty"` - PinActive bool `json:"pin_active,omitempty"` - Failover bool `json:"failover,omitempty"` - StatusCode int `json:"status_code,omitempty"` - LatencyMS int64 `json:"latency_ms,omitempty"` - Error string `json:"error,omitempty"` + Time time.Time `json:"time"` + Method string `json:"method"` + Path string `json:"path"` + Provider string `json:"provider"` + RouteKind string `json:"route_kind,omitempty"` + Model string `json:"model,omitempty"` + AccountHint string `json:"account_hint,omitempty"` + PinActive bool `json:"pin_active,omitempty"` + Failover bool `json:"failover,omitempty"` + StatusCode int `json:"status_code,omitempty"` + LatencyMS int64 `json:"latency_ms,omitempty"` + Error string `json:"error,omitempty"` + SessionKey string `json:"session_key,omitempty"` + SessionSource string `json:"session_source,omitempty"` } type routeDiagnosticsContextKey struct{} @@ -76,6 +81,13 @@ func (event *RouteEvent) applyRouteDiagnostics(diag *routeDiagnostics) { } } +func (event *RouteEvent) applySessionCorrelation(headers http.Header) { + if event == nil { + return + } + event.SessionKey, event.SessionSource = sessionCorrelation(headers) +} + func redactedAccountHint(prefix string, identifiers ...string) string { for _, identifier := range identifiers { if identifier == "" { @@ -87,12 +99,242 @@ func redactedAccountHint(prefix string, identifiers ...string) string { return "" } -type DiagnosticsWriter struct { +// sessionCorrelation derives a stable, non-secret session key and source label +// from request headers. It never exposes raw header values; all keys are +// truncated SHA-256 hashes. Authorization, cookies, API keys, local proxy +// tokens, emails, and account UUIDs are never used. +// +// Priority: +// 1. X-Claude-Code-Session-Id → "claude-session:<12 hex>" source "x-claude-code-session-id" +// 2. session_id / Session_id → "codex-session:<12 hex>" source "session_id" +// 3. X-Codex-Window-Id → "codex-window:<12 hex>" source "x-codex-window-id" +// 4. stable non-secret headers → "unknown-client:<12 hex>" source "unknown-client" +// 5. nothing → "" source "none" +func sessionCorrelation(headers http.Header) (key string, source string) { + return headerSessionCorrelation(headers, true) +} + +func payloadSessionCorrelation(headers http.Header, body []byte) (key string, source string) { + if key, source := headerSessionCorrelation(headers, false); key != "" { + return key, source + } + if key, source := bodySessionCorrelation(body); key != "" { + return key, source + } + return headerSessionCorrelation(headers, true) +} + +func headerSessionCorrelation(headers http.Header, includeUnknownClient bool) (key string, source string) { + // 1. Claude Code session ID + if v := headers.Get("X-Claude-Code-Session-Id"); v != "" { + return hashPrefix("claude-session", v), "x-claude-code-session-id" + } + + // 2. Codex session_id — http.CanonicalHeaderKey("session_id") = "Session_id". + // Both spellings canonicalize to "Session_id", so one .Get is sufficient. + if v := headers.Get("Session_id"); v != "" { + return hashPrefix("codex-session", v), "session_id" + } + + // 3. Codex window ID + if v := headers.Get("X-Codex-Window-Id"); v != "" { + return hashPrefix("codex-window", v), "x-codex-window-id" + } + + if !includeUnknownClient { + return "", "none" + } + + // 4. Stable non-secret client fingerprint from User-Agent + known safe headers. + // Deliberately excludes Authorization, Cookie, x-api-key, local token values. + var parts []string + if ua := headers.Get("User-Agent"); ua != "" { + parts = append(parts, ua) + } + for _, safe := range []string{"X-Stainless-Runtime", "X-Stainless-Runtime-Version", "X-Stainless-Lang"} { + if v := headers.Get(safe); v != "" { + parts = append(parts, safe+"="+v) + } + } + if len(parts) > 0 { + combined := strings.Join(parts, "|") + return hashPrefix("unknown-client", combined), "unknown-client" + } + + return "", "none" +} + +func bodySessionCorrelation(body []byte) (key string, source string) { + var value any + if err := json.Unmarshal(body, &value); err != nil { + return "", "none" + } + for _, field := range []string{"conversation_id", "thread_id", "session_id", "response_id", "previous_response_id", "parent_response_id"} { + if v := findStringField(value, field); v != "" { + return hashPrefix("body-session", field+":"+v), "body:" + field + } + } + return "", "none" +} + +func findStringField(value any, field string) string { + switch v := value.(type) { + case map[string]any: + if raw, ok := v[field]; ok { + if s, ok := raw.(string); ok && s != "" { + return s + } + } + keys := make([]string, 0, len(v)) + for key := range v { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + if key == field { + continue + } + if s := findStringField(v[key], field); s != "" { + return s + } + } + case []any: + for _, child := range v { + if s := findStringField(child, field); s != "" { + return s + } + } + } + return "" +} + +func codexWebSocketFrameCorrelation(headers http.Header, frame []byte) (key string, source string, signal string) { + if key, source := headerSessionCorrelation(headers, false); key != "" { + return key, source, codexWebSocketFrameSignal(frame) + } + if key, source := codexWebSocketFrameSession(frame); key != "" { + return key, source, codexWebSocketFrameSignal(frame) + } + key, source = headerSessionCorrelation(headers, true) + return key, source, codexWebSocketFrameSignal(frame) +} + +func codexWebSocketFrameSession(frame []byte) (key string, source string) { + var value any + if err := json.Unmarshal(frame, &value); err != nil { + return "", "none" + } + for _, field := range []string{"thread_id", "conversation_id", "session_id", "response_id", "previous_response_id", "parent_response_id"} { + if v := findStringField(value, field); v != "" { + return hashPrefix("ws-session", field+":"+v), "ws:" + field + } + } + return "", "none" +} + +func codexWebSocketFrameSignal(frame []byte) string { + var payload struct { + Method string `json:"method"` + Params any `json:"params"` + } + if err := json.Unmarshal(frame, &payload); err != nil { + return "unknown" + } + method := strings.ToLower(payload.Method) + switch { + case strings.Contains(method, "compact"): + return "compact_transition" + case strings.Contains(method, "clear") || strings.Contains(method, "reset"): + return "clear_transition" + case method == "thread/start" || strings.Contains(method, "start"): + return "new_session" + case findStringField(payload.Params, "previous_response_id") != "" || findStringField(payload.Params, "parent_response_id") != "": + return "continuation" + case countMessages(payload.Params) >= 10: + return "long_session" + default: + return "unknown" + } +} + +func countMessages(value any) int { + switch v := value.(type) { + case map[string]any: + if raw, ok := v["messages"]; ok { + if messages, ok := raw.([]any); ok { + return len(messages) + } + } + if raw, ok := v["input"]; ok { + if messages, ok := raw.([]any); ok { + return len(messages) + } + } + maxCount := 0 + for _, child := range v { + if count := countMessages(child); count > maxCount { + maxCount = count + } + } + return maxCount + case []any: + maxCount := 0 + for _, child := range v { + if count := countMessages(child); count > maxCount { + maxCount = count + } + } + return maxCount + default: + return 0 + } +} + +func hashPrefix(prefix, value string) string { + sum := sha256.Sum256([]byte(value)) + return prefix + ":" + hex.EncodeToString(sum[:])[:12] +} + +// PayloadEvent is a single payload diagnostics log entry. It records +// request-body metadata (and the body itself) for buffered requests. +// It never records headers, tokens, or credential values. +type PayloadEvent struct { + Time time.Time `json:"time"` + Method string `json:"method"` + Path string `json:"path"` + Provider string `json:"provider"` + RouteKind string `json:"route_kind,omitempty"` + Model string `json:"model,omitempty"` + ClientKind string `json:"client_kind,omitempty"` + SessionKey string `json:"session_key,omitempty"` + SessionSource string `json:"session_source,omitempty"` + SessionSignal string `json:"session_signal,omitempty"` + FrameIndex int `json:"frame_index,omitempty"` + BodyBytes int `json:"body_bytes"` + Body json.RawMessage `json:"body,omitempty"` +} + +// encodeBody returns raw as an embedded JSON value if it is valid JSON, +// or as a JSON string otherwise. This keeps the payload log valid JSONL +// regardless of whether the request body was JSON or binary. +func encodeBody(raw []byte) json.RawMessage { + if json.Valid(raw) { + return json.RawMessage(raw) + } + encoded, err := json.Marshal(string(raw)) + if err != nil { + return json.RawMessage(`""`) + } + return json.RawMessage(encoded) +} + +// jsonlWriter is a low-level JSONL file writer with a mutex for concurrent safety. +type jsonlWriter struct { mu sync.Mutex file *os.File } -func OpenDiagnosticsWriter(path string) (*DiagnosticsWriter, error) { +func openJSONLWriter(path string) (*jsonlWriter, error) { f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) if err != nil { return nil, err @@ -101,25 +343,19 @@ func OpenDiagnosticsWriter(path string) (*DiagnosticsWriter, error) { _ = f.Close() return nil, err } - return &DiagnosticsWriter{file: f}, nil + return &jsonlWriter{file: f}, nil } -func (w *DiagnosticsWriter) Write(event RouteEvent) error { - if w == nil { - return nil - } +func (w *jsonlWriter) encode(v any) error { w.mu.Lock() defer w.mu.Unlock() if w.file == nil { return nil } - return json.NewEncoder(w.file).Encode(event) + return json.NewEncoder(w.file).Encode(v) } -func (w *DiagnosticsWriter) Close() error { - if w == nil { - return nil - } +func (w *jsonlWriter) close() error { w.mu.Lock() defer w.mu.Unlock() if w.file == nil { @@ -129,3 +365,60 @@ func (w *DiagnosticsWriter) Close() error { w.file = nil return err } + +// DiagnosticsWriter writes RouteEvents to a JSONL file. +type DiagnosticsWriter struct { + w *jsonlWriter +} + +func OpenDiagnosticsWriter(path string) (*DiagnosticsWriter, error) { + jw, err := openJSONLWriter(path) + if err != nil { + return nil, err + } + return &DiagnosticsWriter{w: jw}, nil +} + +func (w *DiagnosticsWriter) Write(event RouteEvent) error { + if w == nil || w.w == nil { + return nil + } + return w.w.encode(event) +} + +func (w *DiagnosticsWriter) Close() error { + if w == nil || w.w == nil { + return nil + } + return w.w.close() +} + +// PayloadWriter writes PayloadEvents to a JSONL file. +type PayloadWriter struct { + w *jsonlWriter +} + +// OpenPayloadWriter opens (or creates) a JSONL file for payload diagnostics. +func OpenPayloadWriter(path string) (*PayloadWriter, error) { + jw, err := openJSONLWriter(path) + if err != nil { + return nil, err + } + return &PayloadWriter{w: jw}, nil +} + +// Write appends a PayloadEvent. Nil-safe and zero-value-safe. +func (w *PayloadWriter) Write(event PayloadEvent) error { + if w == nil || w.w == nil { + return nil + } + return w.w.encode(event) +} + +// Close closes the underlying file. Nil-safe and idempotent. +func (w *PayloadWriter) Close() error { + if w == nil || w.w == nil { + return nil + } + return w.w.close() +} diff --git a/internal/proxy/diag_test.go b/internal/proxy/diag_test.go index 5c23af4..ce8fd02 100644 --- a/internal/proxy/diag_test.go +++ b/internal/proxy/diag_test.go @@ -3,9 +3,11 @@ package proxy import ( "bufio" "encoding/json" + "net/http" "os" "path/filepath" "runtime" + "strings" "sync" "testing" "time" @@ -149,6 +151,516 @@ func TestDiagnosticsWriterCloseSafeAndStopsWrites(t *testing.T) { } } +// ── sessionCorrelation tests ───────────────────────────────────────────────── + +func TestSessionCorrelationClaudeSessionId(t *testing.T) { + h := http.Header{} + h.Set("X-Claude-Code-Session-Id", "session-abc-123") + key, source := sessionCorrelation(h) + if source != "x-claude-code-session-id" { + t.Fatalf("source = %q, want x-claude-code-session-id", source) + } + if key == "" || key[:len("claude-session:")] != "claude-session:" { + t.Fatalf("key = %q, want claude-session:", key) + } + // Deterministic: same input → same key. + key2, _ := sessionCorrelation(h) + if key != key2 { + t.Fatalf("non-deterministic: key1=%q key2=%q", key, key2) + } + // Does not contain raw header value. + if key == "claude-session:session-abc-123" { + t.Fatal("key leaks raw header value") + } +} + +func TestSessionCorrelationCodexSessionId(t *testing.T) { + // http.CanonicalHeaderKey("session_id") = "Session_id" + // Use the canonical form so h.Get() finds it. + h := http.Header{} + h["Session_id"] = []string{"codex-sess-xyz"} + key, source := sessionCorrelation(h) + if source != "session_id" { + t.Fatalf("source = %q, want session_id", source) + } + if key == "" || key[:len("codex-session:")] != "codex-session:" { + t.Fatalf("key = %q, want codex-session:", key) + } +} + +func TestSessionCorrelationCodexWindowId(t *testing.T) { + h := http.Header{} + h.Set("X-Codex-Window-Id", "window-001") + key, source := sessionCorrelation(h) + if source != "x-codex-window-id" { + t.Fatalf("source = %q, want x-codex-window-id", source) + } + if key == "" || key[:len("codex-window:")] != "codex-window:" { + t.Fatalf("key = %q, want codex-window:", key) + } +} + +func TestSessionCorrelationUnknownClient(t *testing.T) { + h := http.Header{} + h.Set("User-Agent", "claude-code/1.2.3") + key, source := sessionCorrelation(h) + if source != "unknown-client" { + t.Fatalf("source = %q, want unknown-client", source) + } + if key == "" || key[:len("unknown-client:")] != "unknown-client:" { + t.Fatalf("key = %q, want unknown-client:", key) + } + // Raw User-Agent is not exposed. + if key == "unknown-client:claude-code/1.2.3" { + t.Fatal("key leaks raw User-Agent") + } +} + +func TestSessionCorrelationNone(t *testing.T) { + h := http.Header{} + key, source := sessionCorrelation(h) + if source != "none" { + t.Fatalf("source = %q, want none", source) + } + if key != "" { + t.Fatalf("key = %q, want empty", key) + } +} + +func TestSessionCorrelationPriority(t *testing.T) { + // Claude session ID takes priority over Codex headers. + h := http.Header{} + h.Set("X-Claude-Code-Session-Id", "claude-sess") + h.Set("X-Codex-Window-Id", "codex-win") + h["session_id"] = []string{"codex-sess"} + _, source := sessionCorrelation(h) + if source != "x-claude-code-session-id" { + t.Fatalf("source = %q, want x-claude-code-session-id (highest priority)", source) + } +} + +func TestSessionCorrelationDistinctSessions(t *testing.T) { + h1 := http.Header{} + h1.Set("X-Claude-Code-Session-Id", "session-alpha") + h2 := http.Header{} + h2.Set("X-Claude-Code-Session-Id", "session-beta") + key1, _ := sessionCorrelation(h1) + key2, _ := sessionCorrelation(h2) + if key1 == key2 { + t.Fatalf("distinct sessions produced same key: %q", key1) + } +} + +func TestSessionCorrelationNoCredentialHeaders(t *testing.T) { + // Authorization, cookies, x-api-key, etc. must never influence the key. + // We verify that a request with only credential headers produces source="none". + h := http.Header{} + h.Set("Authorization", "Bearer secret-token") + h.Set("Cookie", "session=abc") + h.Set("X-Api-Key", "api-key-secret") + key, source := sessionCorrelation(h) + if source != "none" { + t.Fatalf("source = %q, want none (credential headers must be ignored)", source) + } + if key != "" { + t.Fatalf("key = %q, want empty", key) + } +} + +func TestPayloadSessionCorrelationConversationIDOverridesUnknownClient(t *testing.T) { + h := http.Header{} + h.Set("User-Agent", "claude-code/1.0.0") + raw := []byte(`{"model":"claude-sonnet","conversation_id":"conv-alpha","messages":[]}`) + key, source := payloadSessionCorrelation(h, raw) + if source != "body:conversation_id" { + t.Fatalf("source = %q, want body:conversation_id", source) + } + if key == "" || !strings.HasPrefix(key, "body-session:") { + t.Fatalf("key = %q, want body-session:", key) + } + if strings.Contains(key, "conv-alpha") { + t.Fatalf("key leaks raw conversation ID: %q", key) + } +} + +func TestPayloadSessionCorrelationNestedThreadID(t *testing.T) { + h := http.Header{} + h.Set("User-Agent", "codex/1.0.0") + raw := []byte(`{"model":"gpt-5.5","metadata":{"thread_id":"thread-123"},"input":[]}`) + key, source := payloadSessionCorrelation(h, raw) + if source != "body:thread_id" { + t.Fatalf("source = %q, want body:thread_id", source) + } + if key == "" || !strings.HasPrefix(key, "body-session:") { + t.Fatalf("key = %q, want body-session:", key) + } +} + +func TestPayloadSessionCorrelationPreviousResponseID(t *testing.T) { + h := http.Header{} + h.Set("User-Agent", "codex/1.0.0") + raw := []byte(`{"model":"gpt-5.5","previous_response_id":"resp_123","input":[{"role":"user","content":"continue"}]}`) + key, source := payloadSessionCorrelation(h, raw) + if source != "body:previous_response_id" { + t.Fatalf("source = %q, want body:previous_response_id", source) + } + if key == "" || !strings.HasPrefix(key, "body-session:") { + t.Fatalf("key = %q, want body-session:", key) + } +} + +func TestPayloadSessionCorrelationHeaderSessionBeatsBody(t *testing.T) { + h := http.Header{} + h.Set("X-Claude-Code-Session-Id", "header-session") + raw := []byte(`{"conversation_id":"body-session"}`) + key, source := payloadSessionCorrelation(h, raw) + if source != "x-claude-code-session-id" { + t.Fatalf("source = %q, want x-claude-code-session-id", source) + } + if key == "" || !strings.HasPrefix(key, "claude-session:") { + t.Fatalf("key = %q, want claude-session:", key) + } +} + +func TestPayloadSessionCorrelationCodexHeaderBeatsBodySessionID(t *testing.T) { + h := http.Header{} + h["Session_id"] = []string{"header-session"} + raw := []byte(`{"session_id":"body-session"}`) + key, source := payloadSessionCorrelation(h, raw) + if source != "session_id" { + t.Fatalf("source = %q, want session_id", source) + } + if key == "" || !strings.HasPrefix(key, "codex-session:") { + t.Fatalf("key = %q, want codex-session:", key) + } +} + +func TestPayloadSessionCorrelationUnknownClientFallback(t *testing.T) { + h := http.Header{} + h.Set("User-Agent", "claude-code/1.0.0") + raw := []byte(`{"model":"claude-sonnet","messages":[]}`) + key, source := payloadSessionCorrelation(h, raw) + if source != "unknown-client" { + t.Fatalf("source = %q, want unknown-client", source) + } + if key == "" || !strings.HasPrefix(key, "unknown-client:") { + t.Fatalf("key = %q, want unknown-client:", key) + } +} + +func TestPayloadSessionCorrelationInvalidJSONFallback(t *testing.T) { + h := http.Header{} + h.Set("User-Agent", "claude-code/1.0.0") + raw := []byte(`not json conversation_id=abc`) + key, source := payloadSessionCorrelation(h, raw) + if source != "unknown-client" { + t.Fatalf("source = %q, want unknown-client", source) + } + if key == "" || !strings.HasPrefix(key, "unknown-client:") { + t.Fatalf("key = %q, want unknown-client:", key) + } +} + +func TestCodexWebSocketFrameSignalNewSession(t *testing.T) { + frame := []byte(`{"jsonrpc":"2.0","id":1,"method":"thread/start","params":{"model":"gpt-5.5","thread_id":"thread-new"}}`) + key, source, signal := codexWebSocketFrameCorrelation(nil, frame) + if signal != "new_session" { + t.Fatalf("signal = %q, want new_session", signal) + } + if source != "ws:thread_id" { + t.Fatalf("source = %q, want ws:thread_id", source) + } + if key == "" || !strings.HasPrefix(key, "ws-session:") { + t.Fatalf("key = %q, want ws-session:", key) + } +} + +func TestCodexWebSocketFrameSignalContinuation(t *testing.T) { + frame := []byte(`{"jsonrpc":"2.0","id":2,"method":"response/create","params":{"previous_response_id":"resp_prev"}}`) + _, source, signal := codexWebSocketFrameCorrelation(nil, frame) + if signal != "continuation" { + t.Fatalf("signal = %q, want continuation", signal) + } + if source != "ws:previous_response_id" { + t.Fatalf("source = %q, want ws:previous_response_id", source) + } +} + +func TestCodexWebSocketFrameSignalClearCompactLongUnknown(t *testing.T) { + tests := []struct { + name string + frame string + signal string + }{ + {"clear", `{"method":"thread/clear","params":{"thread_id":"t"}}`, "clear_transition"}, + {"compact", `{"method":"thread/compact","params":{"thread_id":"t"}}`, "compact_transition"}, + {"long", `{"method":"response/create","params":{"thread_id":"t","messages":[1,2,3,4,5,6,7,8,9,10,11]}}`, "long_session"}, + {"continuation beats long", `{"method":"response/create","params":{"previous_response_id":"resp_prev","messages":[1,2,3,4,5,6,7,8,9,10,11]}}`, "continuation"}, + {"unknown", `{"method":"ping","params":{}}`, "unknown"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, _, signal := codexWebSocketFrameCorrelation(nil, []byte(tc.frame)) + if signal != tc.signal { + t.Fatalf("signal = %q, want %q", signal, tc.signal) + } + }) + } +} + +// ── encodeBody tests ───────────────────────────────────────────────────────── + +func TestEncodeBodyValidJSON(t *testing.T) { + raw := []byte(`{"model":"claude-sonnet","messages":[]}`) + result := encodeBody(raw) + if !json.Valid(result) { + t.Fatalf("result is not valid JSON: %s", result) + } + // Should be embedded as-is (not double-encoded as a string). + if result[0] != '{' { + t.Fatalf("expected object literal, got: %s", result) + } +} + +func TestEncodeBodyInvalidFallback(t *testing.T) { + raw := []byte("not json at all \x00\x01") + result := encodeBody(raw) + if !json.Valid(result) { + t.Fatalf("result is not valid JSON: %s", result) + } + // Should be a JSON string. + var s string + if err := json.Unmarshal(result, &s); err != nil { + t.Fatalf("expected JSON string, got %s: %v", result, err) + } +} + +func TestEncodeBodyEmpty(t *testing.T) { + // Empty bytes are valid JSON (they're not, but we accept nil/empty gracefully). + result := encodeBody([]byte{}) + // Should produce a JSON string (empty bytes are not valid JSON). + if !json.Valid(result) { + t.Fatalf("result is not valid JSON: %s", result) + } +} + +// ── PayloadWriter tests ────────────────────────────────────────────────────── + +func TestPayloadWriterCreatesAndAppendsJSONL(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + + w, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + if err := w.Write(PayloadEvent{ + Time: time.Unix(1, 0).UTC(), + Method: "POST", + Path: "/v1/messages", + Provider: "claude", + RouteKind: "anthropic_messages", + Model: "claude-sonnet", + BodyBytes: 42, + Body: encodeBody([]byte(`{"model":"claude-sonnet"}`)), + }); err != nil { + t.Fatalf("Write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Reopen and append. + w, err = OpenPayloadWriter(path) + if err != nil { + t.Fatalf("reopen PayloadWriter: %v", err) + } + if err := w.Write(PayloadEvent{ + Time: time.Unix(2, 0).UTC(), + Method: "POST", + Path: "/responses", + Provider: "codex", + RouteKind: "codex_native", + Model: "gpt-5.4", + BodyBytes: 22, + Body: encodeBody([]byte(`{"model":"gpt-5.4"}`)), + }); err != nil { + t.Fatalf("append Write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("append Close: %v", err) + } + + events := readPayloadEvents(t, path) + if len(events) != 2 { + t.Fatalf("events = %d, want 2", len(events)) + } + if events[0].Path != "/v1/messages" || events[1].Path != "/responses" { + t.Fatalf("events paths = %q, %q", events[0].Path, events[1].Path) + } + if runtime.GOOS != "windows" { + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat payload log: %v", err) + } + if got := info.Mode().Perm(); got != 0o600 { + t.Fatalf("file mode = %#o, want 0600", got) + } + } +} + +func TestPayloadWriterNilSafe(t *testing.T) { + var w *PayloadWriter + if err := w.Write(PayloadEvent{Path: "/v1/messages"}); err != nil { + t.Fatalf("nil Write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("nil Close: %v", err) + } +} + +func TestPayloadWriterZeroValueSafe(t *testing.T) { + var w PayloadWriter + if err := w.Write(PayloadEvent{Path: "/v1/messages"}); err != nil { + t.Fatalf("zero Write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("zero Close: %v", err) + } +} + +func TestPayloadWriterConcurrentWritesProduceValidJSONLines(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + w, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + + const count = 64 + var wg sync.WaitGroup + wg.Add(count) + for i := 0; i < count; i++ { + i := i + go func() { + defer wg.Done() + if err := w.Write(PayloadEvent{ + Time: time.Unix(int64(i), 0).UTC(), + Method: "POST", + Path: "/v1/messages", + Provider: "claude", + BodyBytes: i, + Body: encodeBody([]byte(`{"model":"claude-sonnet"}`)), + }); err != nil { + t.Errorf("Write(%d): %v", i, err) + } + }() + } + wg.Wait() + + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + events := readPayloadEvents(t, path) + if len(events) != count { + t.Fatalf("events = %d, want %d", len(events), count) + } +} + +func TestPayloadWriterCloseSafeAndStopsWrites(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + w, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + if err := w.Write(PayloadEvent{Time: time.Unix(1, 0).UTC(), Method: "POST", Path: "/v1/messages", Provider: "claude"}); err != nil { + t.Fatalf("Write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } + if err := w.Write(PayloadEvent{Time: time.Unix(2, 0).UTC(), Method: "POST", Path: "/v1/messages", Provider: "claude"}); err != nil { + t.Fatalf("Write after Close: %v", err) + } + + events := readPayloadEvents(t, path) + if len(events) != 1 { + t.Fatalf("events after closed write = %d, want 1", len(events)) + } +} + +func TestPayloadEventIncludesSessionCorrelation(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + w, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + + h := http.Header{} + h.Set("X-Claude-Code-Session-Id", "test-session-id") + sessionKey, sessionSource := sessionCorrelation(h) + + if err := w.Write(PayloadEvent{ + Time: time.Unix(1, 0).UTC(), + Method: "POST", + Path: "/v1/messages", + Provider: "claude", + RouteKind: "anthropic_messages", + Model: "claude-sonnet", + SessionKey: sessionKey, + SessionSource: sessionSource, + BodyBytes: 10, + Body: encodeBody([]byte(`{"model":"claude-sonnet"}`)), + }); err != nil { + t.Fatalf("Write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + events := readPayloadEvents(t, path) + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + ev := events[0] + if ev.SessionSource != "x-claude-code-session-id" { + t.Fatalf("SessionSource = %q, want x-claude-code-session-id", ev.SessionSource) + } + if ev.SessionKey == "" || ev.SessionKey[:len("claude-session:")] != "claude-session:" { + t.Fatalf("SessionKey = %q, want claude-session:", ev.SessionKey) + } + // Verify raw session ID is not in the log. + raw, _ := os.ReadFile(path) + if strings.Contains(string(raw), "test-session-id") { + t.Fatalf("payload log leaked raw session ID: %s", raw) + } +} + +func readPayloadEvents(t *testing.T, path string) []PayloadEvent { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatalf("open payload log: %v", err) + } + defer f.Close() + + var events []PayloadEvent + scanner := bufio.NewScanner(f) + for scanner.Scan() { + var event PayloadEvent + if err := json.Unmarshal(scanner.Bytes(), &event); err != nil { + t.Fatalf("invalid payload JSON line %q: %v", scanner.Text(), err) + } + events = append(events, event) + } + if err := scanner.Err(); err != nil { + t.Fatalf("scan payload log: %v", err) + } + return events +} + func readDiagnosticsEvents(t *testing.T, path string) []RouteEvent { t.Helper() f, err := os.Open(path) diff --git a/internal/proxy/pinned_selector.go b/internal/proxy/pinned_selector.go index 96b1304..6b53474 100644 --- a/internal/proxy/pinned_selector.go +++ b/internal/proxy/pinned_selector.go @@ -122,7 +122,10 @@ func (s *PinnedClaudeSelector) pinExhausted(acct *keyring.ClaudeOAuth) bool { return false } snap, ok := s.quota.Snapshot(acctIdentifier(acct)) - return ok && snap.Result.MinRemainingPct() == 0 + if !ok || time.Since(snap.FetchedAt) > transientQuotaMaxAge { + return false + } + return snap.Result.MinRemainingPct() == 0 } func (s *PinnedClaudeSelector) expirePin(pin string) { diff --git a/internal/proxy/pinned_selector_test.go b/internal/proxy/pinned_selector_test.go index cf52897..270d8dd 100644 --- a/internal/proxy/pinned_selector_test.go +++ b/internal/proxy/pinned_selector_test.go @@ -149,6 +149,44 @@ func TestPinnedClaudeSelector_ExhaustedPinClearsAndDelegatesToInner(t *testing.T } } +func TestPinnedClaudeSelector_StaleExhaustedPinDoesNotClear(t *testing.T) { + future := time.Now().UnixMilli() + 3600_000 + accounts := []keyring.ClaudeOAuth{ + {Email: "pinned@test.com", AccountUUID: "uuid-pin", AccessToken: "tok-pin", ExpiresAt: future}, + {Email: "fallback@test.com", AccountUUID: "uuid-fallback", AccessToken: "tok-fb", ExpiresAt: future}, + } + inner := innerSelectorFunc(func(ctx context.Context, exclude ...string) (*keyring.ClaudeOAuth, error) { + return &keyring.ClaudeOAuth{Email: "fallback@test.com", AccessToken: "tok-fb", ExpiresAt: future}, nil + }) + sel := NewPinnedClaudeSelector(inner, func() []keyring.ClaudeOAuth { return accounts }, "pinned@test.com", stubQuotaReader{ + "uuid-pin": { + Result: quota.Result{ + Status: quota.StatusExhausted, + Windows: map[quota.WindowName]quota.Window{ + quota.Window5Hour: {RemainingPct: 0}, + }, + }, + FetchedAt: time.Now().Add(-10 * time.Minute), + }, + }) + expiredPin := "" + sel.SetPinExpireFunc(func(pin string) { expiredPin = pin }) + + acct, err := sel.Select(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if acct.Email != "pinned@test.com" { + t.Errorf("email = %q, want pinned@test.com", acct.Email) + } + if got := sel.Pin(); got != "pinned@test.com" { + t.Errorf("pin = %q, want pinned@test.com", got) + } + if expiredPin != "" { + t.Errorf("expired pin = %q, want empty", expiredPin) + } +} + func TestPinnedClaudeSelector_NotFoundReturnsError(t *testing.T) { future := time.Now().UnixMilli() + 3600_000 accounts := []keyring.ClaudeOAuth{ diff --git a/internal/proxy/selector.go b/internal/proxy/selector.go index 13a4623..c696b33 100644 --- a/internal/proxy/selector.go +++ b/internal/proxy/selector.go @@ -115,6 +115,12 @@ func (s *accountSelector) betterCandidate(candidate *keyring.ClaudeOAuth, candid // Quota-based comparison (when at least one has data). if candidateRemaining >= 0 || currentRemaining >= 0 { if candidateRemaining != currentRemaining { + if candidateRemaining == 0 && currentRemaining < 0 { + return false + } + if currentRemaining == 0 && candidateRemaining < 0 { + return true + } return candidateRemaining > currentRemaining } // Equal remaining — fall through to tiebreakers. diff --git a/internal/proxy/selector_test.go b/internal/proxy/selector_test.go index 45456e0..6a1e979 100644 --- a/internal/proxy/selector_test.go +++ b/internal/proxy/selector_test.go @@ -326,6 +326,27 @@ func TestAccountSelector_Select_QuotaAware(t *testing.T) { } }) + t.Run("unknown quota beats known exhausted", func(t *testing.T) { + quotaReader := buildQuotaReader(map[string]QuotaSnapshot{ + "uuid-exhausted": {Result: quotaResult("uuid-exhausted", "exhausted@test.com", 0)}, + // uuid-unknown has no snapshot + }) + sel := NewAccountSelector(func() []keyring.ClaudeOAuth { + return []keyring.ClaudeOAuth{ + {Email: "unknown@test.com", AccountUUID: "uuid-unknown", AccessToken: "t1", ExpiresAt: future}, + {Email: "exhausted@test.com", AccountUUID: "uuid-exhausted", AccessToken: "t2", ExpiresAt: future + 5000}, + } + }, nil, quotaReader) + + acct, err := sel.Select(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.Email != "unknown@test.com" { + t.Errorf("got %q, want unknown@test.com (unknown beats confirmed exhausted)", acct.Email) + } + }) + t.Run("nil quota falls back to existing logic", func(t *testing.T) { sel := NewAccountSelector(func() []keyring.ClaudeOAuth { return []keyring.ClaudeOAuth{ diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 1ce4869..27b7b2c 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -57,6 +57,7 @@ type Server struct { CodexUpgradeTransport http.RoundTripper // HTTP/1.1-only transport for WebSocket upgrades Headroom *HeadroomBridge Diag *DiagnosticsWriter + PayloadDiag *PayloadWriter // HeadroomMode is the resolved compression mode. Only meaningful when // Headroom is non-nil. Reported in the /health response. HeadroomMode HeadroomMode @@ -321,7 +322,7 @@ func (s *Server) handleCodexAppServerRoute(w http.ResponseWriter, r *http.Reques message := fmt.Sprintf("%s requires websocket upgrade", codexAppServerPath) w.Header().Set("Upgrade", "websocket") writeError(w, http.StatusUpgradeRequired, "invalid_request_error", message) - s.emitDiagnostics(RouteEvent{ + event := RouteEvent{ Time: start.UTC(), Method: r.Method, Path: r.URL.Path, @@ -330,7 +331,9 @@ func (s *Server) handleCodexAppServerRoute(w http.ResponseWriter, r *http.Reques StatusCode: http.StatusUpgradeRequired, LatencyMS: time.Since(start).Milliseconds(), Error: diagnosticsErrorCode("invalid_request_error", message), - }) + } + event.applySessionCorrelation(r.Header) + s.emitDiagnostics(event) return } s.proxyCodexAppServer(w, r) @@ -370,6 +373,7 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { }, "diagnostics": map[string]bool{ "enabled": s.Diag != nil, + "payload": s.PayloadDiag != nil, }, } if s.Headroom != nil { @@ -429,6 +433,7 @@ func (s *Server) handleNativeCodex(w http.ResponseWriter, r *http.Request) { Error: rec.diagnosticsError(), } event.applyRouteDiagnostics(routeDiag) + event.applySessionCorrelation(r.Header) s.emitDiagnostics(event) }() } @@ -453,6 +458,24 @@ func (s *Server) handleNativeCodex(w http.ResponseWriter, r *http.Request) { model = extractModel(body) fmt.Fprintf(os.Stderr, "cq: route POST /responses model=%q provider=codex (native)\n", model) + // Emit payload diagnostics before any body rewrite. + if s.PayloadDiag != nil { + sessionKey, sessionSource := payloadSessionCorrelation(r.Header, body) + s.emitPayloadDiagnostics(PayloadEvent{ + Time: time.Now().UTC(), + Method: r.Method, + Path: r.URL.Path, + Provider: "codex", + RouteKind: "codex_native", + Model: model, + ClientKind: clientKindFromUserAgent(r.Header.Get("User-Agent")), + SessionKey: sessionKey, + SessionSource: sessionSource, + BodyBytes: len(body), + Body: encodeBody(body), + }) + } + // Compress Responses API input via headroom bridge if available. // Fail-open: on error, log and continue with original body. if s.Headroom != nil { @@ -541,57 +564,78 @@ func (s *Server) handleNativeCodex(w http.ResponseWriter, r *http.Request) { // binary/text frames are not buffered by this proxy. func (s *Server) proxyCodexUpgrade(w http.ResponseWriter, r *http.Request) { start := time.Now() - var rec *diagnosticsResponseWriter + statusCode := 0 + diagError := "" ctx, routeDiag := withRouteDiagnostics(r.Context()) r = r.WithContext(ctx) - if wrapped, recorder := s.wrapDiagnosticsResponseWriter(w); recorder != nil { - w = wrapped - rec = recorder - defer func() { - status := rec.statusCode() - if rec.status == 0 { - status = http.StatusSwitchingProtocols - } - event := RouteEvent{ - Time: start.UTC(), - Method: r.Method, - Path: r.URL.Path, - Provider: "codex", - RouteKind: "codex_legacy_websocket", - StatusCode: status, - LatencyMS: time.Since(start).Milliseconds(), - Error: rec.diagnosticsError(), - } - event.applyRouteDiagnostics(routeDiag) - s.emitDiagnostics(event) - }() - } + defer func() { + event := RouteEvent{ + Time: start.UTC(), + Method: r.Method, + Path: r.URL.Path, + Provider: "codex", + RouteKind: "codex_legacy_websocket", + StatusCode: statusCode, + LatencyMS: time.Since(start).Milliseconds(), + Error: diagError, + } + event.applyRouteDiagnostics(routeDiag) + event.applySessionCorrelation(r.Header) + s.emitDiagnostics(event) + }() - codexUpstream, err := url.Parse(s.Config.CodexUpstream) + transport, err := s.codexAppServerTransport() if err != nil { + statusCode = http.StatusServiceUnavailable + diagError = diagnosticsErrorCode("api_error", err.Error()) + writeError(w, http.StatusServiceUnavailable, "api_error", err.Error()) + return + } + upstreamURL, err := codexAppServerWebSocketURL(s.Config.CodexUpstream) + if err != nil { + statusCode = http.StatusInternalServerError + diagError = diagnosticsErrorCode("api_error", "invalid codex upstream URL") writeError(w, http.StatusInternalServerError, "api_error", "invalid codex upstream URL") return } fmt.Fprintf(os.Stderr, "cq: route %s /responses (websocket upgrade) provider=codex (native)\n", r.Method) - transport := s.CodexUpgradeTransport - if transport == nil { - transport = s.CodexTransport + upgrader := websocket.Upgrader{ + CheckOrigin: func(_ *http.Request) bool { return true }, + Subprotocols: websocket.Subprotocols(r), + } + clientConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return } + statusCode = http.StatusSwitchingProtocols + defer clientConn.Close() + clientConn.SetReadLimit(maxRequestBody) - rp := &httputil.ReverseProxy{ - Rewrite: func(pr *httputil.ProxyRequest) { - pr.SetURL(codexUpstream) - pr.Out.URL.Path = codexUpstream.Path + "/responses" - pr.Out.Host = codexUpstream.Host - }, - Transport: transport, - ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) { - writeError(w, http.StatusBadGateway, "api_error", "codex upstream error: "+err.Error()) - }, + messageType, message, err := clientConn.ReadMessage() + if err != nil { + return } - rp.ServeHTTP(w, r) + requestedModel := "" + if messageType == websocket.TextMessage { + requestedModel = extractCodexWebSocketFrameModel(message) + s.emitCodexWebSocketPayloadDiagnostics(r, legacyCodexResponsesPath, requestedModel, message, 1) + } + upstreamConn, _, err := s.dialCodexAppServer(r.Context(), transport, upstreamURL, r.Header, requestedModel) + if err != nil { + _ = clientConn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "upstream error"), time.Now().Add(time.Second)) + return + } + defer upstreamConn.Close() + upstreamConn.SetReadLimit(maxRequestBody) + if err := upstreamConn.WriteMessage(messageType, message); err != nil { + return + } + errCh := make(chan error, 2) + go func() { errCh <- relayWebSocketMessages(clientConn, upstreamConn) }() + go func() { errCh <- relayWebSocketMessages(upstreamConn, clientConn) }() + <-errCh } // proxyCodexAppServer handles the Codex /app-server websocket path. Unlike the @@ -619,6 +663,7 @@ func (s *Server) proxyCodexAppServer(w http.ResponseWriter, r *http.Request) { Error: diagError, } event.applyRouteDiagnostics(routeDiag) + event.applySessionCorrelation(r.Header) s.emitDiagnostics(event) }() @@ -656,6 +701,7 @@ func (s *Server) proxyCodexAppServer(w http.ResponseWriter, r *http.Request) { } if messageType == websocket.TextMessage { requestedModel = extractCodexAppServerThreadStartModel(message) + s.emitCodexWebSocketPayloadDiagnostics(r, codexAppServerPath, requestedModel, message, 1) } fmt.Fprintf(os.Stderr, "cq: route %s %s model=%q provider=codex (native)\n", r.Method, codexAppServerPath, requestedModel) @@ -802,6 +848,22 @@ func extractCodexAppServerThreadStartModel(message []byte) string { return payload.Params.Model } +func extractCodexWebSocketFrameModel(message []byte) string { + var payload struct { + Model string `json:"model"` + Params struct { + Model string `json:"model"` + } `json:"params"` + } + if json.Unmarshal(message, &payload) != nil { + return "" + } + if payload.Model != "" { + return payload.Model + } + return payload.Params.Model +} + func rewriteCodexAppServerThreadStartMessage(message []byte, acct *codex.CodexAccount) []byte { var payload map[string]json.RawMessage if json.Unmarshal(message, &payload) != nil { @@ -902,6 +964,7 @@ func (s *Server) proxyHandler(upstream *url.URL) http.HandlerFunc { Error: rec.diagnosticsError(), } event.applyRouteDiagnostics(routeDiag) + event.applySessionCorrelation(r.Header) s.emitDiagnostics(event) }() } @@ -937,6 +1000,26 @@ func (s *Server) proxyHandler(upstream *url.URL) http.HandlerFunc { // Route based on the original endpoint and model before any body rewriting. routeModel = extractModel(buf) + // Emit payload diagnostics before any body rewrite, while buf still holds + // the original request body. Only emitted for buffered Anthropic endpoints. + if diagnosticsAnthropicRouteKind(r.URL.Path) != "" && s.PayloadDiag != nil { + sessionKey, sessionSource := payloadSessionCorrelation(r.Header, buf) + routeProvider = RouteRequestWithCatalog(r.Method, r.URL.Path, routeModel, s.Catalog) + s.emitPayloadDiagnostics(PayloadEvent{ + Time: start.UTC(), + Method: r.Method, + Path: r.URL.Path, + Provider: providerName(routeProvider), + RouteKind: diagnosticsAnthropicRouteKind(r.URL.Path), + Model: routeModel, + ClientKind: clientKindFromUserAgent(r.Header.Get("User-Agent")), + SessionKey: sessionKey, + SessionSource: sessionSource, + BodyBytes: len(buf), + Body: encodeBody(buf), + }) + } + // Compress messages via headroom bridge if available. // Dispatch to the correct path based on the resolved headroom mode. if s.Headroom != nil && len(buf) > 0 { @@ -1045,6 +1128,40 @@ func (s *Server) emitDiagnostics(event RouteEvent) { } } +func (s *Server) emitPayloadDiagnostics(event PayloadEvent) { + if s == nil || s.PayloadDiag == nil { + return + } + if event.Time.IsZero() { + event.Time = time.Now().UTC() + } + if err := s.PayloadDiag.Write(event); err != nil { + fmt.Fprintf(os.Stderr, "cq: payload diagnostics: write: %v\n", err) + } +} + +func (s *Server) emitCodexWebSocketPayloadDiagnostics(r *http.Request, path, model string, frame []byte, frameIndex int) { + if s == nil || s.PayloadDiag == nil || r == nil { + return + } + sessionKey, sessionSource, signal := codexWebSocketFrameCorrelation(r.Header, frame) + s.emitPayloadDiagnostics(PayloadEvent{ + Time: time.Now().UTC(), + Method: r.Method, + Path: path, + Provider: "codex", + RouteKind: "codex_websocket_frame", + Model: model, + ClientKind: clientKindFromUserAgent(r.Header.Get("User-Agent")), + SessionKey: sessionKey, + SessionSource: sessionSource, + SessionSignal: signal, + FrameIndex: frameIndex, + BodyBytes: len(frame), + Body: encodeBody(frame), + }) +} + func (s *Server) claudePinActive() bool { if s == nil { return false @@ -1066,6 +1183,24 @@ func diagnosticsAnthropicRouteKind(path string) string { } } +// clientKindFromUserAgent classifies the client type from a User-Agent string. +// Returns a short lowercase label suitable for diagnostics. +func clientKindFromUserAgent(ua string) string { + lower := strings.ToLower(ua) + switch { + case strings.Contains(lower, "claude-code"): + return "claude-code" + case strings.Contains(lower, "codex"): + return "codex" + case strings.Contains(lower, "anthropic"): + return "anthropic-sdk" + case ua == "": + return "" + default: + return "other" + } +} + func providerName(provider Provider) string { switch provider { case ProviderCodex: diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index 1801c7f..c6a0f82 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" "time" @@ -193,6 +194,7 @@ func TestServerDiagnosticsClaudeRouteEmitsEvent(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{"model":"claude-sonnet","messages":[]}`)) req.Header.Set("Authorization", "Bearer local-tok") req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Claude-Code-Session-Id", "raw-session-secret") handler(w, req) if w.Code != http.StatusOK { @@ -230,6 +232,13 @@ func TestServerDiagnosticsClaudeRouteEmitsEvent(t *testing.T) { if ev.Time.IsZero() { t.Fatal("Time is zero") } + if ev.SessionSource != "x-claude-code-session-id" { + t.Fatalf("SessionSource = %q, want x-claude-code-session-id", ev.SessionSource) + } + if ev.SessionKey == "" || !strings.HasPrefix(ev.SessionKey, "claude-session:") { + t.Fatalf("SessionKey = %q, want claude-session:", ev.SessionKey) + } + assertDiagnosticsLogDoesNotContain(t, path, "raw-session-secret") assertDiagnosticsLogDoesNotContain(t, path, "local-tok") assertDiagnosticsLogDoesNotContain(t, path, "user@test.com") assertDiagnosticsLogDoesNotContain(t, path, "account-uuid-secret") @@ -905,6 +914,188 @@ func TestServerDiagnosticsLegacyCodexWebsocketRouteEmitsEvent(t *testing.T) { assertDiagnosticsLogDoesNotContain(t, path, "codex-tok") } +func TestServerPayloadDiagnosticsLegacyCodexWebSocketFrameEmitsEvent(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + upgrader := websocket.Upgrader{CheckOrigin: func(_ *http.Request) bool { return true }} + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + t.Errorf("upstream path = %q, want /responses", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer codex-tok" { + t.Errorf("upstream auth = %q, want Bearer codex-tok", got) + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upstream upgrade error = %v", err) + return + } + defer conn.Close() + messageType, message, err := conn.ReadMessage() + if err != nil { + t.Errorf("upstream read error = %v", err) + return + } + if !strings.Contains(string(message), "response/create") { + t.Errorf("upstream message = %q, want response/create frame", message) + } + if err := conn.WriteMessage(messageType, []byte(`{"jsonrpc":"2.0","id":1,"result":{}}`)); err != nil { + t.Errorf("upstream write error = %v", err) + } + })) + defer upstream.Close() + + srv := &Server{ + Config: &Config{ClaudeUpstream: "https://api.anthropic.com", CodexUpstream: upstream.URL, LocalToken: "tok"}, + CodexUpgradeTransport: &CodexTokenTransport{ + Selector: &fakeCodexSelector{account: &codex.CodexAccount{AccessToken: "codex-tok", AccountID: "acct-codex"}}, + Inner: http.DefaultTransport, + }, + PayloadDiag: payloadDiag, + } + handler, err := srv.handler() + if err != nil { + t.Fatalf("handler() error = %v", err) + } + proxy := httptest.NewServer(handler) + defer proxy.Close() + + wsURL := "ws" + strings.TrimPrefix(proxy.URL, "http") + legacyCodexResponsesPath + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + if resp != nil { + defer resp.Body.Close() + } + t.Fatalf("Dial() error = %v", err) + } + defer conn.Close() + frame := []byte(`{"jsonrpc":"2.0","id":1,"method":"response/create","params":{"model":"gpt-5.5","previous_response_id":"resp_prev"}}`) + if err := conn.WriteMessage(websocket.TextMessage, frame); err != nil { + t.Fatalf("WriteMessage() error = %v", err) + } + if _, _, err := conn.ReadMessage(); err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + events := readPayloadEvents(t, path) + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + ev := events[0] + if ev.Path != legacyCodexResponsesPath || ev.RouteKind != "codex_websocket_frame" || ev.Provider != "codex" { + t.Fatalf("event route = %+v", ev) + } + if ev.Model != "gpt-5.5" { + t.Fatalf("Model = %q, want gpt-5.5", ev.Model) + } + if ev.SessionSource != "ws:previous_response_id" || ev.SessionSignal != "continuation" { + t.Fatalf("source/signal = %q/%q, want ws:previous_response_id/continuation", ev.SessionSource, ev.SessionSignal) + } + if ev.SessionKey == "" || !strings.HasPrefix(ev.SessionKey, "ws-session:") { + t.Fatalf("SessionKey = %q, want ws-session:", ev.SessionKey) + } + assertPayloadLogDoesNotContain(t, path, "codex-tok") + assertPayloadLogDoesNotContain(t, path, "acct-codex") +} + +func TestServerPayloadDiagnosticsCodexAppServerWebSocketFrameEmitsEvent(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + upgrader := websocket.Upgrader{CheckOrigin: func(_ *http.Request) bool { return true }} + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upstream upgrade error = %v", err) + return + } + defer conn.Close() + messageType, message, err := conn.ReadMessage() + if err != nil { + t.Errorf("upstream read error = %v", err) + return + } + if !strings.Contains(string(message), "thread/start") { + t.Errorf("upstream message = %q, want thread/start frame", message) + } + if err := conn.WriteMessage(messageType, []byte(`{"jsonrpc":"2.0","id":1,"result":{}}`)); err != nil { + t.Errorf("upstream write error = %v", err) + } + })) + defer upstream.Close() + + srv := &Server{ + Config: &Config{ClaudeUpstream: "https://api.anthropic.com", CodexUpstream: upstream.URL, LocalToken: "tok"}, + CodexUpgradeTransport: &CodexTokenTransport{ + Selector: &fakeCodexSelector{account: &codex.CodexAccount{AccessToken: "codex-tok", AccountID: "acct-codex"}}, + Inner: http.DefaultTransport, + }, + PayloadDiag: payloadDiag, + } + handler, err := srv.handler() + if err != nil { + t.Fatalf("handler() error = %v", err) + } + proxy := httptest.NewServer(handler) + defer proxy.Close() + + wsURL := "ws" + strings.TrimPrefix(proxy.URL, "http") + codexAppServerPath + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + if resp != nil { + defer resp.Body.Close() + } + t.Fatalf("Dial() error = %v", err) + } + defer conn.Close() + frame := []byte(`{"jsonrpc":"2.0","id":1,"method":"thread/start","params":{"model":"gpt-5.5","thread_id":"thread-ws-1"}}`) + if err := conn.WriteMessage(websocket.TextMessage, frame); err != nil { + t.Fatalf("WriteMessage() error = %v", err) + } + if _, _, err := conn.ReadMessage(); err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + events := readPayloadEvents(t, path) + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + ev := events[0] + if ev.Path != codexAppServerPath || ev.RouteKind != "codex_websocket_frame" || ev.Provider != "codex" { + t.Fatalf("event route = %+v", ev) + } + if ev.Model != "gpt-5.5" { + t.Fatalf("Model = %q, want gpt-5.5", ev.Model) + } + if ev.SessionSource != "ws:thread_id" || ev.SessionSignal != "new_session" { + t.Fatalf("source/signal = %q/%q, want ws:thread_id/new_session", ev.SessionSource, ev.SessionSignal) + } + if ev.SessionKey == "" || !strings.HasPrefix(ev.SessionKey, "ws-session:") { + t.Fatalf("SessionKey = %q, want ws-session:", ev.SessionKey) + } + if ev.FrameIndex != 1 { + t.Fatalf("FrameIndex = %d, want 1", ev.FrameIndex) + } + if string(ev.Body) != string(frame) { + t.Fatalf("Body = %s, want raw frame %s", ev.Body, frame) + } + assertPayloadLogDoesNotContain(t, path, "codex-tok") + assertPayloadLogDoesNotContain(t, path, "acct-codex") +} + func TestServerDiagnosticsCompactRoutesEmitEvents(t *testing.T) { for _, tc := range []struct { name string @@ -2755,6 +2946,540 @@ func TestServer_ProxyHandler_TokenModeUsesCompress(t *testing.T) { } } +// ── Payload diagnostics tests ──────────────────────────────────────────────── + +func TestServerPayloadDiagnosticsClaudeRouteEmitsEvent(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + future := time.Now().UnixMilli() + 3600_000 + sel := &fakeSelector{accounts: []keyring.ClaudeOAuth{ + {Email: "user@test.com", AccessToken: "real-token", ExpiresAt: future}, + }} + srv := &Server{ + Config: &Config{ + ClaudeUpstream: "https://api.anthropic.com", + LocalToken: "local-tok", + }, + Transport: &TokenTransport{ + Selector: sel, + Inner: roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return makeResponse(http.StatusOK, `{"id":"msg_123"}`), nil + }), + }, + PayloadDiag: payloadDiag, + } + + handler := srv.proxyHandler(mustParseURL(srv.Config.ClaudeUpstream)) + w := httptest.NewRecorder() + reqBody := `{"model":"claude-sonnet","messages":[{"role":"user","content":"hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer local-tok") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "claude-code/1.0.0") + req.Header.Set("X-Claude-Code-Session-Id", "test-session-abc") + handler(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200, body: %s", w.Code, w.Body.String()) + } + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + events := readPayloadEvents(t, path) + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + ev := events[0] + if ev.Method != http.MethodPost || ev.Path != "/v1/messages" || ev.Provider != "claude" { + t.Fatalf("event route = %+v", ev) + } + if ev.RouteKind != "anthropic_messages" { + t.Fatalf("RouteKind = %q, want anthropic_messages", ev.RouteKind) + } + if ev.Model != "claude-sonnet" { + t.Fatalf("Model = %q, want claude-sonnet", ev.Model) + } + if ev.ClientKind != "claude-code" { + t.Fatalf("ClientKind = %q, want claude-code", ev.ClientKind) + } + if ev.SessionSource != "x-claude-code-session-id" { + t.Fatalf("SessionSource = %q, want x-claude-code-session-id", ev.SessionSource) + } + if ev.SessionKey == "" { + t.Fatal("SessionKey is empty, want non-empty") + } + if ev.BodyBytes != len(reqBody) { + t.Fatalf("BodyBytes = %d, want %d", ev.BodyBytes, len(reqBody)) + } + if ev.Body == nil { + t.Fatal("Body is nil, want non-nil") + } + // Must not leak credentials. + assertPayloadLogDoesNotContain(t, path, "local-tok") + assertPayloadLogDoesNotContain(t, path, "real-token") + assertPayloadLogDoesNotContain(t, path, "user@test.com") + // Must not leak raw session ID. + assertPayloadLogDoesNotContain(t, path, "test-session-abc") +} + +func TestServerPayloadDiagnosticsCountTokensEmitsEvent(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + future := time.Now().UnixMilli() + 3600_000 + sel := &fakeSelector{accounts: []keyring.ClaudeOAuth{ + {Email: "user@test.com", AccessToken: "real-token", ExpiresAt: future}, + }} + srv := &Server{ + Config: &Config{ + ClaudeUpstream: "https://api.anthropic.com", + LocalToken: "local-tok", + }, + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"input_tokens":42}`)), + }, nil + }), + Discover: func() []keyring.ClaudeOAuth { + return []keyring.ClaudeOAuth{{Email: "user@test.com", AccessToken: "real-token", ExpiresAt: future}} + }, + PayloadDiag: payloadDiag, + } + _ = sel + + handler := srv.proxyHandler(mustParseURL(srv.Config.ClaudeUpstream)) + w := httptest.NewRecorder() + reqBody := `{"model":"claude-sonnet","messages":[{"role":"user","content":"hi"}]}` + req := httptest.NewRequest(http.MethodPost, countTokensPath, strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer local-tok") + req.Header.Set("Content-Type", "application/json") + handler(w, req) + + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + events := readPayloadEvents(t, path) + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + ev := events[0] + if ev.Path != countTokensPath { + t.Fatalf("Path = %q, want %q", ev.Path, countTokensPath) + } + if ev.RouteKind != "anthropic_count_tokens" { + t.Fatalf("RouteKind = %q, want anthropic_count_tokens", ev.RouteKind) + } +} + +func TestServerPayloadDiagnosticsNativeCodexEmitsEvent(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + srv := &Server{ + Config: &Config{ + CodexUpstream: "https://chatgpt.com/backend-api/codex", + LocalToken: "tok", + }, + CodexTransport: &CodexTokenTransport{ + Selector: &fakeCodexSelector{account: &codex.CodexAccount{AccessToken: "codex-tok"}}, + Inner: roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex"}`)), + }, nil + }), + }, + PayloadDiag: payloadDiag, + } + + w := httptest.NewRecorder() + reqBody := `{"model":"gpt-5.4","input":"tell me about go"}` + req := httptest.NewRequest(http.MethodPost, "/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "codex-cli/1.0") + srv.handleNativeCodex(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200, body: %s", w.Code, w.Body.String()) + } + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + events := readPayloadEvents(t, path) + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + ev := events[0] + if ev.Provider != "codex" || ev.RouteKind != "codex_native" { + t.Fatalf("event = %+v, want codex/codex_native", ev) + } + if ev.Model != "gpt-5.4" { + t.Fatalf("Model = %q, want gpt-5.4", ev.Model) + } + if ev.ClientKind != "codex" { + t.Fatalf("ClientKind = %q, want codex", ev.ClientKind) + } + if ev.BodyBytes != len(reqBody) { + t.Fatalf("BodyBytes = %d, want %d", ev.BodyBytes, len(reqBody)) + } + assertPayloadLogDoesNotContain(t, path, "codex-tok") +} + +func TestServerPayloadDiagnosticsCompactEmitsEvent(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + srv := &Server{ + Config: &Config{ + CodexUpstream: "https://chatgpt.com/backend-api/codex", + LocalToken: "tok", + }, + CodexTransport: &CodexTokenTransport{ + Selector: &fakeCodexSelector{account: &codex.CodexAccount{AccessToken: "codex-tok"}}, + Inner: roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"object":"response.compact"}`)), + }, nil + }), + }, + PayloadDiag: payloadDiag, + } + + handler, err := srv.handler() + if err != nil { + t.Fatalf("handler() error = %v", err) + } + w := httptest.NewRecorder() + reqBody := `{"model":"gpt-5.4","previous_response_id":"resp_abc"}` + req := httptest.NewRequest(http.MethodPost, codexCompactResponsesPath, strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200, body: %s", w.Code, w.Body.String()) + } + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + events := readPayloadEvents(t, path) + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + ev := events[0] + if ev.Provider != "codex" || ev.RouteKind != "codex_compact" { + t.Fatalf("event = %+v, want codex/codex_compact", ev) + } + if ev.Model != "gpt-5.4" { + t.Fatalf("Model = %q, want gpt-5.4", ev.Model) + } + assertPayloadLogDoesNotContain(t, path, "codex-tok") +} + +func TestServerPayloadDiagnosticsNoEventForBinaryWebSocketFrame(t *testing.T) { + // Binary WebSocket frames are not captured by payload diagnostics. + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + upgrader := websocket.Upgrader{CheckOrigin: func(_ *http.Request) bool { return true }} + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + _, msg, _ := conn.ReadMessage() + _ = conn.WriteMessage(websocket.TextMessage, msg) + })) + defer upstream.Close() + + srv := &Server{ + Config: &Config{ + ClaudeUpstream: "https://api.anthropic.com", + CodexUpstream: upstream.URL, + LocalToken: "local-tok", + }, + CodexUpgradeTransport: &CodexTokenTransport{ + Selector: &fakeCodexSelector{account: &codex.CodexAccount{AccessToken: "codex-tok"}}, + Inner: http.DefaultTransport, + }, + PayloadDiag: payloadDiag, + } + + handler, err := srv.handler() + if err != nil { + t.Fatalf("handler() error = %v", err) + } + proxy := httptest.NewServer(handler) + defer proxy.Close() + + wsURL := "ws" + strings.TrimPrefix(proxy.URL, "http") + legacyCodexResponsesPath + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + if resp != nil { + defer resp.Body.Close() + } + t.Fatalf("Dial() error = %v", err) + } + _ = conn.WriteMessage(websocket.BinaryMessage, []byte("ping")) + conn.ReadMessage() + _ = conn.Close() + + // Allow brief time for any async writes. + time.Sleep(50 * time.Millisecond) + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // No payload event should be emitted for binary WebSocket frames. + if _, err := os.Stat(path); err == nil { + events := readPayloadEvents(t, path) + if len(events) != 0 { + t.Fatalf("binary websocket emitted %d payload events, want 0", len(events)) + } + } +} + +func TestServerPayloadDiagnosticsDistinctParallelSessions(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + future := time.Now().UnixMilli() + 3600_000 + sel := &fakeSelector{accounts: []keyring.ClaudeOAuth{ + {Email: "user@test.com", AccessToken: "real-token", ExpiresAt: future}, + }} + srv := &Server{ + Config: &Config{ + ClaudeUpstream: "https://api.anthropic.com", + LocalToken: "local-tok", + }, + Transport: &TokenTransport{ + Selector: sel, + Inner: roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return makeResponse(http.StatusOK, `{"id":"msg"}`), nil + }), + }, + PayloadDiag: payloadDiag, + } + + sessions := []string{"session-alpha", "session-beta", "session-gamma"} + var wg sync.WaitGroup + for _, sess := range sessions { + sess := sess + wg.Add(1) + go func() { + defer wg.Done() + handler := srv.proxyHandler(mustParseURL(srv.Config.ClaudeUpstream)) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{"model":"claude-sonnet","messages":[]}`)) + req.Header.Set("Authorization", "Bearer local-tok") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Claude-Code-Session-Id", sess) + handler(w, req) + }() + } + wg.Wait() + + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + events := readPayloadEvents(t, path) + if len(events) != len(sessions) { + t.Fatalf("events = %d, want %d", len(events), len(sessions)) + } + + // All session keys must be distinct. + seen := make(map[string]bool) + for _, ev := range events { + if seen[ev.SessionKey] { + t.Fatalf("duplicate session key %q across parallel sessions", ev.SessionKey) + } + seen[ev.SessionKey] = true + } +} + +func TestServerPayloadDiagnosticsDistinctParallelBodySessionsWithIdenticalHeaders(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + payloadDiag, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer payloadDiag.Close() + + future := time.Now().UnixMilli() + 3600_000 + sel := &fakeSelector{accounts: []keyring.ClaudeOAuth{ + {Email: "user@test.com", AccessToken: "real-token", ExpiresAt: future}, + }} + srv := &Server{ + Config: &Config{ + ClaudeUpstream: "https://api.anthropic.com", + LocalToken: "local-tok", + }, + Transport: &TokenTransport{ + Selector: sel, + Inner: roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return makeResponse(http.StatusOK, `{"id":"msg"}`), nil + }), + }, + PayloadDiag: payloadDiag, + } + + conversationIDs := []string{"conv-alpha", "conv-beta", "conv-gamma"} + var wg sync.WaitGroup + for _, conversationID := range conversationIDs { + conversationID := conversationID + wg.Add(1) + go func() { + defer wg.Done() + handler := srv.proxyHandler(mustParseURL(srv.Config.ClaudeUpstream)) + w := httptest.NewRecorder() + body := fmt.Sprintf(`{"model":"claude-sonnet","conversation_id":%q,"messages":[]}`, conversationID) + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer local-tok") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "claude-code/1.0.0") + handler(w, req) + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200, body: %s", w.Code, w.Body.String()) + } + }() + } + wg.Wait() + + if err := payloadDiag.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + events := readPayloadEvents(t, path) + if len(events) != len(conversationIDs) { + t.Fatalf("events = %d, want %d", len(events), len(conversationIDs)) + } + seen := make(map[string]bool) + for _, ev := range events { + if ev.SessionSource != "body:conversation_id" { + t.Fatalf("SessionSource = %q, want body:conversation_id", ev.SessionSource) + } + if ev.SessionKey == "" || !strings.HasPrefix(ev.SessionKey, "body-session:") { + t.Fatalf("SessionKey = %q, want body-session:", ev.SessionKey) + } + if seen[ev.SessionKey] { + t.Fatalf("duplicate session key %q across body-distinguished sessions", ev.SessionKey) + } + seen[ev.SessionKey] = true + } + assertPayloadLogDoesNotContain(t, path, "local-tok") + assertPayloadLogDoesNotContain(t, path, "real-token") +} + +func TestServerPayloadDiagnosticsDisabledNoFile(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + future := time.Now().UnixMilli() + 3600_000 + sel := &fakeSelector{accounts: []keyring.ClaudeOAuth{ + {Email: "user@test.com", AccessToken: "real-token", ExpiresAt: future}, + }} + srv := &Server{ + Config: &Config{ + ClaudeUpstream: "https://api.anthropic.com", + LocalToken: "local-tok", + }, + Transport: &TokenTransport{ + Selector: sel, + Inner: roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return makeResponse(http.StatusOK, `{"id":"msg"}`), nil + }), + }, + // PayloadDiag intentionally nil. + } + + handler := srv.proxyHandler(mustParseURL(srv.Config.ClaudeUpstream)) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{"model":"claude-sonnet","messages":[]}`)) + req.Header.Set("Authorization", "Bearer local-tok") + handler(w, req) + + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("payload file should not exist when PayloadDiag is nil: %v", err) + } +} + +func TestServerHealthReportsPayloadEnabled(t *testing.T) { + for _, tc := range []struct { + name string + enabled bool + }{ + {name: "disabled", enabled: false}, + {name: "enabled", enabled: true}, + } { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(t.TempDir(), "payloads.jsonl") + srv := &Server{} + if tc.enabled { + pw, err := OpenPayloadWriter(path) + if err != nil { + t.Fatalf("OpenPayloadWriter: %v", err) + } + defer pw.Close() + srv.PayloadDiag = pw + } + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/health", nil) + srv.handleHealth(w, req) + + var resp struct { + Diagnostics struct { + Payload bool `json:"payload"` + } `json:"diagnostics"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp.Diagnostics.Payload != tc.enabled { + t.Fatalf("diagnostics.payload = %v, want %v", resp.Diagnostics.Payload, tc.enabled) + } + }) + } +} + +func assertPayloadLogDoesNotContain(t *testing.T, path, needle string) { + t.Helper() + raw, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read payload log: %v", err) + } + if strings.Contains(string(raw), needle) { + t.Fatalf("payload log leaked %q: %s", needle, raw) + } +} + // TestServer_NativeCodex_HeadroomNil_NoCompression verifies that when Headroom // is nil, no compression is attempted and the original body is forwarded. func TestServer_NativeCodex_HeadroomNil_NoCompression(t *testing.T) { diff --git a/internal/proxy/session_affinity.go b/internal/proxy/session_affinity.go new file mode 100644 index 0000000..bc4661b --- /dev/null +++ b/internal/proxy/session_affinity.go @@ -0,0 +1,124 @@ +package proxy + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/jacobcxdev/cq/internal/keyring" +) + +type sessionAffinityStore struct { + mu sync.RWMutex + accounts map[string]string +} + +func newSessionAffinityStore() *sessionAffinityStore { + return &sessionAffinityStore{accounts: make(map[string]string)} +} + +func (s *sessionAffinityStore) remember(sessionKey string, acct *keyring.ClaudeOAuth) { + if s == nil || sessionKey == "" || acct == nil { + return + } + identifier := acctIdentifier(acct) + if identifier == "" { + return + } + s.mu.Lock() + s.accounts[sessionKey] = identifier + s.mu.Unlock() +} + +func (s *sessionAffinityStore) lookup(sessionKey string) (string, bool) { + if s == nil || sessionKey == "" { + return "", false + } + s.mu.RLock() + identifier, ok := s.accounts[sessionKey] + s.mu.RUnlock() + return identifier, ok +} + +type SessionAffinitySelector struct { + inner ClaudeSelector + discover ClaudeDiscoverer + quota QuotaReader + store *sessionAffinityStore +} + +func NewSessionAffinitySelector(inner ClaudeSelector, discover ClaudeDiscoverer, quota QuotaReader) *SessionAffinitySelector { + return &SessionAffinitySelector{ + inner: inner, + discover: discover, + quota: quota, + store: newSessionAffinityStore(), + } +} + +func (s *SessionAffinitySelector) Select(ctx context.Context, exclude ...string) (*keyring.ClaudeOAuth, error) { + if sessionKey, _ := sessionCorrelation(headersFromContext(ctx)); sessionKey != "" { + if acct := s.affinityAccount(sessionKey, exclude); acct != nil { + return acct, nil + } + } + return s.inner.Select(ctx, exclude...) +} + +func (s *SessionAffinitySelector) affinityAccount(sessionKey string, exclude []string) *keyring.ClaudeOAuth { + identifier, ok := s.store.lookup(sessionKey) + if !ok { + return nil + } + excludeSet := make(map[string]bool, len(exclude)) + for _, key := range exclude { + excludeSet[key] = true + } + for _, acct := range s.discover() { + if acctIdentifier(&acct) != identifier || !affinityAccountUsable(&acct, s.quota, excludeSet) { + continue + } + result := acct + return &result + } + return nil +} + +func (s *SessionAffinitySelector) Remember(sessionKey string, acct *keyring.ClaudeOAuth) { + if s == nil { + return + } + s.store.remember(sessionKey, acct) +} + +func affinityAccountUsable(acct *keyring.ClaudeOAuth, quota QuotaReader, excludeSet map[string]bool) bool { + if acct == nil || isExcluded(acct, excludeSet) || acct.AccessToken == "" { + return false + } + if acct.ExpiresAt != 0 && acct.ExpiresAt <= time.Now().UnixMilli() && acct.RefreshToken == "" { + return false + } + if quota == nil { + return true + } + snap, ok := quota.Snapshot(acctIdentifier(acct)) + if !ok || time.Since(snap.FetchedAt) > transientQuotaMaxAge { + return true + } + return snap.Result.MinRemainingPct() != 0 +} + +type requestHeadersContextKey struct{} + +func contextWithRequestHeaders(ctx context.Context, headers http.Header) context.Context { + return context.WithValue(ctx, requestHeadersContextKey{}, headers) +} + +func headersFromContext(ctx context.Context) http.Header { + if ctx == nil { + return nil + } + headers, _ := ctx.Value(requestHeadersContextKey{}).(http.Header) + return headers +} diff --git a/internal/proxy/transport.go b/internal/proxy/transport.go index fa03128..805e91e 100644 --- a/internal/proxy/transport.go +++ b/internal/proxy/transport.go @@ -55,10 +55,12 @@ func (t *TokenTransport) inner() http.RoundTripper { // RoundTrip implements http.RoundTripper. func (t *TokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { - acct, err := t.Selector.Select(req.Context()) + ctx := contextWithRequestHeaders(req.Context(), req.Header) + acct, err := t.Selector.Select(ctx) if err != nil { return nil, err } + req = req.WithContext(ctx) noteRouteAccount(req.Context(), claudeAccountHint(acct), false) // Refresh upfront if token is already expired. @@ -83,6 +85,9 @@ func (t *TokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { case http.StatusTooManyRequests: return t.handle429(req, resp, acct) default: + if resp.StatusCode < 400 { + t.rememberRequestSessionAccount(req, acct) + } t.clearFailoverSuppression() return resp, nil } @@ -187,6 +192,7 @@ func (t *TokenTransport) handle429(req *http.Request, resp *http.Response, faile if t.isConfirmedExhausted(req.Context(), failedAcct) { t.persistSwitch(req.Context(), alt) } + t.rememberRequestSessionAccount(req, alt) last429Resp.Body.Close() if fallbackResp != nil { fallbackResp.Body.Close() @@ -220,6 +226,7 @@ func (t *TokenTransport) handle429(req *http.Request, resp *http.Response, faile if t.isConfirmedExhausted(req.Context(), failedAcct) { t.persistSwitch(req.Context(), alt) } + t.rememberRequestSessionAccount(req, alt) last429Resp.Body.Close() if fallbackResp != nil { fallbackResp.Body.Close() @@ -304,6 +311,20 @@ func (t *TokenTransport) persistSwitch(ctx context.Context, alt *keyring.ClaudeO } } +func (t *TokenTransport) rememberRequestSessionAccount(req *http.Request, acct *keyring.ClaudeOAuth) { + sessionKey, _ := sessionCorrelation(req.Header) + rememberSessionAccount(t.Selector, sessionKey, acct) +} + +func rememberSessionAccount(selector ClaudeSelector, sessionKey string, acct *keyring.ClaudeOAuth) { + switch s := selector.(type) { + case *SessionAffinitySelector: + s.Remember(sessionKey, acct) + case *PinnedClaudeSelector: + rememberSessionAccount(s.inner, sessionKey, acct) + } +} + // transientQuotaMaxAge is the maximum age of a quota snapshot to trust // for exhaustion detection. const transientQuotaMaxAge = 5 * time.Minute diff --git a/internal/proxy/transport_test.go b/internal/proxy/transport_test.go index 02571f0..6472d4c 100644 --- a/internal/proxy/transport_test.go +++ b/internal/proxy/transport_test.go @@ -101,6 +101,95 @@ func TestTokenTransport_HappyPath(t *testing.T) { } } +func TestTokenTransport_SessionAffinityReusesWarmAccount(t *testing.T) { + future := time.Now().UnixMilli() + 3600_000 + accounts := []keyring.ClaudeOAuth{ + {Email: "low@test.com", AccountUUID: "uuid-low", AccessToken: "tok-low", ExpiresAt: future}, + {Email: "high@test.com", AccountUUID: "uuid-high", AccessToken: "tok-high", ExpiresAt: future}, + } + quotaReader := stubQuotaReader{ + "uuid-low": {Result: quotaResult("uuid-low", "low@test.com", 20), FetchedAt: time.Now()}, + "uuid-high": {Result: quotaResult("uuid-high", "high@test.com", 80), FetchedAt: time.Now()}, + } + baseSelector := NewAccountSelector(func() []keyring.ClaudeOAuth { return accounts }, nil, quotaReader) + affinitySelector := NewSessionAffinitySelector(baseSelector, func() []keyring.ClaudeOAuth { return accounts }, quotaReader) + transport := &TokenTransport{ + Selector: affinitySelector, + Inner: roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return makeResponse(200, `{"ok":true}`), nil + }), + } + req := makeRequest(`{"msg":"hello"}`) + req.Header.Set("X-Claude-Code-Session-Id", "session-a") + sessionKey, _ := sessionCorrelation(req.Header) + affinitySelector.Remember(sessionKey, &accounts[0]) + ctx, diag := withRouteDiagnostics(req.Context()) + req = req.WithContext(ctx) + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + accountHint, _ := diag.fields() + wantHint := claudeAccountHint(&accounts[0]) + if accountHint != wantHint { + t.Fatalf("account hint = %q, want warm account hint %q", accountHint, wantHint) + } +} + +func TestTokenTransport_SessionAffinityRecordsThroughPinnedSelector(t *testing.T) { + future := time.Now().UnixMilli() + 3600_000 + accounts := []keyring.ClaudeOAuth{ + {Email: "low@test.com", AccountUUID: "uuid-low", AccessToken: "tok-low", ExpiresAt: future}, + {Email: "high@test.com", AccountUUID: "uuid-high", AccessToken: "tok-high", ExpiresAt: future}, + } + quotaReader := stubQuotaReader{ + "uuid-low": {Result: quotaResult("uuid-low", "low@test.com", 20), FetchedAt: time.Now()}, + "uuid-high": {Result: quotaResult("uuid-high", "high@test.com", 80), FetchedAt: time.Now()}, + } + baseSelector := NewAccountSelector(func() []keyring.ClaudeOAuth { return accounts }, nil, quotaReader) + affinitySelector := NewSessionAffinitySelector(baseSelector, func() []keyring.ClaudeOAuth { return accounts }, quotaReader) + pinnedSelector := NewPinnedClaudeSelector(affinitySelector, func() []keyring.ClaudeOAuth { return accounts }, "", quotaReader) + transport := &TokenTransport{ + Selector: pinnedSelector, + Inner: roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return makeResponse(200, `{"ok":true}`), nil + }), + } + + firstReq := makeRequest(`{"msg":"first"}`) + firstReq.Header.Set("X-Claude-Code-Session-Id", "session-a") + if resp, err := transport.RoundTrip(firstReq); err != nil || resp.StatusCode != 200 { + t.Fatalf("first RoundTrip status=%v err=%v", statusCode(resp), err) + } + quotaReader["uuid-low"] = QuotaSnapshot{Result: quotaResult("uuid-low", "low@test.com", 90), FetchedAt: time.Now()} + quotaReader["uuid-high"] = QuotaSnapshot{Result: quotaResult("uuid-high", "high@test.com", 10), FetchedAt: time.Now()} + + secondReq := makeRequest(`{"msg":"second"}`) + secondReq.Header.Set("X-Claude-Code-Session-Id", "session-a") + ctx, diag := withRouteDiagnostics(secondReq.Context()) + secondReq = secondReq.WithContext(ctx) + if resp, err := transport.RoundTrip(secondReq); err != nil || resp.StatusCode != 200 { + t.Fatalf("second RoundTrip status=%v err=%v", statusCode(resp), err) + } + + accountHint, _ := diag.fields() + wantHint := claudeAccountHint(&accounts[1]) + if accountHint != wantHint { + t.Fatalf("account hint = %q, want first successful account hint %q", accountHint, wantHint) + } +} + +func statusCode(resp *http.Response) int { + if resp == nil { + return 0 + } + return resp.StatusCode +} + func TestTokenTransport_AppendsBeta(t *testing.T) { sel := &fakeSelector{accounts: []keyring.ClaudeOAuth{ {Email: "a@test.com", AccessToken: "tok", ExpiresAt: time.Now().UnixMilli() + 3600_000},