diff --git a/config.example.yaml b/config.example.yaml index dfd7454bd..5df0ab3a3 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -51,6 +51,29 @@ request-retry: 3 # Maximum wait time in seconds for a cooled-down credential before triggering a retry. max-retry-interval: 30 +# Rate limiting configuration for API endpoints (per-process, not distributed). +# Protects upstream from client floods by limiting requests per IP, auth key, and model. +# rate-limit: +# enabled: true +# messages: +# per-ip: +# capacity: 60 # Max burst size per IP +# refill-per-second: 1 # Tokens added per second +# per-auth: +# capacity: 120 # Max burst size per auth/key +# refill-per-second: 2 # Tokens added per second +# per-model: +# capacity: 120 # Max burst size per model +# refill-per-second: 2 # Tokens added per second + +# Circuit breaker configuration for persistent upstream errors. +# Prevents repeated retries on hard 403 errors (CONSUMER_INVALID, SERVICE_DISABLED). +# circuit-breaker: +# enabled: true +# hard-403-cooldown-seconds: 600 # 10 minutes cooldown for CONSUMER_INVALID/SERVICE_DISABLED +# soft-403-cooldown-seconds: 1800 # 30 minutes cooldown for other 403 errors +# hard-403-retry: 0 # No retries for hard 403 errors + # Quota exceeded behavior quota-exceeded: switch-project: true # Whether to automatically switch to another project when a quota is exceeded diff --git a/internal/api/middleware/rate_limit.go b/internal/api/middleware/rate_limit.go new file mode 100644 index 000000000..9897a5839 --- /dev/null +++ b/internal/api/middleware/rate_limit.go @@ -0,0 +1,370 @@ +// Package middleware provides HTTP middleware components for the CLI Proxy API server. +// This file contains the rate limiting middleware that protects upstream services +// from client floods by limiting requests per IP, auth key, and model. +package middleware + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +// LimiterStore defines the interface for rate limiter storage. +// This allows for future implementations using Redis or other distributed stores. +type LimiterStore interface { + // TryConsume attempts to consume a token from the bucket identified by key. + // Returns true if successful, false if rate limited. + // Also returns the time until the next token is available. + TryConsume(key string, capacity int, refillPerSecond float64) (allowed bool, retryAfter time.Duration) +} + +// TokenBucket represents a single token bucket for rate limiting. +type TokenBucket struct { + tokens float64 + lastRefill time.Time + lastAccessed time.Time + mu sync.Mutex +} + +// InMemoryLimiterStore implements LimiterStore using in-memory token buckets. +// This is suitable for single-instance deployments. +type InMemoryLimiterStore struct { + buckets sync.Map // map[string]*TokenBucket + stopCleanup chan struct{} +} + +// NewInMemoryLimiterStore creates a new in-memory rate limiter store. +// It starts a background goroutine to clean up stale buckets. +func NewInMemoryLimiterStore() *InMemoryLimiterStore { + s := &InMemoryLimiterStore{ + stopCleanup: make(chan struct{}), + } + go s.cleanupLoop() + return s +} + +// cleanupLoop periodically removes stale buckets that haven't been accessed recently. +func (s *InMemoryLimiterStore) cleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.evictStaleBuckets(10 * time.Minute) + case <-s.stopCleanup: + return + } + } +} + +// evictStaleBuckets removes buckets that haven't been accessed within the given duration. +func (s *InMemoryLimiterStore) evictStaleBuckets(maxAge time.Duration) { + cutoff := time.Now().Add(-maxAge) + s.buckets.Range(func(key, value any) bool { + bucket := value.(*TokenBucket) + bucket.mu.Lock() + lastAccessed := bucket.lastAccessed + bucket.mu.Unlock() + if lastAccessed.Before(cutoff) { + s.buckets.Delete(key) + } + return true + }) +} + +// Stop stops the cleanup goroutine. +func (s *InMemoryLimiterStore) Stop() { + close(s.stopCleanup) +} + +// TryConsume attempts to consume a token from the bucket. +func (s *InMemoryLimiterStore) TryConsume(key string, capacity int, refillPerSecond float64) (bool, time.Duration) { + if capacity <= 0 || refillPerSecond <= 0 { + return true, 0 + } + + now := time.Now() + bucketI, _ := s.buckets.LoadOrStore(key, &TokenBucket{ + tokens: float64(capacity), + lastRefill: now, + lastAccessed: now, + }) + bucket := bucketI.(*TokenBucket) + + bucket.mu.Lock() + defer bucket.mu.Unlock() + + bucket.lastAccessed = now + elapsed := now.Sub(bucket.lastRefill).Seconds() + bucket.tokens += elapsed * refillPerSecond + if bucket.tokens > float64(capacity) { + bucket.tokens = float64(capacity) + } + bucket.lastRefill = now + + if bucket.tokens >= 1 { + bucket.tokens-- + return true, 0 + } + + tokensNeeded := 1 - bucket.tokens + retryAfter := time.Duration(tokensNeeded/refillPerSecond*1000) * time.Millisecond + return false, retryAfter +} + +// RateLimitDimension indicates which rate limit dimension was exceeded. +type RateLimitDimension string + +const ( + DimensionIP RateLimitDimension = "ip" + DimensionAuth RateLimitDimension = "auth" + DimensionModel RateLimitDimension = "model" +) + +// RateLimitError represents a rate limit exceeded error response. +type RateLimitError struct { + Error string `json:"error"` + Message string `json:"message"` + Dimension RateLimitDimension `json:"dimension"` + RetryAfter float64 `json:"retry_after_seconds"` +} + +// RateLimiter holds the rate limiter configuration and store. +type RateLimiter struct { + store LimiterStore + cfg *config.RateLimitConfig + enabled bool + mu sync.RWMutex +} + +// NewRateLimiter creates a new rate limiter with the given configuration. +func NewRateLimiter(cfg *config.RateLimitConfig) *RateLimiter { + rl := &RateLimiter{ + store: NewInMemoryLimiterStore(), + } + if cfg != nil { + rl.cfg = cfg + rl.enabled = cfg.Enabled + } + return rl +} + +// UpdateConfig updates the rate limiter configuration. +func (rl *RateLimiter) UpdateConfig(cfg *config.RateLimitConfig) { + rl.mu.Lock() + defer rl.mu.Unlock() + if cfg != nil { + rl.cfg = cfg + rl.enabled = cfg.Enabled + } else { + rl.enabled = false + } +} + +// IsEnabled returns whether rate limiting is enabled. +func (rl *RateLimiter) IsEnabled() bool { + rl.mu.RLock() + defer rl.mu.RUnlock() + return rl.enabled +} + +// getConfig returns a copy of the current configuration. +func (rl *RateLimiter) getConfig() *config.RateLimitConfig { + rl.mu.RLock() + defer rl.mu.RUnlock() + return rl.cfg +} + +// RateLimitMiddleware creates a Gin middleware that enforces rate limits. +// It checks per-IP, per-auth, and per-model limits based on configuration. +func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc { + return func(c *gin.Context) { + if rl == nil || !rl.IsEnabled() { + c.Next() + return + } + + cfg := rl.getConfig() + if cfg == nil { + c.Next() + return + } + + ip := c.ClientIP() + authID := extractAuthID(c) + model := extractModelFromRequest(c) + + msgCfg := cfg.Messages + + if msgCfg.PerIP.Capacity > 0 && msgCfg.PerIP.RefillPerSecond > 0 { + key := "ip:" + ip + allowed, retryAfter := rl.store.TryConsume(key, msgCfg.PerIP.Capacity, msgCfg.PerIP.RefillPerSecond) + if !allowed { + logRateLimitBlock(ip, authID, model, DimensionIP) + rejectWithRateLimit(c, DimensionIP, retryAfter) + return + } + } + + if authID != "" && msgCfg.PerAuth.Capacity > 0 && msgCfg.PerAuth.RefillPerSecond > 0 { + key := "auth:" + authID + allowed, retryAfter := rl.store.TryConsume(key, msgCfg.PerAuth.Capacity, msgCfg.PerAuth.RefillPerSecond) + if !allowed { + logRateLimitBlock(ip, authID, model, DimensionAuth) + rejectWithRateLimit(c, DimensionAuth, retryAfter) + return + } + } + + if model != "" && msgCfg.PerModel.Capacity > 0 && msgCfg.PerModel.RefillPerSecond > 0 { + key := "model:" + model + allowed, retryAfter := rl.store.TryConsume(key, msgCfg.PerModel.Capacity, msgCfg.PerModel.RefillPerSecond) + if !allowed { + logRateLimitBlock(ip, authID, model, DimensionModel) + rejectWithRateLimit(c, DimensionModel, retryAfter) + return + } + } + + c.Next() + } +} + +// extractAuthID extracts the auth identifier from the Gin context. +// It looks for auth_id set by AuthMiddleware, or falls back to API key hash. +func extractAuthID(c *gin.Context) string { + if authID, exists := c.Get("auth_id"); exists { + if id, ok := authID.(string); ok && id != "" { + return id + } + } + + if authKey := c.GetHeader("Authorization"); authKey != "" { + authKey = strings.TrimPrefix(authKey, "Bearer ") + if len(authKey) > 8 { + return "key:" + authKey[:8] + } + return "key:" + authKey + } + + if apiKey := c.GetHeader("x-api-key"); apiKey != "" { + if len(apiKey) > 8 { + return "key:" + apiKey[:8] + } + return "key:" + apiKey + } + + return "" +} + +// extractModelFromRequest extracts the model name from the request body. +// For Claude/Anthropic format, it looks for the "model" field. +func extractModelFromRequest(c *gin.Context) string { + if c.Request.Body == nil { + return "unknown" + } + + bodyBytes, exists := c.Get("request_body") + if !exists { + return "unknown" + } + + body, ok := bodyBytes.([]byte) + if !ok || len(body) == 0 { + return "unknown" + } + + model := gjson.GetBytes(body, "model").String() + if model == "" { + return "unknown" + } + return model +} + +// rejectWithRateLimit sends a 429 response with rate limit details. +func rejectWithRateLimit(c *gin.Context, dimension RateLimitDimension, retryAfter time.Duration) { + retrySeconds := retryAfter.Seconds() + if retrySeconds < 1 { + retrySeconds = 1 + } + + c.Header("Retry-After", formatRetryAfter(retryAfter)) + + errResp := RateLimitError{ + Error: "rate_limit_exceeded", + Message: "Rate limit exceeded for " + string(dimension), + Dimension: dimension, + RetryAfter: retrySeconds, + } + + c.AbortWithStatusJSON(http.StatusTooManyRequests, errResp) +} + +// formatRetryAfter formats the retry-after duration as seconds. +func formatRetryAfter(d time.Duration) string { + secs := int(d.Seconds()) + if secs < 1 { + secs = 1 + } + return strconv.Itoa(secs) +} + +// logRateLimitBlock logs a rate limit block event. +func logRateLimitBlock(ip, authID, model string, dimension RateLimitDimension) { + log.WithFields(log.Fields{ + "ip": ip, + "auth_id": authID, + "model": model, + "dimension": string(dimension), + }).Warn("rate limit exceeded") +} + +// RequestBodyCaptureMiddleware captures the request body for later use by rate limiting. +// This must be applied before RateLimitMiddleware to extract the model from the body. +func RequestBodyCaptureMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request.Body == nil { + c.Next() + return + } + + if c.Request.Method != http.MethodPost { + c.Next() + return + } + + body, err := c.GetRawData() + if err != nil { + log.WithError(err).Debug("failed to read request body for rate limiting") + c.Next() + return + } + + c.Set("request_body", body) + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + c.Next() + } +} + +// MarshalJSON implements json.Marshaler for RateLimitError. +func (e RateLimitError) MarshalJSON() ([]byte, error) { + type Alias RateLimitError + return json.Marshal(struct { + Type string `json:"type"` + Alias + }{ + Type: "error", + Alias: Alias(e), + }) +} diff --git a/internal/api/middleware/rate_limit_test.go b/internal/api/middleware/rate_limit_test.go new file mode 100644 index 000000000..32f0d61e3 --- /dev/null +++ b/internal/api/middleware/rate_limit_test.go @@ -0,0 +1,270 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func TestInMemoryLimiterStore_TryConsume(t *testing.T) { + store := NewInMemoryLimiterStore() + + t.Run("allows requests under capacity", func(t *testing.T) { + for i := 0; i < 5; i++ { + allowed, _ := store.TryConsume("test-key-1", 10, 1.0) + if !allowed { + t.Errorf("request %d should be allowed", i) + } + } + }) + + t.Run("blocks requests over capacity", func(t *testing.T) { + key := "test-key-2" + for i := 0; i < 3; i++ { + store.TryConsume(key, 3, 0.1) + } + allowed, retryAfter := store.TryConsume(key, 3, 0.1) + if allowed { + t.Error("request should be blocked") + } + if retryAfter <= 0 { + t.Error("retryAfter should be positive") + } + }) + + t.Run("refills tokens over time", func(t *testing.T) { + key := "test-key-3" + for i := 0; i < 2; i++ { + store.TryConsume(key, 2, 10.0) // High refill rate + } + time.Sleep(150 * time.Millisecond) + allowed, _ := store.TryConsume(key, 2, 10.0) + if !allowed { + t.Error("request should be allowed after refill") + } + }) + + t.Run("zero capacity allows all", func(t *testing.T) { + allowed, _ := store.TryConsume("test-key-4", 0, 1.0) + if !allowed { + t.Error("zero capacity should allow all requests") + } + }) + + t.Run("zero refill rate allows all", func(t *testing.T) { + allowed, _ := store.TryConsume("test-key-5", 10, 0) + if !allowed { + t.Error("zero refill rate should allow all requests") + } + }) +} + +func TestRateLimiter_IsEnabled(t *testing.T) { + t.Run("disabled by default", func(t *testing.T) { + rl := NewRateLimiter(nil) + if rl.IsEnabled() { + t.Error("should be disabled when config is nil") + } + }) + + t.Run("enabled when config says so", func(t *testing.T) { + cfg := &config.RateLimitConfig{Enabled: true} + rl := NewRateLimiter(cfg) + if !rl.IsEnabled() { + t.Error("should be enabled when config.Enabled is true") + } + }) + + t.Run("can update config", func(t *testing.T) { + rl := NewRateLimiter(nil) + cfg := &config.RateLimitConfig{Enabled: true} + rl.UpdateConfig(cfg) + if !rl.IsEnabled() { + t.Error("should be enabled after UpdateConfig") + } + }) +} + +func TestRateLimitMiddleware(t *testing.T) { + cfg := &config.RateLimitConfig{ + Enabled: true, + Messages: config.MessagesRateLimitConfig{ + PerIP: config.TokenBucketConfig{ + Capacity: 2, + RefillPerSecond: 0.1, + }, + }, + } + rl := NewRateLimiter(cfg) + + handler := func(c *gin.Context) { + c.String(http.StatusOK, "OK") + } + + t.Run("allows requests under limit", func(t *testing.T) { + router := gin.New() + router.POST("/v1/messages", RateLimitMiddleware(rl), handler) + + for i := 0; i < 2; i++ { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + req.RemoteAddr = "192.168.1.100:12345" + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("request %d: expected 200, got %d", i, w.Code) + } + } + }) + + t.Run("blocks requests over limit", func(t *testing.T) { + rl2 := NewRateLimiter(&config.RateLimitConfig{ + Enabled: true, + Messages: config.MessagesRateLimitConfig{ + PerIP: config.TokenBucketConfig{ + Capacity: 1, + RefillPerSecond: 0.001, + }, + }, + }) + + router := gin.New() + router.POST("/v1/messages", RateLimitMiddleware(rl2), handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + req.RemoteAddr = "192.168.1.200:12345" + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("first request: expected 200, got %d", w.Code) + } + + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + req.RemoteAddr = "192.168.1.200:12345" + router.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("second request: expected 429, got %d", w.Code) + } + + if w.Header().Get("Retry-After") == "" { + t.Error("Retry-After header should be set") + } + }) + + t.Run("disabled middleware passes through", func(t *testing.T) { + disabledRL := NewRateLimiter(&config.RateLimitConfig{Enabled: false}) + + router := gin.New() + router.POST("/v1/messages", RateLimitMiddleware(disabledRL), handler) + + for i := 0; i < 10; i++ { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("disabled limiter should allow all requests, got %d", w.Code) + } + } + }) +} + +func TestExtractAuthID(t *testing.T) { + t.Run("extracts from Authorization header", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Request.Header.Set("Authorization", "Bearer sk-test-1234567890") + + authID := extractAuthID(c) + if !strings.HasPrefix(authID, "key:") { + t.Errorf("expected key prefix, got %s", authID) + } + }) + + t.Run("extracts from x-api-key header", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Request.Header.Set("x-api-key", "test-api-key-12345") + + authID := extractAuthID(c) + if !strings.HasPrefix(authID, "key:") { + t.Errorf("expected key prefix, got %s", authID) + } + }) + + t.Run("extracts from context auth_id", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Set("auth_id", "custom-auth-id") + + authID := extractAuthID(c) + if authID != "custom-auth-id" { + t.Errorf("expected custom-auth-id, got %s", authID) + } + }) + + t.Run("returns empty for no auth", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + authID := extractAuthID(c) + if authID != "" { + t.Errorf("expected empty, got %s", authID) + } + }) +} + +func TestExtractModelFromRequest(t *testing.T) { + t.Run("extracts model from body", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + body := `{"model": "claude-3-opus"}` + c.Set("request_body", []byte(body)) + + model := extractModelFromRequest(c) + if model != "claude-3-opus" { + t.Errorf("expected claude-3-opus, got %s", model) + } + }) + + t.Run("returns unknown for missing body", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + model := extractModelFromRequest(c) + if model != "unknown" { + t.Errorf("expected unknown, got %s", model) + } + }) + + t.Run("returns unknown for missing model field", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + body := `{"messages": []}` + c.Set("request_body", []byte(body)) + + model := extractModelFromRequest(c) + if model != "unknown" { + t.Errorf("expected unknown, got %s", model) + } + }) +} diff --git a/internal/api/server.go b/internal/api/server.go index 79dcf12a4..2ea48d113 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -168,6 +168,9 @@ type Server struct { keepAliveOnTimeout func() keepAliveHeartbeat chan struct{} keepAliveStop chan struct{} + + // rateLimiter handles rate limiting for API endpoints. + rateLimiter *middleware.RateLimiter } // NewServer creates and initializes a new API server instance. @@ -245,6 +248,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk currentPath: wd, envManagementSecret: envManagementSecret, wsRoutes: make(map[string]struct{}), + rateLimiter: middleware.NewRateLimiter(&cfg.RateLimit), } s.wsAuthEnabled.Store(cfg.WebsocketAuth) // Save initial YAML snapshot @@ -255,6 +259,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk } managementasset.SetCurrentConfig(cfg) auth.SetQuotaCooldownDisabled(cfg.DisableCooling) + // Initialize circuit breaker configuration + auth.SetCircuitBreakerConfig( + cfg.CircuitBreaker.CircuitBreakerEnabled(), + cfg.CircuitBreaker.GetHard403CooldownSeconds(), + cfg.CircuitBreaker.GetSoft403CooldownSeconds(), + cfg.CircuitBreaker.Hard403Retry, + ) // Initialize management handler s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager) if optionState.localPassword != "" { @@ -324,7 +335,7 @@ func (s *Server) setupRoutes() { v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers)) v1.POST("/chat/completions", openaiHandlers.ChatCompletions) v1.POST("/completions", openaiHandlers.Completions) - v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) + v1.POST("/messages", middleware.RequestBodyCaptureMiddleware(), middleware.RateLimitMiddleware(s.rateLimiter), claudeCodeHandlers.ClaudeMessages) v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) v1.POST("/responses", openaiResponsesHandlers.Responses) } @@ -866,6 +877,19 @@ func (s *Server) UpdateClients(cfg *config.Config) { s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second) } + // Update rate limiter configuration + if s.rateLimiter != nil { + s.rateLimiter.UpdateConfig(&cfg.RateLimit) + } + + // Update circuit breaker configuration + auth.SetCircuitBreakerConfig( + cfg.CircuitBreaker.CircuitBreakerEnabled(), + cfg.CircuitBreaker.GetHard403CooldownSeconds(), + cfg.CircuitBreaker.GetSoft403CooldownSeconds(), + cfg.CircuitBreaker.Hard403Retry, + ) + // Update log level dynamically when debug flag changes if oldCfg == nil || oldCfg.Debug != cfg.Debug { util.SetLogLevel(cfg) diff --git a/internal/config/config.go b/internal/config/config.go index 5af74b1b4..9427c8145 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -83,6 +83,12 @@ type Config struct { // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` + // RateLimit configures request rate limiting for API endpoints. + RateLimit RateLimitConfig `yaml:"rate-limit" json:"rate-limit"` + + // CircuitBreaker configures circuit breaker behavior for persistent upstream errors. + CircuitBreaker CircuitBreakerConfig `yaml:"circuit-breaker" json:"circuit-breaker"` + legacyMigrationPending bool `yaml:"-" json:"-"` } @@ -176,6 +182,77 @@ type PayloadModelRule struct { Protocol string `yaml:"protocol" json:"protocol"` } +// RateLimitConfig configures request rate limiting for API endpoints. +type RateLimitConfig struct { + // Enabled toggles rate limiting on/off. Default: false. + Enabled bool `yaml:"enabled" json:"enabled"` + + // Messages configures rate limits specifically for /v1/messages endpoint. + Messages MessagesRateLimitConfig `yaml:"messages" json:"messages"` +} + +// MessagesRateLimitConfig defines rate limit buckets for the /v1/messages endpoint. +type MessagesRateLimitConfig struct { + // PerIP defines the token bucket for per-IP rate limiting. + PerIP TokenBucketConfig `yaml:"per-ip" json:"per-ip"` + + // PerAuth defines the token bucket for per-auth/key rate limiting. + PerAuth TokenBucketConfig `yaml:"per-auth" json:"per-auth"` + + // PerModel defines the token bucket for per-model rate limiting. + PerModel TokenBucketConfig `yaml:"per-model" json:"per-model"` +} + +// TokenBucketConfig defines a token bucket rate limiter configuration. +type TokenBucketConfig struct { + // Capacity is the maximum number of tokens in the bucket. + Capacity int `yaml:"capacity" json:"capacity"` + + // RefillPerSecond is the rate at which tokens are added to the bucket. + RefillPerSecond float64 `yaml:"refill-per-second" json:"refill-per-second"` +} + +// CircuitBreakerConfig configures circuit breaker behavior for persistent upstream errors. +type CircuitBreakerConfig struct { + // Enabled toggles circuit breaker on/off. Default: true. + Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` + + // Hard403CooldownSeconds defines how long to cool down after CONSUMER_INVALID or SERVICE_DISABLED errors. + // Default: 600 (10 minutes). + Hard403CooldownSeconds int `yaml:"hard-403-cooldown-seconds" json:"hard-403-cooldown-seconds"` + + // Soft403CooldownSeconds defines how long to cool down after other 403 errors. + // Default: 1800 (30 minutes). + Soft403CooldownSeconds int `yaml:"soft-403-cooldown-seconds" json:"soft-403-cooldown-seconds"` + + // Hard403Retry defines max retries for hard 403 errors. Default: 0 (no retries). + Hard403Retry int `yaml:"hard-403-retry" json:"hard-403-retry"` +} + +// CircuitBreakerEnabled returns whether the circuit breaker is enabled (defaults to true). +func (c *CircuitBreakerConfig) CircuitBreakerEnabled() bool { + if c.Enabled == nil { + return true + } + return *c.Enabled +} + +// GetHard403CooldownSeconds returns the hard 403 cooldown with default fallback. +func (c *CircuitBreakerConfig) GetHard403CooldownSeconds() int { + if c.Hard403CooldownSeconds <= 0 { + return 600 // 10 minutes default + } + return c.Hard403CooldownSeconds +} + +// GetSoft403CooldownSeconds returns the soft 403 cooldown with default fallback. +func (c *CircuitBreakerConfig) GetSoft403CooldownSeconds() int { + if c.Soft403CooldownSeconds <= 0 { + return 1800 // 30 minutes default + } + return c.Soft403CooldownSeconds +} + // ClaudeKey represents the configuration for a Claude API key, // including the API key itself and an optional base URL for the API endpoint. type ClaudeKey struct { diff --git a/sdk/cliproxy/auth/circuit_breaker.go b/sdk/cliproxy/auth/circuit_breaker.go new file mode 100644 index 000000000..1b5a91b9a --- /dev/null +++ b/sdk/cliproxy/auth/circuit_breaker.go @@ -0,0 +1,226 @@ +package auth + +import ( + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// Default circuit breaker cooldown durations. +const ( + DefaultHard403CooldownSeconds = 600 // 10 minutes + DefaultSoft403CooldownSeconds = 1800 // 30 minutes +) + +// circuitBreakerConfig holds runtime circuit breaker configuration. +type circuitBreakerConfig struct { + mu sync.RWMutex + enabled bool + hard403CooldownSeconds int + soft403CooldownSeconds int + hard403Retry int +} + +var globalCircuitBreakerConfig = &circuitBreakerConfig{ + enabled: true, + hard403CooldownSeconds: DefaultHard403CooldownSeconds, + soft403CooldownSeconds: DefaultSoft403CooldownSeconds, + hard403Retry: 0, +} + +// SetCircuitBreakerConfig updates the global circuit breaker configuration. +func SetCircuitBreakerConfig(enabled bool, hard403CooldownSecs, soft403CooldownSecs, hard403Retry int) { + if hard403CooldownSecs <= 0 { + hard403CooldownSecs = DefaultHard403CooldownSeconds + } + if soft403CooldownSecs <= 0 { + soft403CooldownSecs = DefaultSoft403CooldownSeconds + } + if hard403Retry < 0 { + hard403Retry = 0 + } + globalCircuitBreakerConfig.mu.Lock() + globalCircuitBreakerConfig.enabled = enabled + globalCircuitBreakerConfig.hard403CooldownSeconds = hard403CooldownSecs + globalCircuitBreakerConfig.soft403CooldownSeconds = soft403CooldownSecs + globalCircuitBreakerConfig.hard403Retry = hard403Retry + globalCircuitBreakerConfig.mu.Unlock() +} + +// ClassifyHard403 analyzes an error to determine if it's a hard 403 that should trigger circuit breaker. +// It parses the error message/code looking for CONSUMER_INVALID, SERVICE_DISABLED, or PERMISSION_DENIED. +func ClassifyHard403(err *Error) Hard403Type { + if err == nil { + return Hard403None + } + + if err.HTTPStatus != 403 { + return Hard403None + } + + msg := strings.ToUpper(err.Message) + code := strings.ToUpper(err.Code) + + if strings.Contains(msg, "CONSUMER_INVALID") || strings.Contains(code, "CONSUMER_INVALID") { + return Hard403ConsumerInvalid + } + + if strings.Contains(msg, "SERVICE_DISABLED") || strings.Contains(code, "SERVICE_DISABLED") || + strings.Contains(msg, "HAS NOT BEEN USED IN PROJECT") || + strings.Contains(msg, "IT IS DISABLED") { + return Hard403ServiceDisabled + } + + if strings.Contains(msg, "PERMISSION_DENIED") || strings.Contains(code, "PERMISSION_DENIED") || + strings.Contains(msg, "PERMISSION DENIED ON RESOURCE PROJECT") { + return Hard403PermissionDenied + } + + return Hard403None +} + +// IsHard403 returns true if the error is classified as a hard 403. +func IsHard403(err *Error) bool { + return ClassifyHard403(err) != Hard403None +} + +// OpenCircuitBreaker opens the circuit breaker for an auth credential. +func OpenCircuitBreaker(auth *Auth, reason Hard403Type, now time.Time) { + if auth == nil || reason == Hard403None { + return + } + + globalCircuitBreakerConfig.mu.RLock() + enabled := globalCircuitBreakerConfig.enabled + cooldownSecs := globalCircuitBreakerConfig.hard403CooldownSeconds + globalCircuitBreakerConfig.mu.RUnlock() + + if !enabled { + return + } + + cooldownDuration := time.Duration(cooldownSecs) * time.Second + + auth.CircuitBreaker.Open = true + auth.CircuitBreaker.Reason = reason + auth.CircuitBreaker.CooldownUntil = now.Add(cooldownDuration) + auth.CircuitBreaker.OpenedAt = now + auth.CircuitBreaker.FailureCount++ + + log.WithFields(log.Fields{ + "auth_id": auth.ID, + "provider": auth.Provider, + "reason": string(reason), + "cooldown_until": auth.CircuitBreaker.CooldownUntil.Format(time.RFC3339), + "failure_count": auth.CircuitBreaker.FailureCount, + }).Warn("circuit breaker opened for hard 403") +} + +// CloseCircuitBreaker closes the circuit breaker for an auth credential. +func CloseCircuitBreaker(auth *Auth) { + if auth == nil { + return + } + + if auth.CircuitBreaker.Open { + log.WithFields(log.Fields{ + "auth_id": auth.ID, + "provider": auth.Provider, + "reason": string(auth.CircuitBreaker.Reason), + }).Info("circuit breaker closed") + } + + auth.CircuitBreaker.Open = false + auth.CircuitBreaker.Reason = Hard403None + auth.CircuitBreaker.CooldownUntil = time.Time{} + auth.CircuitBreaker.FailureCount = 0 +} + +// IsCircuitBreakerOpen returns true if the circuit breaker is open and cooldown has not expired. +// Note: This function only reads auth state and does NOT auto-close expired circuit breakers. +// Use CheckAndCloseExpiredCircuitBreaker for auto-close behavior when holding a write lock. +func IsCircuitBreakerOpen(auth *Auth, now time.Time) bool { + if auth == nil { + return false + } + + globalCircuitBreakerConfig.mu.RLock() + enabled := globalCircuitBreakerConfig.enabled + globalCircuitBreakerConfig.mu.RUnlock() + + if !enabled { + return false + } + + if !auth.CircuitBreaker.Open { + return false + } + + if now.After(auth.CircuitBreaker.CooldownUntil) { + return false + } + + return true +} + +// CheckAndCloseExpiredCircuitBreaker checks if the circuit breaker cooldown expired and closes it. +// Returns true if the circuit breaker is still open, false if closed or expired. +// Call this only when holding a write lock on the auth. +func CheckAndCloseExpiredCircuitBreaker(auth *Auth, now time.Time) bool { + if auth == nil { + return false + } + + globalCircuitBreakerConfig.mu.RLock() + enabled := globalCircuitBreakerConfig.enabled + globalCircuitBreakerConfig.mu.RUnlock() + + if !enabled { + return false + } + + if !auth.CircuitBreaker.Open { + return false + } + + if now.After(auth.CircuitBreaker.CooldownUntil) { + CloseCircuitBreaker(auth) + return false + } + + return true +} + +// ShouldRetryHard403 returns true if hard 403 retries are allowed. +func ShouldRetryHard403() bool { + globalCircuitBreakerConfig.mu.RLock() + retry := globalCircuitBreakerConfig.hard403Retry + globalCircuitBreakerConfig.mu.RUnlock() + return retry > 0 +} + +// GetHard403MaxRetries returns the maximum number of retries for hard 403 errors. +func GetHard403MaxRetries() int { + globalCircuitBreakerConfig.mu.RLock() + retry := globalCircuitBreakerConfig.hard403Retry + globalCircuitBreakerConfig.mu.RUnlock() + return retry +} + +// GetSoft403Cooldown returns the cooldown duration for soft 403 errors. +func GetSoft403Cooldown() time.Duration { + globalCircuitBreakerConfig.mu.RLock() + secs := globalCircuitBreakerConfig.soft403CooldownSeconds + globalCircuitBreakerConfig.mu.RUnlock() + return time.Duration(secs) * time.Second +} + +// GetHard403Cooldown returns the cooldown duration for hard 403 errors. +func GetHard403Cooldown() time.Duration { + globalCircuitBreakerConfig.mu.RLock() + secs := globalCircuitBreakerConfig.hard403CooldownSeconds + globalCircuitBreakerConfig.mu.RUnlock() + return time.Duration(secs) * time.Second +} diff --git a/sdk/cliproxy/auth/circuit_breaker_test.go b/sdk/cliproxy/auth/circuit_breaker_test.go new file mode 100644 index 000000000..b967a5870 --- /dev/null +++ b/sdk/cliproxy/auth/circuit_breaker_test.go @@ -0,0 +1,300 @@ +package auth + +import ( + "testing" + "time" +) + +func TestClassifyHard403(t *testing.T) { + tests := []struct { + name string + err *Error + expected Hard403Type + }{ + { + name: "nil error", + err: nil, + expected: Hard403None, + }, + { + name: "non-403 error", + err: &Error{HTTPStatus: 401, Message: "unauthorized"}, + expected: Hard403None, + }, + { + name: "generic 403", + err: &Error{HTTPStatus: 403, Message: "access denied"}, + expected: Hard403None, + }, + { + name: "CONSUMER_INVALID in message", + err: &Error{HTTPStatus: 403, Message: "CONSUMER_INVALID: project not valid"}, + expected: Hard403ConsumerInvalid, + }, + { + name: "CONSUMER_INVALID in code", + err: &Error{HTTPStatus: 403, Code: "CONSUMER_INVALID"}, + expected: Hard403ConsumerInvalid, + }, + { + name: "SERVICE_DISABLED in message", + err: &Error{HTTPStatus: 403, Message: "SERVICE_DISABLED: API not enabled"}, + expected: Hard403ServiceDisabled, + }, + { + name: "service not used in project", + err: &Error{HTTPStatus: 403, Message: "Gemini API has not been used in project before or it is disabled"}, + expected: Hard403ServiceDisabled, + }, + { + name: "PERMISSION_DENIED in message", + err: &Error{HTTPStatus: 403, Message: "Permission denied on resource project abc-123"}, + expected: Hard403PermissionDenied, + }, + { + name: "PERMISSION_DENIED in code", + err: &Error{HTTPStatus: 403, Code: "PERMISSION_DENIED"}, + expected: Hard403PermissionDenied, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClassifyHard403(tt.err) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestIsHard403(t *testing.T) { + t.Run("returns true for hard 403", func(t *testing.T) { + err := &Error{HTTPStatus: 403, Message: "CONSUMER_INVALID"} + if !IsHard403(err) { + t.Error("expected true for CONSUMER_INVALID") + } + }) + + t.Run("returns false for soft 403", func(t *testing.T) { + err := &Error{HTTPStatus: 403, Message: "access denied"} + if IsHard403(err) { + t.Error("expected false for generic 403") + } + }) +} + +func TestOpenCircuitBreaker(t *testing.T) { + SetCircuitBreakerConfig(true, 600, 1800, 0) + + t.Run("opens circuit breaker", func(t *testing.T) { + auth := &Auth{ID: "test-auth", Provider: "gemini"} + now := time.Now() + + OpenCircuitBreaker(auth, Hard403ConsumerInvalid, now) + + if !auth.CircuitBreaker.Open { + t.Error("circuit breaker should be open") + } + if auth.CircuitBreaker.Reason != Hard403ConsumerInvalid { + t.Errorf("expected CONSUMER_INVALID, got %v", auth.CircuitBreaker.Reason) + } + if auth.CircuitBreaker.FailureCount != 1 { + t.Errorf("expected failure count 1, got %d", auth.CircuitBreaker.FailureCount) + } + expectedCooldown := now.Add(600 * time.Second) + if !auth.CircuitBreaker.CooldownUntil.Equal(expectedCooldown) { + t.Errorf("expected cooldown until %v, got %v", expectedCooldown, auth.CircuitBreaker.CooldownUntil) + } + }) + + t.Run("increments failure count", func(t *testing.T) { + auth := &Auth{ID: "test-auth-2", Provider: "gemini"} + now := time.Now() + + OpenCircuitBreaker(auth, Hard403ConsumerInvalid, now) + OpenCircuitBreaker(auth, Hard403ConsumerInvalid, now.Add(time.Minute)) + + if auth.CircuitBreaker.FailureCount != 2 { + t.Errorf("expected failure count 2, got %d", auth.CircuitBreaker.FailureCount) + } + }) + + t.Run("does nothing for Hard403None", func(t *testing.T) { + auth := &Auth{ID: "test-auth-3", Provider: "gemini"} + now := time.Now() + + OpenCircuitBreaker(auth, Hard403None, now) + + if auth.CircuitBreaker.Open { + t.Error("circuit breaker should not open for Hard403None") + } + }) + + t.Run("does nothing when disabled", func(t *testing.T) { + SetCircuitBreakerConfig(false, 600, 1800, 0) + defer SetCircuitBreakerConfig(true, 600, 1800, 0) + + auth := &Auth{ID: "test-auth-4", Provider: "gemini"} + now := time.Now() + + OpenCircuitBreaker(auth, Hard403ConsumerInvalid, now) + + if auth.CircuitBreaker.Open { + t.Error("circuit breaker should not open when disabled") + } + }) +} + +func TestCloseCircuitBreaker(t *testing.T) { + t.Run("closes circuit breaker", func(t *testing.T) { + auth := &Auth{ + ID: "test-auth", + Provider: "gemini", + CircuitBreaker: CircuitBreakerState{ + Open: true, + Reason: Hard403ConsumerInvalid, + CooldownUntil: time.Now().Add(time.Hour), + FailureCount: 3, + }, + } + + CloseCircuitBreaker(auth) + + if auth.CircuitBreaker.Open { + t.Error("circuit breaker should be closed") + } + if auth.CircuitBreaker.Reason != Hard403None { + t.Errorf("expected Hard403None, got %v", auth.CircuitBreaker.Reason) + } + if auth.CircuitBreaker.FailureCount != 0 { + t.Errorf("expected failure count 0, got %d", auth.CircuitBreaker.FailureCount) + } + }) +} + +func TestIsCircuitBreakerOpen(t *testing.T) { + SetCircuitBreakerConfig(true, 600, 1800, 0) + + t.Run("returns true when open and not expired", func(t *testing.T) { + auth := &Auth{ + ID: "test-auth", + Provider: "gemini", + CircuitBreaker: CircuitBreakerState{ + Open: true, + CooldownUntil: time.Now().Add(time.Hour), + }, + } + + if !IsCircuitBreakerOpen(auth, time.Now()) { + t.Error("expected circuit breaker to be open") + } + }) + + t.Run("returns false when expired", func(t *testing.T) { + auth := &Auth{ + ID: "test-auth-2", + Provider: "gemini", + CircuitBreaker: CircuitBreakerState{ + Open: true, + CooldownUntil: time.Now().Add(-time.Minute), + }, + } + + if IsCircuitBreakerOpen(auth, time.Now()) { + t.Error("expected circuit breaker to report as not open after expiry") + } + // IsCircuitBreakerOpen no longer auto-closes; it just returns false for expired breakers + }) + + t.Run("CheckAndCloseExpiredCircuitBreaker closes expired breaker", func(t *testing.T) { + auth := &Auth{ + ID: "test-auth-close", + Provider: "gemini", + CircuitBreaker: CircuitBreakerState{ + Open: true, + Reason: Hard403ConsumerInvalid, + CooldownUntil: time.Now().Add(-time.Minute), + FailureCount: 1, + }, + } + + if CheckAndCloseExpiredCircuitBreaker(auth, time.Now()) { + t.Error("expected CheckAndCloseExpiredCircuitBreaker to return false for expired breaker") + } + if auth.CircuitBreaker.Open { + t.Error("circuit breaker should have been closed") + } + }) + + t.Run("returns false when not open", func(t *testing.T) { + auth := &Auth{ + ID: "test-auth-3", + Provider: "gemini", + CircuitBreaker: CircuitBreakerState{ + Open: false, + }, + } + + if IsCircuitBreakerOpen(auth, time.Now()) { + t.Error("expected circuit breaker to be closed") + } + }) + + t.Run("returns false when disabled", func(t *testing.T) { + SetCircuitBreakerConfig(false, 600, 1800, 0) + defer SetCircuitBreakerConfig(true, 600, 1800, 0) + + auth := &Auth{ + ID: "test-auth-4", + Provider: "gemini", + CircuitBreaker: CircuitBreakerState{ + Open: true, + CooldownUntil: time.Now().Add(time.Hour), + }, + } + + if IsCircuitBreakerOpen(auth, time.Now()) { + t.Error("expected circuit breaker check to return false when disabled") + } + }) +} + +func TestShouldRetryHard403(t *testing.T) { + t.Run("returns false when hard403Retry is 0", func(t *testing.T) { + SetCircuitBreakerConfig(true, 600, 1800, 0) + if ShouldRetryHard403() { + t.Error("expected false when hard403Retry is 0") + } + }) + + t.Run("returns true when hard403Retry > 0", func(t *testing.T) { + SetCircuitBreakerConfig(true, 600, 1800, 1) + defer SetCircuitBreakerConfig(true, 600, 1800, 0) + + if !ShouldRetryHard403() { + t.Error("expected true when hard403Retry > 0") + } + }) +} + +func TestGetHard403MaxRetries(t *testing.T) { + SetCircuitBreakerConfig(true, 600, 1800, 3) + defer SetCircuitBreakerConfig(true, 600, 1800, 0) + + if GetHard403MaxRetries() != 3 { + t.Errorf("expected 3, got %d", GetHard403MaxRetries()) + } +} + +func TestGetCooldownDurations(t *testing.T) { + SetCircuitBreakerConfig(true, 300, 900, 0) + defer SetCircuitBreakerConfig(true, 600, 1800, 0) + + if GetHard403Cooldown() != 300*time.Second { + t.Errorf("expected 300s, got %v", GetHard403Cooldown()) + } + if GetSoft403Cooldown() != 900*time.Second { + t.Errorf("expected 900s, got %v", GetSoft403Cooldown()) + } +} diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index dc7887e73..eaadc886a 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -632,9 +632,24 @@ func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, pro if maxWait <= 0 { return 0, false } - if status := statusCodeFromError(err); status == http.StatusOK { + status := statusCodeFromError(err) + if status == http.StatusOK { return 0, false } + + // Check if this is a hard 403 error - don't retry unless configured + if status == http.StatusForbidden { + if authErr := errorToAuthError(err); authErr != nil && IsHard403(authErr) { + if !ShouldRetryHard403() { + return 0, false + } + // Limit retries for hard 403s + if attempt >= GetHard403MaxRetries() { + return 0, false + } + } + } + wait, found := m.closestCooldownWait(providers, model) if !found || wait > maxWait { return 0, false @@ -642,6 +657,30 @@ func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, pro return wait, true } +// errorToAuthError attempts to extract an *Error from a generic error. +func errorToAuthError(err error) *Error { + if err == nil { + return nil + } + if authErr, ok := err.(*Error); ok { + return authErr + } + // Check if the error has StatusCode method + type statusCoder interface { + StatusCode() int + } + type messager interface { + Error() string + } + if sc, ok := err.(statusCoder); ok { + return &Error{ + HTTPStatus: sc.StatusCode(), + Message: err.Error(), + } + } + return nil +} + func waitForCooldown(ctx context.Context, wait time.Duration) error { if wait <= 0 { return nil @@ -721,6 +760,8 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { auth.UpdatedAt = now shouldResumeModel = true clearModelQuota = true + // Close circuit breaker on successful model request + CloseCircuitBreaker(auth) } else { clearAuthStateOnSuccess(auth, now) } @@ -738,54 +779,76 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { } statusCode := statusCodeFromResult(result.Error) - switch statusCode { - case 401: - next := now.Add(30 * time.Minute) - state.NextRetryAfter = next - suspendReason = "unauthorized" - shouldSuspendModel = true - case 402, 403: - next := now.Add(30 * time.Minute) - state.NextRetryAfter = next - suspendReason = "payment_required" - shouldSuspendModel = true - case 404: - next := now.Add(12 * time.Hour) - state.NextRetryAfter = next - suspendReason = "not_found" - shouldSuspendModel = true - case 429: - var next time.Time - backoffLevel := state.Quota.BackoffLevel - if result.RetryAfter != nil { - next = now.Add(*result.RetryAfter) - } else { - cooldown, nextLevel := nextQuotaCooldown(backoffLevel) - if cooldown > 0 { - next = now.Add(cooldown) - } - backoffLevel = nextLevel - } - state.NextRetryAfter = next - state.Quota = QuotaState{ - Exceeded: true, - Reason: "quota", - NextRecoverAt: next, - BackoffLevel: backoffLevel, + handledHard403 := false + + // Check for hard 403 errors and open circuit breaker at auth level + if statusCode == 403 { + hard403Type := ClassifyHard403(result.Error) + if hard403Type != Hard403None { + OpenCircuitBreaker(auth, hard403Type, now) + auth.StatusMessage = string(hard403Type) + auth.NextRetryAfter = auth.CircuitBreaker.CooldownUntil + state.NextRetryAfter = auth.CircuitBreaker.CooldownUntil + suspendReason = string(hard403Type) + shouldSuspendModel = true + auth.Status = StatusError + auth.UpdatedAt = now + updateAggregatedAvailability(auth, now) + handledHard403 = true } - suspendReason = "quota" - shouldSuspendModel = true - setModelQuota = true - case 408, 500, 502, 503, 504: - next := now.Add(1 * time.Minute) - state.NextRetryAfter = next - default: - state.NextRetryAfter = time.Time{} } - auth.Status = StatusError - auth.UpdatedAt = now - updateAggregatedAvailability(auth, now) + if !handledHard403 { + switch statusCode { + case 401: + next := now.Add(30 * time.Minute) + state.NextRetryAfter = next + suspendReason = "unauthorized" + shouldSuspendModel = true + case 402, 403: + // Soft 403 (not classified as hard 403 above) + next := now.Add(GetSoft403Cooldown()) + state.NextRetryAfter = next + suspendReason = "payment_required" + shouldSuspendModel = true + case 404: + next := now.Add(12 * time.Hour) + state.NextRetryAfter = next + suspendReason = "not_found" + shouldSuspendModel = true + case 429: + var next time.Time + backoffLevel := state.Quota.BackoffLevel + if result.RetryAfter != nil { + next = now.Add(*result.RetryAfter) + } else { + cooldown, nextLevel := nextQuotaCooldown(backoffLevel) + if cooldown > 0 { + next = now.Add(cooldown) + } + backoffLevel = nextLevel + } + state.NextRetryAfter = next + state.Quota = QuotaState{ + Exceeded: true, + Reason: "quota", + NextRecoverAt: next, + BackoffLevel: backoffLevel, + } + suspendReason = "quota" + shouldSuspendModel = true + setModelQuota = true + case 408, 500, 502, 503, 504: + next := now.Add(1 * time.Minute) + state.NextRetryAfter = next + default: + state.NextRetryAfter = time.Time{} + } + + auth.Status = StatusError + auth.UpdatedAt = now + updateAggregatedAvailability(auth, now) + } } else { applyAuthFailureState(auth, result.Error, result.RetryAfter, now) } @@ -933,6 +996,7 @@ func clearAuthStateOnSuccess(auth *Auth, now time.Time) { auth.LastError = nil auth.NextRetryAfter = time.Time{} auth.UpdatedAt = now + CloseCircuitBreaker(auth) } func cloneError(err *Error) *Error { @@ -1001,13 +1065,25 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati } } statusCode := statusCodeFromResult(resultErr) + + // Check for hard 403 errors and open circuit breaker + if statusCode == 403 { + hard403Type := ClassifyHard403(resultErr) + if hard403Type != Hard403None { + OpenCircuitBreaker(auth, hard403Type, now) + auth.StatusMessage = string(hard403Type) + auth.NextRetryAfter = auth.CircuitBreaker.CooldownUntil + return + } + } + switch statusCode { case 401: auth.StatusMessage = "unauthorized" auth.NextRetryAfter = now.Add(30 * time.Minute) case 402, 403: auth.StatusMessage = "payment_required" - auth.NextRetryAfter = now.Add(30 * time.Minute) + auth.NextRetryAfter = now.Add(GetSoft403Cooldown()) case 404: auth.StatusMessage = "not_found" auth.NextRetryAfter = now.Add(12 * time.Hour) @@ -1091,6 +1167,8 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli candidates := make([]*Auth, 0, len(m.auths)) modelKey := strings.TrimSpace(model) registryRef := registry.GetGlobalRegistry() + now := time.Now() + skippedCircuitBreaker := 0 for _, candidate := range m.auths { if candidate.Provider != provider || candidate.Disabled { continue @@ -1098,6 +1176,11 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli if _, used := tried[candidate.ID]; used { continue } + // Skip auths with open circuit breakers + if IsCircuitBreakerOpen(candidate, now) { + skippedCircuitBreaker++ + continue + } if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) { continue } @@ -1105,6 +1188,9 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli } if len(candidates) == 0 { m.mu.RUnlock() + if skippedCircuitBreaker > 0 { + return nil, nil, &Error{Code: "auth_unavailable", Message: "all auths have open circuit breakers", HTTPStatus: 503} + } return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"} } selected, errPick := m.selector.Pick(ctx, provider, model, opts, candidates) diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 25e88b96e..2a57ed1d8 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -56,6 +56,9 @@ type Auth struct { // ModelStates tracks per-model runtime availability data. ModelStates map[string]*ModelState `json:"model_states,omitempty"` + // CircuitBreaker tracks circuit breaker state for persistent 403 errors. + CircuitBreaker CircuitBreakerState `json:"circuit_breaker,omitempty"` + // Runtime carries non-serialisable data used during execution (in-memory only). Runtime any `json:"-"` @@ -74,6 +77,34 @@ type QuotaState struct { BackoffLevel int `json:"backoff_level,omitempty"` } +// Hard403Type classifies persistent 403 errors for circuit breaker logic. +type Hard403Type string + +const ( + // Hard403None indicates the error is not a hard 403. + Hard403None Hard403Type = "" + // Hard403ConsumerInvalid indicates CONSUMER_INVALID error (bad key/project). + Hard403ConsumerInvalid Hard403Type = "CONSUMER_INVALID" + // Hard403ServiceDisabled indicates SERVICE_DISABLED error (API not enabled). + Hard403ServiceDisabled Hard403Type = "SERVICE_DISABLED" + // Hard403PermissionDenied indicates PERMISSION_DENIED error. + Hard403PermissionDenied Hard403Type = "PERMISSION_DENIED" +) + +// CircuitBreakerState tracks circuit breaker status for an auth credential. +type CircuitBreakerState struct { + // Open indicates the circuit breaker is open (blocking requests). + Open bool `json:"open"` + // Reason describes why the circuit breaker opened. + Reason Hard403Type `json:"reason,omitempty"` + // CooldownUntil is when the circuit breaker may close. + CooldownUntil time.Time `json:"cooldown_until"` + // OpenedAt records when the circuit breaker opened. + OpenedAt time.Time `json:"opened_at"` + // FailureCount tracks consecutive hard 403 failures. + FailureCount int `json:"failure_count"` +} + // ModelState captures the execution state for a specific model under an auth entry. type ModelState struct { // Status reflects the lifecycle status for this model.