diff --git a/sdk/go/ai/README.md b/sdk/go/ai/README.md index dc80d364..5f8d8bcb 100644 --- a/sdk/go/ai/README.md +++ b/sdk/go/ai/README.md @@ -10,6 +10,7 @@ This package provides AI/LLM capabilities for the AgentField Go SDK, supporting - ✅ **Type-Safe**: Automatic conversion from Go structs to JSON schemas - ✅ **Functional Options**: Clean, idiomatic Go API with functional options pattern - ✅ **Automatic Configuration**: Reads from environment variables by default +- ✅ **Rate Limiting**: Built-in exponential backoff and circuit breaker for production resilience (see Rate Limiting section below) ## Quick Start @@ -233,12 +234,83 @@ if err := response.Into(&result); err != nil { } ``` +## Rate Limiting + +The Go SDK includes built-in rate limiting with exponential backoff and circuit breaker patterns for production resilience. + +### Configuration + +Rate limiting is **enabled by default** with sensible defaults: + +```go +config := ai.DefaultConfig() +// Uses default rate limiting: +// - MaxRetries: 5 +// - BaseDelay: 1 second +// - MaxDelay: 30 seconds +// - JitterFactor: 0.1 +// - CircuitBreakerThreshold: 5 consecutive failures +// - CircuitBreakerTimeout: 60 seconds +``` + +### Custom Configuration + +```go +config := &ai.Config{ + APIKey: os.Getenv("OPENAI_API_KEY"), + Model: "gpt-4o", + + // Custom rate limiting + RateLimitMaxRetries: 10, + RateLimitBaseDelay: 500 * time.Millisecond, + RateLimitMaxDelay: 60 * time.Second, + RateLimitJitterFactor: 0.2, + CircuitBreakerThreshold: 3, + CircuitBreakerTimeout: 30 * time.Second, +} +``` + +### Disable Rate Limiting + +```go +config := ai.DefaultConfig() +config.DisableRateLimiter = true // Disable rate limiting completely +``` + +### How It Works + +**Exponential Backoff**: Delays increase exponentially (1s → 2s → 4s → 8s...) +**Jitter**: Adds randomness to prevent thundering herd +**Circuit Breaker**: Opens after N consecutive failures, prevents cascade +**Automatic Detection**: Identifies rate limit errors from status codes and error messages + +### Thread Safety + +The AI client and rate limiter are safe for concurrent use by multiple goroutines: + +```go +agent, _ := agent.New(config) + +// Safe to call from multiple goroutines +var wg sync.WaitGroup +for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + response, err := agent.AI(ctx, "Hello") + // Process response + }() +} +wg.Wait() +``` + ## Performance Considerations 1. **Connection Pooling**: The HTTP client uses connection pooling for efficient requests 2. **Context Cancellation**: Always use contexts with timeouts for AI calls 3. **Streaming**: Use streaming for long responses to improve perceived latency 4. **Model Selection**: Choose appropriate models for your use case (faster models = lower latency) +5. **Rate Limiting**: Built-in rate limiting handles API throttling automatically ## Examples diff --git a/sdk/go/ai/client.go b/sdk/go/ai/client.go index 076dfad3..8a668e5f 100644 --- a/sdk/go/ai/client.go +++ b/sdk/go/ai/client.go @@ -12,8 +12,9 @@ import ( // Client provides AI/LLM capabilities using OpenAI or OpenRouter API. type Client struct { - config *Config - httpClient *http.Client + config *Config + httpClient *http.Client + rateLimiter *RateLimiter } // NewClient creates a new AI client with the given configuration. @@ -26,8 +27,22 @@ func NewClient(config *Config) (*Client, error) { return nil, fmt.Errorf("invalid config: %w", err) } + // Initialize rate limiter if not disabled + var rateLimiter *RateLimiter + if !config.DisableRateLimiter { + rateLimiter = NewRateLimiter(RateLimiterConfig{ + MaxRetries: config.RateLimitMaxRetries, + BaseDelay: config.RateLimitBaseDelay, + MaxDelay: config.RateLimitMaxDelay, + JitterFactor: config.RateLimitJitterFactor, + CircuitBreakerThreshold: config.CircuitBreakerThreshold, + CircuitBreakerTimeout: config.CircuitBreakerTimeout, + }) + } + return &Client{ - config: config, + config: config, + rateLimiter: rateLimiter, httpClient: &http.Client{ Timeout: config.Timeout, }, @@ -53,7 +68,13 @@ func (c *Client) Complete(ctx context.Context, prompt string, opts ...Option) (* } } - // Make HTTP request + // Make HTTP request with rate limiting + if c.rateLimiter != nil { + return c.rateLimiter.ExecuteWithRetry(ctx, func() (*Response, error) { + return c.doRequest(ctx, req) + }) + } + return c.doRequest(ctx, req) } @@ -73,6 +94,13 @@ func (c *Client) CompleteWithMessages(ctx context.Context, messages []Message, o } } + // Make HTTP request with rate limiting + if c.rateLimiter != nil { + return c.rateLimiter.ExecuteWithRetry(ctx, func() (*Response, error) { + return c.doRequest(ctx, req) + }) + } + return c.doRequest(ctx, req) } @@ -144,6 +172,43 @@ func (c *Client) doRequest(ctx context.Context, req *Request) (*Response, error) // StreamComplete makes a streaming chat completion request. // Returns a channel of response chunks. func (c *Client) StreamComplete(ctx context.Context, prompt string, opts ...Option) (<-chan StreamChunk, <-chan error) { + // Build request with streaming enabled + opts = append(opts, WithStream()) + req := &Request{ + Messages: []Message{ + {Role: "user", Content: prompt}, + }, + Model: c.config.Model, + Temperature: &c.config.Temperature, + MaxTokens: &c.config.MaxTokens, + Stream: true, + } + + // Apply options + for _, opt := range opts { + // If option application fails, return error channels immediately + if err := opt(req); err != nil { + chunkCh := make(chan StreamChunk) + errCh := make(chan error, 1) + close(chunkCh) + errCh <- fmt.Errorf("apply option: %w", err) + close(errCh) + return chunkCh, errCh + } + } + + // Use rate limiter if enabled + if c.rateLimiter != nil { + return c.rateLimiter.ExecuteStreamWithRetry(ctx, func() (<-chan StreamChunk, <-chan error) { + return c.doStreamRequest(ctx, req) + }) + } + + return c.doStreamRequest(ctx, req) +} + +// doStreamRequest executes the streaming HTTP request. +func (c *Client) doStreamRequest(ctx context.Context, req *Request) (<-chan StreamChunk, <-chan error) { chunkCh := make(chan StreamChunk) errCh := make(chan error, 1) @@ -151,26 +216,6 @@ func (c *Client) StreamComplete(ctx context.Context, prompt string, opts ...Opti defer close(chunkCh) defer close(errCh) - // Build request with streaming enabled - opts = append(opts, WithStream()) - req := &Request{ - Messages: []Message{ - {Role: "user", Content: prompt}, - }, - Model: c.config.Model, - Temperature: &c.config.Temperature, - MaxTokens: &c.config.MaxTokens, - Stream: true, - } - - // Apply options - for _, opt := range opts { - if err := opt(req); err != nil { - errCh <- fmt.Errorf("apply option: %w", err) - return - } - } - // Marshal request body, err := json.Marshal(req) if err != nil { diff --git a/sdk/go/ai/config.go b/sdk/go/ai/config.go index 73328fb3..29587c15 100644 --- a/sdk/go/ai/config.go +++ b/sdk/go/ai/config.go @@ -33,6 +33,15 @@ type Config struct { // Optional: Site name for OpenRouter rankings SiteName string + + // Rate Limiter Configuration + RateLimitMaxRetries int // Maximum number of retry attempts (default: 5) + RateLimitBaseDelay time.Duration // Base delay for exponential backoff (default: 1s) + RateLimitMaxDelay time.Duration // Maximum delay between retries (default: 30s) + RateLimitJitterFactor float64 // Jitter factor 0.0-1.0 (default: 0.1) + CircuitBreakerThreshold int // Consecutive failures before opening circuit (default: 5) + CircuitBreakerTimeout time.Duration // Time before attempting to close circuit (default: 60s) + DisableRateLimiter bool // Disable rate limiting completely (default: false) } // DefaultConfig returns a Config with sensible defaults. @@ -67,6 +76,15 @@ func DefaultConfig() *Config { Temperature: 0.7, MaxTokens: 4096, Timeout: 30 * time.Second, + + // Rate Limiter Defaults + RateLimitMaxRetries: 5, + RateLimitBaseDelay: time.Second, + RateLimitMaxDelay: 30 * time.Second, + RateLimitJitterFactor: 0.1, + CircuitBreakerThreshold: 5, + CircuitBreakerTimeout: 60 * time.Second, + DisableRateLimiter: false, } } diff --git a/sdk/go/ai/rate_limiter.go b/sdk/go/ai/rate_limiter.go new file mode 100644 index 00000000..450a3de9 --- /dev/null +++ b/sdk/go/ai/rate_limiter.go @@ -0,0 +1,422 @@ +package ai + +import ( + "context" + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "math" + "math/rand" + "os" + "strconv" + "strings" + "sync" + "time" +) + +// RateLimitError represents an error due to rate limiting or circuit breaker. +var ErrRateLimitExceeded = errors.New("rate limit retries exhausted") +var ErrCircuitOpen = errors.New("circuit breaker is open") + +// CircuitState represents the state of the circuit breaker. +type CircuitState int + +const ( + // CircuitClosed means requests are allowed. + CircuitClosed CircuitState = iota + // CircuitOpen means requests are blocked. + CircuitOpen + // CircuitHalfOpen means a test request is allowed. + CircuitHalfOpen +) + +// String returns the string representation of CircuitState. +func (s CircuitState) String() string { + switch s { + case CircuitClosed: + return "Closed" + case CircuitOpen: + return "Open" + case CircuitHalfOpen: + return "HalfOpen" + default: + return "Unknown" + } +} + +// RateLimiter provides exponential backoff retry logic with circuit breaker pattern. +// It is safe for concurrent use by multiple goroutines. +type RateLimiter struct { + maxRetries int + baseDelay time.Duration + maxDelay time.Duration + jitterFactor float64 + circuitBreakerThreshold int + circuitBreakerTimeout time.Duration + + // Callbacks for observability + onCircuitOpen func() + onCircuitClose func() + + // Circuit breaker state (protected by mu) + mu sync.Mutex + consecutiveFailures int + circuitOpenTime *time.Time + containerSeed int64 +} + +// NewRateLimiter creates a new RateLimiter with the given configuration. +// It initializes the rate limiter with default values if not provided. +// Configuration values are used as-is without applying defaults. +func NewRateLimiter(config RateLimiterConfig) *RateLimiter { + return &RateLimiter{ + maxRetries: config.MaxRetries, + baseDelay: config.BaseDelay, + maxDelay: config.MaxDelay, + jitterFactor: config.JitterFactor, + circuitBreakerThreshold: config.CircuitBreakerThreshold, + circuitBreakerTimeout: config.CircuitBreakerTimeout, + onCircuitOpen: config.OnCircuitOpen, + onCircuitClose: config.OnCircuitClose, + containerSeed: getContainerSeed(), + } +} + +// RateLimiterConfig holds configuration for the rate limiter. +type RateLimiterConfig struct { + MaxRetries int // Maximum number of retry attempts + BaseDelay time.Duration // Base delay for exponential backoff + MaxDelay time.Duration // Maximum delay between retries + JitterFactor float64 // Jitter factor (0.0-1.0) to prevent thundering herd + CircuitBreakerThreshold int // Number of consecutive failures before opening circuit + CircuitBreakerTimeout time.Duration // Time to wait before attempting to close circuit + OnCircuitOpen func() // Callback when circuit opens (optional) + OnCircuitClose func() // Callback when circuit closes (optional) +} + +// getContainerSeed generates a container-specific seed for consistent jitter distribution. +func getContainerSeed() int64 { + hostname := os.Getenv("HOSTNAME") + if hostname == "" { + hostname = "localhost" + } + pid := os.Getpid() + identifier := fmt.Sprintf("%s-%d", hostname, pid) + + hash := md5.Sum([]byte(identifier)) + hexStr := hex.EncodeToString(hash[:]) + seed, _ := strconv.ParseInt(hexStr[:8], 16, 64) + return seed +} + +// isRateLimitError checks if an error is a rate limit error. +func isRateLimitError(err error) bool { + if err == nil { + return false + } + + errMsg := strings.ToLower(err.Error()) + + // Check for common rate limit keywords + keywords := []string{ + "rate limit", + "rate-limit", + "rate_limit", + "too many requests", + "quota exceeded", + "temporarily rate-limited", + "rate limited", + "requests per", + "rpm exceeded", + "tpm exceeded", + "usage limit", + "throttled", + "throttling", + "429", // HTTP 429 status + "503", // HTTP 503 status (service unavailable, often due to rate limits) + } + + for _, keyword := range keywords { + if strings.Contains(errMsg, keyword) { + return true + } + } + + return false +} + +// calculateBackoffDelay calculates the delay with exponential backoff and jitter. +func (rl *RateLimiter) calculateBackoffDelay(attempt int) time.Duration { + // Exponential backoff: baseDelay * (2 ^ attempt) + exponent := math.Pow(2, float64(attempt)) + backoffDelay := time.Duration(float64(rl.baseDelay) * exponent) + + // Cap at max delay + if backoffDelay > rl.maxDelay { + backoffDelay = rl.maxDelay + } + + // Add jitter to distribute load + // Use time-based randomness combined with container seed for true randomness + // while maintaining some distribution across containers + rng := rand.New(rand.NewSource(rl.containerSeed + int64(attempt) + time.Now().UnixNano())) + jitterRange := float64(backoffDelay) * rl.jitterFactor + jitter := (rng.Float64()*2 - 1) * jitterRange // Random value between -jitterRange and +jitterRange + + delay := time.Duration(float64(backoffDelay) + jitter) + + // Ensure minimum delay + if delay < 100*time.Millisecond { + delay = 100 * time.Millisecond + } + + return delay +} + +// checkCircuitBreaker checks if the circuit breaker is open. +// Must be called with mu held. +func (rl *RateLimiter) checkCircuitBreaker() CircuitState { + if rl.circuitOpenTime == nil { + return CircuitClosed + } + + // Check if circuit breaker timeout has passed + if time.Since(*rl.circuitOpenTime) > rl.circuitBreakerTimeout { + // Timeout passed - enter half-open state + return CircuitHalfOpen + } + + return CircuitOpen +} + +// updateCircuitBreaker updates the circuit breaker state based on operation result. +func (rl *RateLimiter) updateCircuitBreaker(success bool) { + rl.mu.Lock() + wasOpen := rl.circuitOpenTime != nil + + if success { + // Reset on success + rl.consecutiveFailures = 0 + if rl.circuitOpenTime != nil { + rl.circuitOpenTime = nil + rl.mu.Unlock() + // Trigger callback outside the lock + if rl.onCircuitClose != nil { + rl.onCircuitClose() + } + return + } + } else { + // Increment failures + rl.consecutiveFailures++ + + // Open circuit if threshold reached + if rl.consecutiveFailures >= rl.circuitBreakerThreshold && !wasOpen { + now := time.Now() + rl.circuitOpenTime = &now + rl.mu.Unlock() + // Trigger callback outside the lock + if rl.onCircuitOpen != nil { + rl.onCircuitOpen() + } + return + } + } + + rl.mu.Unlock() +} + +// GetCircuitState returns the current state of the circuit breaker. +func (rl *RateLimiter) GetCircuitState() CircuitState { + rl.mu.Lock() + defer rl.mu.Unlock() + return rl.checkCircuitBreaker() +} + +// GetConsecutiveFailures returns the current count of consecutive failures. +func (rl *RateLimiter) GetConsecutiveFailures() int { + rl.mu.Lock() + defer rl.mu.Unlock() + return rl.consecutiveFailures +} + +// ExecuteWithRetry executes a function with rate limit retry logic. +func (rl *RateLimiter) ExecuteWithRetry(ctx context.Context, fn func() (*Response, error)) (*Response, error) { + // Check circuit breaker + rl.mu.Lock() + circuitState := rl.checkCircuitBreaker() + rl.mu.Unlock() + + if circuitState == CircuitOpen { + return nil, fmt.Errorf("%w: too many consecutive rate limit failures, will retry after %v", + ErrCircuitOpen, rl.circuitBreakerTimeout) + } + + var lastErr error + + for attempt := 0; attempt <= rl.maxRetries; attempt++ { + // Check context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Execute the function + result, err := fn() + + if err == nil { + // Success - update circuit breaker and return + rl.updateCircuitBreaker(true) + return result, nil + } + + lastErr = err + + // Check if this is a rate limit error + if !isRateLimitError(err) { + // Not a rate limit error - return immediately + return nil, err + } + + // Update circuit breaker for rate limit failure + rl.updateCircuitBreaker(false) + + // Check if we've exceeded max retries + if attempt >= rl.maxRetries { + break + } + + // Calculate backoff delay + delay := rl.calculateBackoffDelay(attempt) + + // Wait before retry + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + // Continue to next attempt + } + } + + // All retries exhausted + return nil, fmt.Errorf("%w: after %d attempts: %v", ErrRateLimitExceeded, rl.maxRetries+1, lastErr) +} + +// ExecuteStreamWithRetry executes a streaming function with rate limit retry logic. +func (rl *RateLimiter) ExecuteStreamWithRetry(ctx context.Context, fn func() (<-chan StreamChunk, <-chan error)) (<-chan StreamChunk, <-chan error) { + chunkCh := make(chan StreamChunk) + errCh := make(chan error, 1) + + go func() { + defer close(chunkCh) + defer close(errCh) + + // Check circuit breaker + rl.mu.Lock() + circuitState := rl.checkCircuitBreaker() + rl.mu.Unlock() + + if circuitState == CircuitOpen { + errCh <- fmt.Errorf("%w: too many consecutive rate limit failures, will retry after %v", + ErrCircuitOpen, rl.circuitBreakerTimeout) + return + } + + var lastErr error + sentAnyChunks := false + + for attempt := 0; attempt <= rl.maxRetries; attempt++ { + // Check context cancellation + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + default: + } + + // Execute the streaming function + resultChunkCh, resultErrCh := fn() + + // Forward chunks and capture errors - consume both channels until closed + streamErr := error(nil) + ch := resultChunkCh + ech := resultErrCh + + for ch != nil || ech != nil { + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + case chunk, ok := <-ch: + if !ok { + ch = nil + continue + } + sentAnyChunks = true + chunkCh <- chunk + case err, ok := <-ech: + if !ok { + ech = nil + continue + } + if err != nil { + streamErr = err + } + } + } + + // Now check if there was an error after both channels are closed + if streamErr != nil { + lastErr = streamErr + + // Check if this is a rate limit error + if !isRateLimitError(streamErr) { + // Not a rate limit error - forward and return + errCh <- streamErr + return + } + + // If we've already sent chunks, we cannot safely retry as it would duplicate content + if sentAnyChunks { + errCh <- fmt.Errorf("%w: partial content already sent, cannot retry: %v", ErrRateLimitExceeded, streamErr) + return + } + + // Update circuit breaker for rate limit failure + rl.updateCircuitBreaker(false) + + // Break inner loop to retry + goto retry + } + + // Stream completed successfully + rl.updateCircuitBreaker(true) + return + + retry: + // Check if we've exceeded max retries + if attempt >= rl.maxRetries { + break + } + + // Calculate backoff delay + delay := rl.calculateBackoffDelay(attempt) + + // Wait before retry + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + case <-time.After(delay): + // Continue to next attempt + } + } + + // All retries exhausted + errCh <- fmt.Errorf("%w: after %d attempts: %v", ErrRateLimitExceeded, rl.maxRetries+1, lastErr) + }() + + return chunkCh, errCh +} diff --git a/sdk/go/ai/rate_limiter_test.go b/sdk/go/ai/rate_limiter_test.go new file mode 100644 index 00000000..d5473460 --- /dev/null +++ b/sdk/go/ai/rate_limiter_test.go @@ -0,0 +1,981 @@ +package ai + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" +) + +// Test error types +var errRateLimit = errors.New("rate limit exceeded: 429 Too Many Requests") +var errNotRateLimit = errors.New("some other error") +var errQuotaExceeded = errors.New("quota exceeded for this month") +var errThrottling = errors.New("throttling: requests per minute exceeded") + +func TestNewRateLimiter(t *testing.T) { + tests := []struct { + name string + config RateLimiterConfig + check func(*testing.T, *RateLimiter) + }{ + { + name: "zero values are respected", + config: RateLimiterConfig{ + MaxRetries: 0, + BaseDelay: 0, + MaxDelay: 0, + JitterFactor: 0, + CircuitBreakerThreshold: 0, + CircuitBreakerTimeout: 0, + }, + check: func(t *testing.T, rl *RateLimiter) { + if rl.maxRetries != 0 { + t.Errorf("Expected maxRetries=0, got %d", rl.maxRetries) + } + if rl.baseDelay != 0 { + t.Errorf("Expected baseDelay=0, got %v", rl.baseDelay) + } + if rl.maxDelay != 0 { + t.Errorf("Expected maxDelay=0, got %v", rl.maxDelay) + } + if rl.jitterFactor != 0 { + t.Errorf("Expected jitterFactor=0, got %f", rl.jitterFactor) + } + if rl.circuitBreakerThreshold != 0 { + t.Errorf("Expected circuitBreakerThreshold=0, got %d", rl.circuitBreakerThreshold) + } + if rl.circuitBreakerTimeout != 0 { + t.Errorf("Expected circuitBreakerTimeout=0, got %v", rl.circuitBreakerTimeout) + } + }, + }, + { + name: "custom values", + config: RateLimiterConfig{ + MaxRetries: 10, + BaseDelay: 500 * time.Millisecond, + MaxDelay: 10 * time.Second, + JitterFactor: 0.2, + CircuitBreakerThreshold: 3, + CircuitBreakerTimeout: 30 * time.Second, + }, + check: func(t *testing.T, rl *RateLimiter) { + if rl.maxRetries != 10 { + t.Errorf("Expected maxRetries=10, got %d", rl.maxRetries) + } + if rl.baseDelay != 500*time.Millisecond { + t.Errorf("Expected baseDelay=500ms, got %v", rl.baseDelay) + } + if rl.maxDelay != 10*time.Second { + t.Errorf("Expected maxDelay=10s, got %v", rl.maxDelay) + } + if rl.jitterFactor != 0.2 { + t.Errorf("Expected jitterFactor=0.2, got %f", rl.jitterFactor) + } + if rl.circuitBreakerThreshold != 3 { + t.Errorf("Expected circuitBreakerThreshold=3, got %d", rl.circuitBreakerThreshold) + } + if rl.circuitBreakerTimeout != 30*time.Second { + t.Errorf("Expected circuitBreakerTimeout=30s, got %v", rl.circuitBreakerTimeout) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rl := NewRateLimiter(tt.config) + tt.check(t, rl) + }) + } +} + +func TestIsRateLimitError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "rate limit error with 429", + err: errRateLimit, + expected: true, + }, + { + name: "quota exceeded error", + err: errQuotaExceeded, + expected: true, + }, + { + name: "throttling error", + err: errThrottling, + expected: true, + }, + { + name: "non-rate-limit error", + err: errNotRateLimit, + expected: false, + }, + { + name: "error with 'too many requests'", + err: errors.New("too many requests please try again"), + expected: true, + }, + { + name: "error with 'rate limited'", + err: errors.New("you have been rate limited"), + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isRateLimitError(tt.err) + if result != tt.expected { + t.Errorf("Expected %v, got %v for error: %v", tt.expected, result, tt.err) + } + }) + } +} + +func TestCalculateBackoffDelay(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + BaseDelay: time.Second, + MaxDelay: 30 * time.Second, + JitterFactor: 0.1, + }) + + tests := []struct { + attempt int + minExpected time.Duration + maxExpected time.Duration + }{ + {attempt: 0, minExpected: 800 * time.Millisecond, maxExpected: 1200 * time.Millisecond}, // ~1s ± 10% + {attempt: 1, minExpected: 1800 * time.Millisecond, maxExpected: 2200 * time.Millisecond}, // ~2s ± 10% + {attempt: 2, minExpected: 3600 * time.Millisecond, maxExpected: 4400 * time.Millisecond}, // ~4s ± 10% + {attempt: 3, minExpected: 7200 * time.Millisecond, maxExpected: 8800 * time.Millisecond}, // ~8s ± 10% + {attempt: 10, minExpected: 27 * time.Second, maxExpected: 33 * time.Second}, // Capped at 30s ± 10% (with jitter) + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("attempt_%d", tt.attempt), func(t *testing.T) { + delay := rl.calculateBackoffDelay(tt.attempt) + if delay < tt.minExpected || delay > tt.maxExpected { + t.Errorf("Attempt %d: expected delay between %v and %v, got %v", + tt.attempt, tt.minExpected, tt.maxExpected, delay) + } + }) + } +} + +func TestCircuitBreakerStates(t *testing.T) { + t.Run("initially closed", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + CircuitBreakerThreshold: 3, + }) + + state := rl.checkCircuitBreaker() + if state != CircuitClosed { + t.Errorf("Expected CircuitClosed, got %v", state) + } + }) + + t.Run("opens after threshold failures", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + CircuitBreakerThreshold: 3, + }) + + // Simulate failures + for i := 0; i < 3; i++ { + rl.updateCircuitBreaker(false) + } + + state := rl.checkCircuitBreaker() + if state != CircuitOpen { + t.Errorf("Expected CircuitOpen after %d failures, got %v", 3, state) + } + }) + + t.Run("resets on success", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + CircuitBreakerThreshold: 3, + }) + + // Simulate failures + rl.updateCircuitBreaker(false) + rl.updateCircuitBreaker(false) + + // Success should reset + rl.updateCircuitBreaker(true) + + if rl.consecutiveFailures != 0 { + t.Errorf("Expected consecutiveFailures=0 after success, got %d", rl.consecutiveFailures) + } + }) + + t.Run("enters half-open after timeout", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + CircuitBreakerThreshold: 2, + CircuitBreakerTimeout: 100 * time.Millisecond, + }) + + // Open the circuit + rl.updateCircuitBreaker(false) + rl.updateCircuitBreaker(false) + + if rl.checkCircuitBreaker() != CircuitOpen { + t.Error("Circuit should be open") + } + + // Wait for timeout + time.Sleep(150 * time.Millisecond) + + state := rl.checkCircuitBreaker() + if state != CircuitHalfOpen { + t.Errorf("Expected CircuitHalfOpen after timeout, got %v", state) + } + }) +} + +func TestExecuteWithRetry_Success(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + result, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + callCount++ + return &Response{}, nil + }) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if result == nil { + t.Error("Expected result, got nil") + } + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +} + +func TestExecuteWithRetry_NonRateLimitError(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + _, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + callCount++ + return nil, errNotRateLimit + }) + + if err == nil { + t.Error("Expected error, got nil") + } + if !errors.Is(err, errNotRateLimit) { + t.Errorf("Expected errNotRateLimit, got %v", err) + } + if callCount != 1 { + t.Errorf("Expected 1 call (no retry for non-rate-limit error), got %d", callCount) + } +} + +func TestExecuteWithRetry_RateLimitThenSuccess(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + result, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + callCount++ + if callCount < 3 { + return nil, errRateLimit + } + return &Response{}, nil + }) + + if err != nil { + t.Errorf("Expected no error after retries, got %v", err) + } + if result == nil { + t.Error("Expected result, got nil") + } + if callCount != 3 { + t.Errorf("Expected 3 calls, got %d", callCount) + } +} + +func TestExecuteWithRetry_MaxRetriesExceeded(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 2, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + _, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + callCount++ + return nil, errRateLimit + }) + + if err == nil { + t.Error("Expected error, got nil") + } + if !errors.Is(err, ErrRateLimitExceeded) { + t.Errorf("Expected ErrRateLimitExceeded, got %v", err) + } + if callCount != 3 { // maxRetries=2 means 3 total attempts (initial + 2 retries) + t.Errorf("Expected 3 calls, got %d", callCount) + } +} + +func TestExecuteWithRetry_CircuitBreakerOpen(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 2, + BaseDelay: 10 * time.Millisecond, + CircuitBreakerThreshold: 2, + CircuitBreakerTimeout: 10 * time.Second, // Long timeout to keep circuit open + }) + + ctx := context.Background() + + // Trigger circuit breaker by failing multiple times + _, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + return nil, errRateLimit + }) + + // Verify we got max retries error (not checking circuit status yet) + if !errors.Is(err, ErrRateLimitExceeded) { + t.Errorf("Expected ErrRateLimitExceeded, got %v", err) + } + + // Circuit should now be open (after 3 consecutive rate limit failures) + callCount := 0 + _, err = rl.ExecuteWithRetry(ctx, func() (*Response, error) { + callCount++ + t.Error("Function should not be called when circuit is open") + return nil, nil + }) + + if callCount != 0 { + t.Errorf("Expected 0 calls when circuit is open, got %d", callCount) + } + if err == nil { + t.Error("Expected error, got nil") + } + if !errors.Is(err, ErrCircuitOpen) { + t.Errorf("Expected ErrCircuitOpen, got %v", err) + } +} + +func TestExecuteWithRetry_ContextCancellation(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 5, + BaseDelay: 100 * time.Millisecond, + }) + + ctx, cancel := context.WithCancel(context.Background()) + callCount := 0 + + // Cancel after first call + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + _, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + callCount++ + return nil, errRateLimit + }) + + if err == nil { + t.Error("Expected error, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled, got %v", err) + } + // Should have made at least one call before cancellation + if callCount < 1 { + t.Errorf("Expected at least 1 call, got %d", callCount) + } +} + +func TestCircuitStateString(t *testing.T) { + tests := []struct { + state CircuitState + expected string + }{ + {CircuitClosed, "Closed"}, + {CircuitOpen, "Open"}, + {CircuitHalfOpen, "HalfOpen"}, + {CircuitState(99), "Unknown"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + result := tt.state.String() + if result != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, result) + } + }) + } +} + +func TestExecuteWithRetry_BackoffTiming(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + JitterFactor: 0.0, // No jitter for predictable timing + CircuitBreakerThreshold: 10, // High threshold to prevent interference + }) + + ctx := context.Background() + attempts := []time.Time{} + + _, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + attempts = append(attempts, time.Now()) + return nil, errRateLimit + }) + + if err == nil { + t.Error("Expected error after max retries") + } + + if len(attempts) != 4 { // Initial + 3 retries + t.Errorf("Expected 4 attempts, got %d", len(attempts)) + } + + // Check backoff timing (allowing generous tolerance for timing variance) + if len(attempts) >= 2 { + delay1 := attempts[1].Sub(attempts[0]) + // First retry: baseDelay * 2^0 = 100ms, but allow 90-500ms for variance + if delay1 < 90*time.Millisecond || delay1 > 500*time.Millisecond { + t.Logf("First retry delay: %v (acceptable range)", delay1) + } + } + + if len(attempts) >= 3 { + delay2 := attempts[2].Sub(attempts[1]) + // Second retry: baseDelay * 2^1 = 200ms, but timing can vary + if delay2 < 90*time.Millisecond { + t.Errorf("Second retry delay too short: %v", delay2) + } + // Log but don't fail on upper bound - timing is approximate + if delay2 > 500*time.Millisecond { + t.Logf("Second retry delay: %v (higher than expected but acceptable)", delay2) + } + } +} + +func TestGetContainerSeed(t *testing.T) { + seed1 := getContainerSeed() + seed2 := getContainerSeed() + + if seed1 != seed2 { + t.Error("Container seed should be consistent") + } + + if seed1 == 0 { + t.Error("Container seed should not be zero") + } +} + +func TestRateLimitError_VarNames(t *testing.T) { + // Test that error variables are defined and usable + if ErrRateLimitExceeded == nil { + t.Error("ErrRateLimitExceeded should not be nil") + } + if ErrCircuitOpen == nil { + t.Error("ErrCircuitOpen should not be nil") + } + + // Test that errors have descriptive messages + if !strings.Contains(ErrRateLimitExceeded.Error(), "rate limit") { + t.Errorf("ErrRateLimitExceeded should mention rate limit, got: %v", ErrRateLimitExceeded) + } + if !strings.Contains(ErrCircuitOpen.Error(), "circuit") { + t.Errorf("ErrCircuitOpen should mention circuit, got: %v", ErrCircuitOpen) + } +} + +func TestUpdateCircuitBreaker(t *testing.T) { + t.Run("consecutive failures increment counter", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + CircuitBreakerThreshold: 5, + }) + + rl.updateCircuitBreaker(false) + if rl.consecutiveFailures != 1 { + t.Errorf("Expected 1 failure, got %d", rl.consecutiveFailures) + } + + rl.updateCircuitBreaker(false) + if rl.consecutiveFailures != 2 { + t.Errorf("Expected 2 failures, got %d", rl.consecutiveFailures) + } + }) + + t.Run("success resets counter and closes circuit", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + CircuitBreakerThreshold: 2, + }) + + // Open circuit + rl.updateCircuitBreaker(false) + rl.updateCircuitBreaker(false) + + if rl.circuitOpenTime == nil { + t.Error("Circuit should be open") + } + + // Success should reset and close + rl.updateCircuitBreaker(true) + + if rl.consecutiveFailures != 0 { + t.Errorf("Expected 0 failures after success, got %d", rl.consecutiveFailures) + } + if rl.circuitOpenTime != nil { + t.Error("Circuit should be closed after success") + } + }) +} + +func TestExecuteWithRetry_EdgeCases(t *testing.T) { + t.Run("immediate success on first attempt", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{MaxRetries: 3}) + ctx := context.Background() + + start := time.Now() + result, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + return &Response{}, nil + }) + + duration := time.Since(start) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if result == nil { + t.Error("Expected result") + } + // Should complete quickly without delays + if duration > 100*time.Millisecond { + t.Errorf("Should complete immediately, took %v", duration) + } + }) + + t.Run("alternating rate limit and non-rate-limit errors", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + }) + ctx := context.Background() + callCount := 0 + + _, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + callCount++ + if callCount == 1 { + return nil, errRateLimit // Should retry + } + return nil, errNotRateLimit // Should fail immediately + }) + + if callCount != 2 { + t.Errorf("Expected 2 calls (rate limit retry then immediate fail), got %d", callCount) + } + if !errors.Is(err, errNotRateLimit) { + t.Errorf("Expected errNotRateLimit, got %v", err) + } + }) + + t.Run("zero max retries still attempts once", func(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 0, + CircuitBreakerThreshold: 10, // High threshold to prevent circuit breaker from interfering + }) + ctx := context.Background() + callCount := 0 + + _, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + callCount++ + return nil, errRateLimit + }) + + if callCount != 1 { + t.Errorf("Expected 1 call (initial attempt), got %d", callCount) + } + if !errors.Is(err, ErrRateLimitExceeded) { + t.Errorf("Expected ErrRateLimitExceeded, got %v", err) + } + }) +} + +func TestRateLimiter_ConcurrentAccess(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + CircuitBreakerThreshold: 10, // High threshold to avoid circuit opening during test + }) + + ctx := context.Background() + var wg sync.WaitGroup + successCount := 0 + var mu sync.Mutex + + // Run 100 concurrent requests + for i := 0; i < 100; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Alternate between success and failure to test concurrent state updates + result, err := rl.ExecuteWithRetry(ctx, func() (*Response, error) { + if id%2 == 0 { + return &Response{}, nil + } + return nil, errNotRateLimit // Non-rate-limit error for faster test + }) + + if err == nil && result != nil { + mu.Lock() + successCount++ + mu.Unlock() + } + }(i) + } + + wg.Wait() + + // Should have ~50 successes (even IDs) + if successCount < 45 || successCount > 55 { + t.Errorf("Expected ~50 successes, got %d", successCount) + } + + // Should not panic or race + t.Logf("Concurrent access test passed with %d successes", successCount) +} + +func TestExecuteStreamWithRetry_Success(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + chunkCh, errCh := rl.ExecuteStreamWithRetry(ctx, func() (<-chan StreamChunk, <-chan error) { + callCount++ + ch := make(chan StreamChunk, 3) + ec := make(chan error) + + go func() { + defer close(ch) + defer close(ec) + ch <- StreamChunk{Choices: []StreamDelta{{Delta: MessageDelta{Content: "Hello"}}}} + ch <- StreamChunk{Choices: []StreamDelta{{Delta: MessageDelta{Content: " World"}}}} + }() + + return ch, ec + }) + + var content strings.Builder + for chunk := range chunkCh { + if len(chunk.Choices) > 0 { + content.WriteString(chunk.Choices[0].Delta.Content) + } + } + + err := <-errCh + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if content.String() != "Hello World" { + t.Errorf("Expected 'Hello World', got '%s'", content.String()) + } + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +} + +func TestExecuteStreamWithRetry_RateLimitThenSuccess(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + chunkCh, errCh := rl.ExecuteStreamWithRetry(ctx, func() (<-chan StreamChunk, <-chan error) { + callCount++ + ch := make(chan StreamChunk) + ec := make(chan error, 1) + + go func() { + defer close(ch) + defer close(ec) + + if callCount < 3 { + ec <- errRateLimit + return + } + + ch <- StreamChunk{Choices: []StreamDelta{{Delta: MessageDelta{Content: "Success"}}}} + }() + + return ch, ec + }) + + var content strings.Builder + for chunk := range chunkCh { + if len(chunk.Choices) > 0 { + content.WriteString(chunk.Choices[0].Delta.Content) + } + } + + err := <-errCh + if err != nil { + t.Errorf("Expected no error after retries, got %v", err) + } + if content.String() != "Success" { + t.Errorf("Expected 'Success', got '%s'", content.String()) + } + if callCount != 3 { + t.Errorf("Expected 3 calls (2 rate limits + success), got %d", callCount) + } +} + +func TestExecuteStreamWithRetry_MaxRetriesExceeded(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 2, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + chunkCh, errCh := rl.ExecuteStreamWithRetry(ctx, func() (<-chan StreamChunk, <-chan error) { + callCount++ + ch := make(chan StreamChunk) + ec := make(chan error, 1) + + go func() { + defer close(ch) + defer close(ec) + ec <- errRateLimit + }() + + return ch, ec + }) + + // Consume all chunks (should be none) + for range chunkCh { + } + + err := <-errCh + if err == nil { + t.Error("Expected error after max retries") + } + if !errors.Is(err, ErrRateLimitExceeded) { + t.Errorf("Expected ErrRateLimitExceeded, got %v", err) + } + if callCount != 3 { // maxRetries=2 means 3 total attempts + t.Errorf("Expected 3 calls, got %d", callCount) + } +} + +func TestExecuteStreamWithRetry_NonRateLimitError(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + chunkCh, errCh := rl.ExecuteStreamWithRetry(ctx, func() (<-chan StreamChunk, <-chan error) { + callCount++ + ch := make(chan StreamChunk) + ec := make(chan error, 1) + + go func() { + defer close(ch) + defer close(ec) + ec <- errNotRateLimit + }() + + return ch, ec + }) + + // Consume all chunks + for range chunkCh { + } + + err := <-errCh + if err == nil { + t.Error("Expected error") + } + if !errors.Is(err, errNotRateLimit) { + t.Errorf("Expected errNotRateLimit, got %v", err) + } + if callCount != 1 { + t.Errorf("Expected 1 call (no retry for non-rate-limit error), got %d", callCount) + } +} + +func TestGetCircuitState(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + CircuitBreakerThreshold: 2, + CircuitBreakerTimeout: 100 * time.Millisecond, + }) + + // Initially closed + if state := rl.GetCircuitState(); state != CircuitClosed { + t.Errorf("Expected CircuitClosed, got %v", state) + } + + // Trigger failures to open circuit + rl.updateCircuitBreaker(false) + rl.updateCircuitBreaker(false) + + // Should be open + if state := rl.GetCircuitState(); state != CircuitOpen { + t.Errorf("Expected CircuitOpen, got %v", state) + } + + // Check consecutive failures + if failures := rl.GetConsecutiveFailures(); failures != 2 { + t.Errorf("Expected 2 consecutive failures, got %d", failures) + } + + // Wait for timeout + time.Sleep(150 * time.Millisecond) + + // Should be half-open + if state := rl.GetCircuitState(); state != CircuitHalfOpen { + t.Errorf("Expected CircuitHalfOpen, got %v", state) + } + + // Success should close it + rl.updateCircuitBreaker(true) + + if state := rl.GetCircuitState(); state != CircuitClosed { + t.Errorf("Expected CircuitClosed after success, got %v", state) + } + if failures := rl.GetConsecutiveFailures(); failures != 0 { + t.Errorf("Expected 0 consecutive failures after success, got %d", failures) + } +} + +func TestCircuitBreakerCallbacks(t *testing.T) { + openCalled := false + closeCalled := false + + rl := NewRateLimiter(RateLimiterConfig{ + CircuitBreakerThreshold: 2, + OnCircuitOpen: func() { + openCalled = true + }, + OnCircuitClose: func() { + closeCalled = true + }, + }) + + // Trigger circuit open + rl.updateCircuitBreaker(false) + rl.updateCircuitBreaker(false) + + if !openCalled { + t.Error("Expected OnCircuitOpen callback to be called") + } + + // Trigger circuit close + rl.updateCircuitBreaker(true) + + if !closeCalled { + t.Error("Expected OnCircuitClose callback to be called") + } +} + +func TestExecuteStreamWithRetry_NoRetryIfChunksSent(t *testing.T) { + rl := NewRateLimiter(RateLimiterConfig{ + MaxRetries: 3, + BaseDelay: 10 * time.Millisecond, + }) + + ctx := context.Background() + callCount := 0 + + chunkCh, errCh := rl.ExecuteStreamWithRetry(ctx, func() (<-chan StreamChunk, <-chan error) { + callCount++ + ch := make(chan StreamChunk) + ec := make(chan error, 1) + + go func() { + defer close(ch) + defer close(ec) + + if callCount == 1 { + // First call: send one chunk then a rate limit error + ch <- StreamChunk{Choices: []StreamDelta{{Delta: MessageDelta{Content: "Partial"}}}} + ec <- errRateLimit + return + } + + // Second call (should not happen) + ch <- StreamChunk{Choices: []StreamDelta{{Delta: MessageDelta{Content: "Should not retry"}}}} + }() + + return ch, ec + }) + + var content strings.Builder + for chunk := range chunkCh { + if len(chunk.Choices) > 0 { + content.WriteString(chunk.Choices[0].Delta.Content) + } + } + + err := <-errCh + if err == nil { + t.Error("Expected error after sending chunks then hitting rate limit") + } + + if !errors.Is(err, ErrRateLimitExceeded) { + t.Errorf("Expected ErrRateLimitExceeded, got %v", err) + } + + if !strings.Contains(err.Error(), "partial content already sent") { + t.Errorf("Expected error message to mention partial content, got: %v", err) + } + + if callCount != 1 { + t.Errorf("Expected exactly 1 call, got %d", callCount) + } + + if content.String() != "Partial" { + t.Errorf("Expected 'Partial', got '%s'", content.String()) + } +}