diff --git a/_example/authgate-cli/main.go b/_example/authgate-cli/main.go index 2149bc6..59ee75c 100644 --- a/_example/authgate-cli/main.go +++ b/_example/authgate-cli/main.go @@ -588,7 +588,11 @@ func exchangeDeviceCode( } // Validate token response - if err := validateTokenResponse(tokenResp.AccessToken, tokenResp.TokenType, tokenResp.ExpiresIn); err != nil { + if err := validateTokenResponse( + tokenResp.AccessToken, + tokenResp.TokenType, + tokenResp.ExpiresIn, + ); err != nil { return nil, fmt.Errorf("invalid token response: %w", err) } @@ -781,7 +785,11 @@ func refreshAccessToken(refreshToken string) (*TokenStorage, error) { } // Validate token response - if err := validateTokenResponse(tokenResp.AccessToken, tokenResp.TokenType, tokenResp.ExpiresIn); err != nil { + if err := validateTokenResponse( + tokenResp.AccessToken, + tokenResp.TokenType, + tokenResp.ExpiresIn, + ); err != nil { return nil, fmt.Errorf("invalid token response: %w", err) } diff --git a/internal/auth/local.go b/internal/auth/local.go index f2c993e..f6a9016 100644 --- a/internal/auth/local.go +++ b/internal/auth/local.go @@ -28,7 +28,10 @@ func (p *LocalAuthProvider) Authenticate( return nil, ErrInvalidCredentials } - if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { + if err := bcrypt.CompareHashAndPassword( + []byte(user.PasswordHash), + []byte(password), + ); err != nil { return nil, ErrInvalidCredentials } diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index 0f7fe86..1fa1f7e 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -10,6 +10,7 @@ import ( "time" "github.com/appleboy/authgate/internal/auth" + "github.com/appleboy/authgate/internal/metrics" "github.com/appleboy/authgate/internal/middleware" "github.com/appleboy/authgate/internal/services" "github.com/appleboy/authgate/internal/templates" @@ -97,6 +98,7 @@ type AuthHandler struct { baseURL string sessionFingerprintEnabled bool sessionFingerprintIncludeIP bool + metrics metrics.MetricsRecorder } func NewAuthHandler( @@ -104,12 +106,14 @@ func NewAuthHandler( baseURL string, fingerprintEnabled bool, fingerprintIncludeIP bool, + m metrics.MetricsRecorder, ) *AuthHandler { return &AuthHandler{ userService: us, baseURL: baseURL, sessionFingerprintEnabled: fingerprintEnabled, sessionFingerprintIncludeIP: fingerprintIncludeIP, + metrics: m, } } @@ -170,6 +174,7 @@ func (h *AuthHandler) LoginPageWithOAuth( func (h *AuthHandler) Login(c *gin.Context, oauthProviders map[string]*auth.OAuthProvider, ) { + start := time.Now() username := c.PostForm("username") password := c.PostForm("password") redirectTo := c.PostForm("redirect") @@ -181,6 +186,11 @@ func (h *AuthHandler) Login(c *gin.Context, user, err := h.userService.Authenticate(c.Request.Context(), username, password) if err != nil { + // Record failed login + duration := time.Since(start) + h.metrics.RecordLogin("local", false) + h.metrics.RecordAuthAttempt("local", false, duration) + var errorMsg string // Check for specific error types @@ -212,6 +222,15 @@ func (h *AuthHandler) Login(c *gin.Context, return } + // Record successful login + duration := time.Since(start) + authSource := user.AuthSource + if authSource == "" { + authSource = "local" + } + h.metrics.RecordLogin(authSource, true) + h.metrics.RecordAuthAttempt(authSource, true, duration) + // Set session session := sessions.Default(c) session.Set(SessionUserID, user.ID) @@ -249,6 +268,16 @@ func (h *AuthHandler) Login(c *gin.Context, // Logout clears the session and redirects to login func (h *AuthHandler) Logout(c *gin.Context) { session := sessions.Default(c) + + // Calculate session duration if available + var sessionDuration time.Duration + if createdAtUnix := session.Get(SessionLastActivity); createdAtUnix != nil { + if createdAtInt64, ok := createdAtUnix.(int64); ok { + createdAt := time.Unix(createdAtInt64, 0) + sessionDuration = time.Since(createdAt) + } + } + session.Clear() if err := session.Save(); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -256,5 +285,9 @@ func (h *AuthHandler) Logout(c *gin.Context) { }) return } + + // Record logout + h.metrics.RecordLogout(sessionDuration) + c.Redirect(http.StatusFound, "/login") } diff --git a/internal/handlers/oauth_handler.go b/internal/handlers/oauth_handler.go index d797ffe..950070f 100644 --- a/internal/handlers/oauth_handler.go +++ b/internal/handlers/oauth_handler.go @@ -12,6 +12,7 @@ import ( "time" "github.com/appleboy/authgate/internal/auth" + "github.com/appleboy/authgate/internal/metrics" "github.com/appleboy/authgate/internal/services" "github.com/appleboy/authgate/internal/templates" @@ -38,6 +39,7 @@ type OAuthHandler struct { httpClient *http.Client // Custom HTTP client for OAuth requests sessionFingerprintEnabled bool sessionFingerprintIncludeIP bool + metrics metrics.MetricsRecorder } // NewOAuthHandler creates a new OAuth handler @@ -47,6 +49,7 @@ func NewOAuthHandler( httpClient *http.Client, fingerprintEnabled bool, fingerprintIncludeIP bool, + m metrics.MetricsRecorder, ) *OAuthHandler { return &OAuthHandler{ providers: providers, @@ -54,6 +57,7 @@ func NewOAuthHandler( httpClient: httpClient, sessionFingerprintEnabled: fingerprintEnabled, sessionFingerprintIncludeIP: fingerprintIncludeIP, + metrics: m, } } @@ -200,6 +204,9 @@ func (h *OAuthHandler) OAuthCallback(c *gin.Context) { token, ) if err != nil { + // Record failure + h.metrics.RecordOAuthCallback(provider, false) + log.Printf("[OAuth] Authentication failed: %v", err) // Handle specific errors @@ -225,6 +232,9 @@ func (h *OAuthHandler) OAuthCallback(c *gin.Context) { return } + // Record success + h.metrics.RecordOAuthCallback(provider, true) + // Clear OAuth session data session.Delete("oauth_state") session.Delete("oauth_provider") diff --git a/internal/handlers/session.go b/internal/handlers/session.go index 5f176b8..6056bb9 100644 --- a/internal/handlers/session.go +++ b/internal/handlers/session.go @@ -162,7 +162,11 @@ func (h *SessionHandler) RevokeSession(c *gin.Context) { userID, _ := c.Get("user_id") // Revoke the token - if err := h.tokenService.RevokeTokenByID(c.Request.Context(), tokenID, userID.(string)); err != nil { + if err := h.tokenService.RevokeTokenByID( + c.Request.Context(), + tokenID, + userID.(string), + ); err != nil { templates.RenderTempl( c, http.StatusInternalServerError, @@ -214,7 +218,11 @@ func (h *SessionHandler) DisableSession(c *gin.Context) { userID, _ := c.Get("user_id") // Disable the token - if err := h.tokenService.DisableToken(c.Request.Context(), tokenID, userID.(string)); err != nil { + if err := h.tokenService.DisableToken( + c.Request.Context(), + tokenID, + userID.(string), + ); err != nil { templates.RenderTempl( c, http.StatusInternalServerError, @@ -238,7 +246,11 @@ func (h *SessionHandler) EnableSession(c *gin.Context) { userID, _ := c.Get("user_id") // Enable the token - if err := h.tokenService.EnableToken(c.Request.Context(), tokenID, userID.(string)); err != nil { + if err := h.tokenService.EnableToken( + c.Request.Context(), + tokenID, + userID.(string), + ); err != nil { templates.RenderTempl( c, http.StatusInternalServerError, diff --git a/internal/metrics/INTEGRATION.md b/internal/metrics/INTEGRATION.md deleted file mode 100644 index 6fb266c..0000000 --- a/internal/metrics/INTEGRATION.md +++ /dev/null @@ -1,518 +0,0 @@ -# Metrics Integration Guide - -This document provides examples of how to integrate Prometheus metrics into AuthGate services and handlers. - -## Overview - -The metrics system consists of: - -- **metrics.go**: Core metrics definitions and initialization -- **http.go**: HTTP middleware and helper methods for recording metrics -- **main.go**: Initialization and `/metrics` endpoint registration - -## Current Status - -✅ **Implemented (Core Infrastructure)** - -- HTTP request metrics (automatic via middleware) -- Metrics initialization with singleton pattern -- `/metrics` endpoint for Prometheus scraping -- Helper methods for recording OAuth, Auth, and Token metrics - -⚠️ **Pending (Service Integration)** - -- Device service integration -- Token service integration -- Auth handler integration -- Session management integration - -## Quick Start - -The metrics system is already initialized in `main.go` and the HTTP metrics middleware is active. You can access metrics at: - -```bash -curl http://localhost:8080/metrics -``` - -## HTTP Metrics (✅ Auto-enabled) - -HTTP request metrics are automatically collected via the `HTTPMetricsMiddleware`: - -```go -// Automatically tracked for all routes: -- http_requests_total{method, path, status} -- http_request_duration_seconds{method, path} -- http_requests_in_flight -``` - -Example output: - -``` -http_requests_total{method="POST",path="/oauth/token",status="200"} 42 -http_request_duration_seconds_bucket{method="POST",path="/oauth/token",le="0.1"} 40 -http_requests_in_flight 3 -``` - -## Integration Examples - -### 1. Device Service Integration - -To record device code metrics, update `internal/services/device.go`: - -```go -import "github.com/appleboy/authgate/internal/metrics" - -type DeviceService struct { - store *store.Store - config *config.Config - auditService *AuditService - metrics *metrics.Metrics // Add this field -} - -func NewDeviceService( - s *store.Store, - cfg *config.Config, - auditService *AuditService, - m *metrics.Metrics, // Add parameter -) *DeviceService { - return &DeviceService{ - store: s, - config: cfg, - auditService: auditService, - metrics: m, // Store metrics instance - } -} - -func (s *DeviceService) GenerateDeviceCode( - ctx context.Context, - clientID, scope string, -) (*models.DeviceCode, error) { - // ... existing validation code ... - - if err := s.store.CreateDeviceCode(deviceCode); err != nil { - // Record failure - if s.metrics != nil { - s.metrics.RecordOAuthDeviceCodeGenerated(false) - } - return nil, err - } - - // Record success - if s.metrics != nil { - s.metrics.RecordOAuthDeviceCodeGenerated(true) - } - - // ... existing audit logging ... - - return deviceCode, nil -} - -func (s *DeviceService) VerifyUserCode( - ctx context.Context, - userCode string, - userID uint, -) error { - // ... existing validation code ... - - // Calculate authorization duration - authDuration := time.Since(dc.CreatedAt) - - // Record authorization - if s.metrics != nil { - s.metrics.RecordOAuthDeviceCodeAuthorized(authDuration) - } - - // ... rest of the code ... -} -``` - -### 2. Token Service Integration - -To record token metrics, update `internal/services/token.go`: - -```go -import "github.com/appleboy/authgate/internal/metrics" - -type TokenService struct { - store *store.Store - config *config.Config - deviceService *DeviceService - localTokenProvider *token.LocalTokenProvider - httpTokenProvider *token.HTTPTokenProvider - tokenProviderMode string - auditService *AuditService - metrics *metrics.Metrics // Add this field -} - -func (s *TokenService) IssueTokenForDeviceCode( - ctx context.Context, - deviceCode string, -) (*TokenResponse, error) { - start := time.Now() - - // ... existing validation code ... - - // Validate device code - dc, err := s.deviceService.GetDeviceCode(deviceCode) - if err != nil { - if s.metrics != nil { - result := "invalid" - if err == ErrDeviceCodeExpired { - result = "expired" - } - s.metrics.RecordOAuthDeviceCodeValidation(result) - } - return nil, err - } - - if !dc.Authorized { - if s.metrics != nil { - s.metrics.RecordOAuthDeviceCodeValidation("pending") - } - return nil, ErrAuthorizationPending - } - - // Record successful validation - if s.metrics != nil { - s.metrics.RecordOAuthDeviceCodeValidation("success") - } - - // ... generate tokens ... - - // Record token issuance - if s.metrics != nil { - duration := time.Since(start) - provider := s.tokenProviderMode - s.metrics.RecordTokenIssued("access", "device_code", duration, provider) - if refreshToken != "" { - s.metrics.RecordTokenIssued("refresh", "device_code", duration, provider) - } - } - - return &TokenResponse{ - AccessToken: accessToken, - TokenType: "Bearer", - ExpiresIn: int(accessExpiry.Seconds()), - RefreshToken: refreshToken, - }, nil -} - -func (s *TokenService) RefreshToken( - ctx context.Context, - refreshTokenPlaintext string, -) (*TokenResponse, error) { - // ... existing code ... - - // Record refresh attempt - if err != nil { - if s.metrics != nil { - s.metrics.RecordTokenRefresh(false) - } - return nil, err - } - - if s.metrics != nil { - s.metrics.RecordTokenRefresh(true) - } - - // ... rest of the code ... -} - -func (s *TokenService) RevokeToken( - ctx context.Context, - tokenPlaintext, tokenTypeHint string, - reason string, -) error { - // ... existing revocation code ... - - // Record revocation - if s.metrics != nil { - tokenType := "access" // or "refresh" based on token category - s.metrics.RecordTokenRevoked(tokenType, reason) - } - - return nil -} -``` - -### 3. Auth Handler Integration - -To record authentication metrics, update `internal/handlers/auth.go`: - -```go -import "github.com/appleboy/authgate/internal/metrics" - -type AuthHandler struct { - userService *services.UserService - baseURL string - sessionFingerprint bool - sessionFingerprintIP bool - metrics *metrics.Metrics // Add this field -} - -func (h *AuthHandler) Login(c *gin.Context, oauthProviders map[string]*auth.OAuthProvider) { - start := time.Now() - - // ... parse credentials ... - - // Attempt authentication - user, err := h.userService.AuthenticateUser(c, username, password) - - authSource := "local" // or determine from user.AuthSource - - if err != nil { - // Record failed login - if h.metrics != nil { - h.metrics.RecordLogin(authSource, false) - duration := time.Since(start) - h.metrics.RecordAuthAttempt("local", false, duration) - } - - // ... existing error handling ... - return - } - - // Record successful login - if h.metrics != nil { - h.metrics.RecordLogin(authSource, true) - duration := time.Since(start) - h.metrics.RecordAuthAttempt("local", true, duration) - } - - // ... rest of the code ... -} - -func (h *AuthHandler) Logout(c *gin.Context) { - session := sessions.Default(c) - - // Calculate session duration if available - var sessionDuration time.Duration - if createdAt, ok := session.Get("created_at").(time.Time); ok { - sessionDuration = time.Since(createdAt) - } - - session.Clear() - session.Options(sessions.Options{MaxAge: -1}) - _ = session.Save() - - // Record logout - if h.metrics != nil { - h.metrics.RecordLogout(sessionDuration) - } - - c.Redirect(http.StatusFound, "/login") -} -``` - -### 4. OAuth Handler Integration - -For OAuth callbacks, update `internal/handlers/oauth_handler.go`: - -```go -func (h *OAuthHandler) OAuthCallback(c *gin.Context) { - provider := c.Param("provider") - - // ... existing OAuth flow ... - - if err != nil { - if h.metrics != nil { - h.metrics.RecordOAuthCallback(provider, false) - } - // ... error handling ... - return - } - - if h.metrics != nil { - h.metrics.RecordOAuthCallback(provider, true) - } - - // ... rest of the code ... -} -``` - -## Updating main.go - -After integrating metrics into services, update `main.go` to pass the metrics instance: - -```go -// Current (already done): -prometheusMetrics := metrics.Init() - -// Update service initialization (to be done): -deviceService := services.NewDeviceService(db, cfg, auditService, prometheusMetrics) -tokenService := services.NewTokenService( - db, - cfg, - deviceService, - localTokenProvider, - httpTokenProvider, - cfg.TokenProviderMode, - auditService, - prometheusMetrics, // Add this -) - -// Update handler initialization: -authHandler := handlers.NewAuthHandler( - userService, - cfg.BaseURL, - cfg.SessionFingerprint, - cfg.SessionFingerprintIP, - prometheusMetrics, // Add this -) -``` - -## Periodic Gauge Updates - -For gauge metrics that track current state (active tokens, sessions, device codes), add a background job in `main.go`: - -```go -// Add metrics update job (runs every 30 seconds) -m.AddRunningJob(func(ctx context.Context) error { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - // Update active tokens count - activeAccessTokens, _ := db.CountActiveTokens("access") - activeRefreshTokens, _ := db.CountActiveTokens("refresh") - prometheusMetrics.SetActiveTokensCount("access", activeAccessTokens) - prometheusMetrics.SetActiveTokensCount("refresh", activeRefreshTokens) - - // Update active device codes count - totalDeviceCodes, pendingDeviceCodes, _ := db.CountDeviceCodes() - prometheusMetrics.SetActiveDeviceCodesCount(totalDeviceCodes, pendingDeviceCodes) - - // Update active sessions count (if tracking sessions in DB) - // activeSessions, _ := db.CountActiveSessions() - // prometheusMetrics.SetActiveSessionsCount(activeSessions) - - case <-ctx.Done(): - return nil - } - } -}) -``` - -## Grafana Dashboard Example - -Example PromQL queries for Grafana: - -```promql -# Request rate by endpoint -rate(http_requests_total[5m]) - -# Error rate -rate(http_requests_total{status=~"5.."}[5m]) / rate(http_requests_total[5m]) - -# P95 latency -histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m])) - -# Active device codes -oauth_device_codes_active - -# Token issuance rate -rate(oauth_tokens_issued_total[5m]) - -# Failed authentication rate -rate(auth_login_total{result="failure"}[5m]) -``` - -## Testing Metrics - -Start the server and generate some traffic: - -```bash -# Start server -./bin/authgate server - -# Check metrics endpoint -curl http://localhost:8080/metrics | grep oauth - -# Generate some metrics: -# 1. Request device code -curl -X POST http://localhost:8080/oauth/device/code \ - -H "Content-Type: application/json" \ - -d '{"client_id":"your-client-id","scope":"read:user"}' - -# 2. Login -curl -X POST http://localhost:8080/login \ - -d "username=admin&password=yourpassword" - -# Check metrics again -curl http://localhost:8080/metrics | grep -E "(oauth|auth|http_request)" -``` - -## Best Practices - -1. **Always check for nil**: Metrics instance might be nil in tests - - ```go - if s.metrics != nil { - s.metrics.RecordSomething() - } - ``` - -2. **Record timing at operation boundaries**: Start timer at the beginning, record at the end - - ```go - start := time.Now() - // ... do work ... - if s.metrics != nil { - s.metrics.RecordAuthAttempt("local", success, time.Since(start)) - } - ``` - -3. **Use meaningful labels**: Keep cardinality low (no user IDs, tokens, etc.) - - ```go - // Good: bounded set of values - s.metrics.RecordLogin("local", true) - - // Bad: unbounded cardinality - // s.metrics.RecordLogin(username, true) // DON'T DO THIS - ``` - -4. **Record both success and failure**: Always track outcomes - - ```go - if err != nil { - s.metrics.RecordTokenRefresh(false) - return err - } - s.metrics.RecordTokenRefresh(true) - ``` - -## Next Steps - -To complete the metrics integration: - -1. Update DeviceService to include metrics parameter and calls -2. Update TokenService to include metrics parameter and calls -3. Update AuthHandler to include metrics parameter and calls -4. Update OAuthHandler to include metrics parameter and calls -5. Add periodic gauge update job in main.go -6. Add store methods for counting active resources (optional) -7. Create Grafana dashboard (see docs/MONITORING.md) -8. Update configuration docs with metrics endpoint info - -## Configuration - -Add to `.env` or environment variables: - -```bash -# Metrics are always enabled -# Access at: http://localhost:8080/metrics - -# For production, consider: -# - Restricting /metrics endpoint to internal network -# - Using Prometheus scrape configs with authentication -# - Setting up Grafana dashboards -``` - -## See Also - -- [docs/MONITORING.md](../../docs/MONITORING.md) - Monitoring best practices -- [Prometheus Documentation](https://prometheus.io/docs/) -- [Gin Prometheus Middleware](https://github.com/zsais/go-gin-prometheus) diff --git a/internal/metrics/http.go b/internal/metrics/http.go index 9db04ff..c4a758e 100644 --- a/internal/metrics/http.go +++ b/internal/metrics/http.go @@ -14,7 +14,23 @@ const ( ) // HTTPMetricsMiddleware creates a Gin middleware that records HTTP metrics -func HTTPMetricsMiddleware(m *Metrics) gin.HandlerFunc { +func HTTPMetricsMiddleware(m MetricsRecorder) gin.HandlerFunc { + // If NoopMetrics, return a lightweight middleware that does nothing + if _, ok := m.(*NoopMetrics); ok { + return func(c *gin.Context) { + c.Next() + } + } + + // Type assert to concrete Metrics for Prometheus access + metrics, ok := m.(*Metrics) + if !ok { + // Fallback if unknown implementation + return func(c *gin.Context) { + c.Next() + } + } + return func(c *gin.Context) { // Skip metrics endpoint to avoid self-recording if c.Request.URL.Path == "/metrics" { @@ -25,8 +41,8 @@ func HTTPMetricsMiddleware(m *Metrics) gin.HandlerFunc { start := time.Now() // Increment in-flight counter - m.HTTPRequestsInFlight.Inc() - defer m.HTTPRequestsInFlight.Dec() + metrics.HTTPRequestsInFlight.Inc() + defer metrics.HTTPRequestsInFlight.Dec() // Process request c.Next() @@ -38,10 +54,10 @@ func HTTPMetricsMiddleware(m *Metrics) gin.HandlerFunc { status := strconv.Itoa(c.Writer.Status()) // Record request count - m.HTTPRequestsTotal.WithLabelValues(method, path, status).Inc() + metrics.HTTPRequestsTotal.WithLabelValues(method, path, status).Inc() // Record request duration - m.HTTPRequestDuration.WithLabelValues(method, path).Observe(duration) + metrics.HTTPRequestDuration.WithLabelValues(method, path).Observe(duration) } } diff --git a/internal/metrics/interface.go b/internal/metrics/interface.go new file mode 100644 index 0000000..cf1edfe --- /dev/null +++ b/internal/metrics/interface.go @@ -0,0 +1,34 @@ +package metrics + +import "time" + +// MetricsRecorder defines the interface for recording application metrics +// Implementations include Metrics (Prometheus-based) and NoopMetrics (no-op) +type MetricsRecorder interface { + // OAuth Device Flow + RecordOAuthDeviceCodeGenerated(success bool) + RecordOAuthDeviceCodeAuthorized(authorizationTime time.Duration) + RecordOAuthDeviceCodeValidation(result string) + + // Token Operations + RecordTokenIssued(tokenType, grantType string, generationTime time.Duration, provider string) + RecordTokenRevoked(tokenType, reason string) + RecordTokenRefresh(success bool) + RecordTokenValidation(result string, duration time.Duration, provider string) + + // Authentication + RecordAuthAttempt(method string, success bool, duration time.Duration) + RecordLogin(authSource string, success bool) + RecordLogout(sessionDuration time.Duration) + RecordOAuthCallback(provider string, success bool) + RecordExternalAPICall(provider string, duration time.Duration) + + // Session Management + RecordSessionExpired(reason string, duration time.Duration) + RecordSessionInvalidated(reason string) + + // Gauge Setters (for periodic updates) + SetActiveTokensCount(tokenType string, count int) + SetActiveDeviceCodesCount(total, pending int) + SetActiveSessionsCount(count int) +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 8020fcb..6a91739 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -7,6 +7,9 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) +// Ensure Metrics implements MetricsRecorder interface at compile time +var _ MetricsRecorder = (*Metrics)(nil) + // Metrics holds all Prometheus metrics for the application type Metrics struct { // OAuth Device Flow Metrics @@ -52,9 +55,15 @@ var ( once sync.Once ) -// Init initializes all Prometheus metrics -// Uses sync.Once to ensure metrics are only registered once -func Init() *Metrics { +// Init initializes metrics based on enabled flag +// If enabled=true, returns Prometheus-based Metrics +// If enabled=false, returns NoopMetrics (zero overhead) +// Uses sync.Once to ensure Prometheus metrics are only registered once +func Init(enabled bool) MetricsRecorder { + if !enabled { + return NewNoopMetrics() + } + once.Do(func() { defaultMetrics = initMetrics() }) @@ -300,9 +309,13 @@ func initMetrics() *Metrics { } // GetMetrics returns the global metrics instance +// +// Deprecated: Use Init(true) instead func GetMetrics() *Metrics { if defaultMetrics == nil { - return Init() + once.Do(func() { + defaultMetrics = initMetrics() + }) } return defaultMetrics } diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index f4ec60f..a327f36 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -8,12 +8,25 @@ import ( ) func TestInit(t *testing.T) { - m := Init() + m := Init(true) assert.NotNil(t, m) - assert.NotNil(t, m.DeviceCodesTotal) - assert.NotNil(t, m.TokensIssuedTotal) - assert.NotNil(t, m.AuthAttemptsTotal) - assert.NotNil(t, m.HTTPRequestsTotal) + + // Type assert to concrete Metrics to access fields + metrics, ok := m.(*Metrics) + assert.True(t, ok, "Init(true) should return *Metrics") + assert.NotNil(t, metrics.DeviceCodesTotal) + assert.NotNil(t, metrics.TokensIssuedTotal) + assert.NotNil(t, metrics.AuthAttemptsTotal) + assert.NotNil(t, metrics.HTTPRequestsTotal) +} + +func TestInitNoop(t *testing.T) { + m := Init(false) + assert.NotNil(t, m) + + // Type assert to NoopMetrics + _, ok := m.(*NoopMetrics) + assert.True(t, ok, "Init(false) should return *NoopMetrics") } func TestGetMetrics(t *testing.T) { @@ -26,21 +39,21 @@ func TestGetMetrics(t *testing.T) { } func TestRecordOAuthDeviceCodeGenerated(t *testing.T) { - m := Init() + m := Init(true) m.RecordOAuthDeviceCodeGenerated(true) // No error means success - prometheus metrics don't return errors for recording } func TestRecordOAuthDeviceCodeAuthorized(t *testing.T) { - m := Init() + m := Init(true) m.RecordOAuthDeviceCodeAuthorized(5 * time.Second) // No error means success } func TestRecordOAuthDeviceCodeValidation(t *testing.T) { - m := Init() + m := Init(true) // First generate a device code m.RecordOAuthDeviceCodeGenerated(true) @@ -51,7 +64,7 @@ func TestRecordOAuthDeviceCodeValidation(t *testing.T) { } func TestRecordTokenIssued(t *testing.T) { - m := Init() + m := Init(true) m.RecordTokenIssued("access", "device_code", 100*time.Millisecond, "local") m.RecordTokenIssued("refresh", "device_code", 150*time.Millisecond, "local") @@ -59,7 +72,7 @@ func TestRecordTokenIssued(t *testing.T) { } func TestRecordTokenRevoked(t *testing.T) { - m := Init() + m := Init(true) // First issue a token m.RecordTokenIssued("access", "device_code", 100*time.Millisecond, "local") @@ -70,7 +83,7 @@ func TestRecordTokenRevoked(t *testing.T) { } func TestRecordTokenRefresh(t *testing.T) { - m := Init() + m := Init(true) m.RecordTokenRefresh(true) m.RecordTokenRefresh(false) @@ -78,7 +91,7 @@ func TestRecordTokenRefresh(t *testing.T) { } func TestRecordTokenValidation(t *testing.T) { - m := Init() + m := Init(true) m.RecordTokenValidation("valid", 50*time.Millisecond, "local") m.RecordTokenValidation("invalid", 30*time.Millisecond, "local") @@ -87,7 +100,7 @@ func TestRecordTokenValidation(t *testing.T) { } func TestRecordAuthAttempt(t *testing.T) { - m := Init() + m := Init(true) m.RecordAuthAttempt("local", true, 200*time.Millisecond) m.RecordAuthAttempt("local", false, 150*time.Millisecond) @@ -96,7 +109,7 @@ func TestRecordAuthAttempt(t *testing.T) { } func TestRecordLogin(t *testing.T) { - m := Init() + m := Init(true) m.RecordLogin("local", true) m.RecordLogin("local", false) @@ -105,7 +118,7 @@ func TestRecordLogin(t *testing.T) { } func TestRecordLogout(t *testing.T) { - m := Init() + m := Init(true) // First create a session m.RecordLogin("local", true) @@ -116,7 +129,7 @@ func TestRecordLogout(t *testing.T) { } func TestRecordOAuthCallback(t *testing.T) { - m := Init() + m := Init(true) m.RecordOAuthCallback("microsoft", true) m.RecordOAuthCallback("github", false) @@ -124,14 +137,14 @@ func TestRecordOAuthCallback(t *testing.T) { } func TestRecordExternalAPICall(t *testing.T) { - m := Init() + m := Init(true) m.RecordExternalAPICall("http_api", 300*time.Millisecond) // No error means success } func TestRecordSessionExpired(t *testing.T) { - m := Init() + m := Init(true) // First create a session m.RecordLogin("local", true) @@ -142,7 +155,7 @@ func TestRecordSessionExpired(t *testing.T) { } func TestRecordSessionInvalidated(t *testing.T) { - m := Init() + m := Init(true) // First create a session m.RecordLogin("local", true) @@ -153,7 +166,7 @@ func TestRecordSessionInvalidated(t *testing.T) { } func TestSetActiveTokensCount(t *testing.T) { - m := Init() + m := Init(true) m.SetActiveTokensCount("access", 100) m.SetActiveTokensCount("refresh", 50) @@ -161,14 +174,14 @@ func TestSetActiveTokensCount(t *testing.T) { } func TestSetActiveDeviceCodesCount(t *testing.T) { - m := Init() + m := Init(true) m.SetActiveDeviceCodesCount(20, 5) // No error means success } func TestSetActiveSessionsCount(t *testing.T) { - m := Init() + m := Init(true) m.SetActiveSessionsCount(42) // No error means success diff --git a/internal/metrics/noop.go b/internal/metrics/noop.go new file mode 100644 index 0000000..44b43b3 --- /dev/null +++ b/internal/metrics/noop.go @@ -0,0 +1,61 @@ +package metrics + +import "time" + +// NoopMetrics is a no-operation implementation of MetricsRecorder +// All methods are empty and do nothing, providing zero overhead when metrics are disabled +type NoopMetrics struct{} + +// Ensure NoopMetrics implements MetricsRecorder interface at compile time +var _ MetricsRecorder = (*NoopMetrics)(nil) + +// NewNoopMetrics creates a new no-operation metrics recorder +func NewNoopMetrics() MetricsRecorder { + return &NoopMetrics{} +} + +// OAuth Device Flow - noop implementations +func (n *NoopMetrics) RecordOAuthDeviceCodeGenerated(success bool) {} +func (n *NoopMetrics) RecordOAuthDeviceCodeAuthorized(authorizationTime time.Duration) {} +func (n *NoopMetrics) RecordOAuthDeviceCodeValidation(result string) {} + +// Token Operations - noop implementations +func (n *NoopMetrics) RecordTokenIssued( + tokenType, grantType string, + generationTime time.Duration, + provider string, +) { +} + +func (n *NoopMetrics) RecordTokenRevoked( + tokenType, reason string, +) { +} + +func (n *NoopMetrics) RecordTokenRefresh( + success bool, +) { +} + +func (n *NoopMetrics) RecordTokenValidation( + result string, + duration time.Duration, + provider string, +) { +} + +// Authentication - noop implementations +func (n *NoopMetrics) RecordAuthAttempt(method string, success bool, duration time.Duration) {} +func (n *NoopMetrics) RecordLogin(authSource string, success bool) {} +func (n *NoopMetrics) RecordLogout(sessionDuration time.Duration) {} +func (n *NoopMetrics) RecordOAuthCallback(provider string, success bool) {} +func (n *NoopMetrics) RecordExternalAPICall(provider string, duration time.Duration) {} + +// Session Management - noop implementations +func (n *NoopMetrics) RecordSessionExpired(reason string, duration time.Duration) {} +func (n *NoopMetrics) RecordSessionInvalidated(reason string) {} + +// Gauge Setters - noop implementations +func (n *NoopMetrics) SetActiveTokensCount(tokenType string, count int) {} +func (n *NoopMetrics) SetActiveDeviceCodesCount(total, pending int) {} +func (n *NoopMetrics) SetActiveSessionsCount(count int) {} diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index 15cb545..d11a31d 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -76,7 +76,11 @@ func NewRateLimiter(config RateLimitConfig) (gin.HandlerFunc, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := client.Ping(ctx).Err(); err != nil { - return nil, fmt.Errorf("failed to connect to Redis at %s: %w", config.RedisAddr, err) + return nil, fmt.Errorf( + "failed to connect to Redis at %s: %w", + config.RedisAddr, + err, + ) } } diff --git a/internal/services/client.go b/internal/services/client.go index f0bcc75..4a3c7b2 100644 --- a/internal/services/client.go +++ b/internal/services/client.go @@ -334,7 +334,10 @@ func (s *ClientService) VerifyClientSecret(clientID, clientSecret string) error return ErrClientNotFound } - if err := bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)); err != nil { + if err := bcrypt.CompareHashAndPassword( + []byte(client.ClientSecret), + []byte(clientSecret), + ); err != nil { return errors.New("invalid client secret") } diff --git a/internal/services/device.go b/internal/services/device.go index 15470f6..e0c59dc 100644 --- a/internal/services/device.go +++ b/internal/services/device.go @@ -12,6 +12,7 @@ import ( "time" "github.com/appleboy/authgate/internal/config" + "github.com/appleboy/authgate/internal/metrics" "github.com/appleboy/authgate/internal/models" "github.com/appleboy/authgate/internal/store" "github.com/appleboy/authgate/internal/util" @@ -29,17 +30,20 @@ type DeviceService struct { store *store.Store config *config.Config auditService *AuditService + metrics metrics.MetricsRecorder } func NewDeviceService( s *store.Store, cfg *config.Config, auditService *AuditService, + m metrics.MetricsRecorder, ) *DeviceService { return &DeviceService{ store: s, config: cfg, auditService: auditService, + metrics: m, } } @@ -90,9 +94,14 @@ func (s *DeviceService) GenerateDeviceCode( } if err := s.store.CreateDeviceCode(deviceCode); err != nil { + // Record failure + s.metrics.RecordOAuthDeviceCodeGenerated(false) return nil, err } + // Record success + s.metrics.RecordOAuthDeviceCodeGenerated(true) + // Log device code generation if s.auditService != nil { s.auditService.Log(ctx, AuditLogEntry{ @@ -194,6 +203,10 @@ func (s *DeviceService) AuthorizeDeviceCode( return err } + // Record authorization with duration + authDuration := time.Since(dc.CreatedAt) + s.metrics.RecordOAuthDeviceCodeAuthorized(authDuration) + // Log device code authorization if s.auditService != nil { s.auditService.Log(ctx, AuditLogEntry{ diff --git a/internal/services/device_security_test.go b/internal/services/device_security_test.go index b4f0328..cfb2a9f 100644 --- a/internal/services/device_security_test.go +++ b/internal/services/device_security_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/appleboy/authgate/internal/config" + "github.com/appleboy/authgate/internal/metrics" "github.com/appleboy/authgate/internal/models" "github.com/appleboy/authgate/internal/store" "github.com/appleboy/authgate/internal/util" @@ -24,7 +25,7 @@ func setupTestService(t *testing.T) (*DeviceService, *models.OAuthApplication) { st, err := store.New("sqlite", ":memory:", cfg) require.NoError(t, err) - service := NewDeviceService(st, cfg, nil) + service := NewDeviceService(st, cfg, nil, metrics.NewNoopMetrics()) // Create test client client := &models.OAuthApplication{ diff --git a/internal/services/device_test.go b/internal/services/device_test.go index 2a773c4..918c6db 100644 --- a/internal/services/device_test.go +++ b/internal/services/device_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/appleboy/authgate/internal/config" + "github.com/appleboy/authgate/internal/metrics" "github.com/appleboy/authgate/internal/models" "github.com/appleboy/authgate/internal/store" @@ -56,7 +57,7 @@ func TestGenerateDeviceCode_ActiveClient(t *testing.T) { DeviceCodeExpiration: 30 * time.Minute, PollingInterval: 5, } - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client client := createTestClient(t, s, true) @@ -80,7 +81,7 @@ func TestGenerateDeviceCode_InactiveClient(t *testing.T) { DeviceCodeExpiration: 30 * time.Minute, PollingInterval: 5, } - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an inactive client client := createTestClient(t, s, false) @@ -105,7 +106,7 @@ func TestGenerateDeviceCode_InvalidClient(t *testing.T) { DeviceCodeExpiration: 30 * time.Minute, PollingInterval: 5, } - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Try to generate device code with non-existent client dc, err := deviceService.GenerateDeviceCode( @@ -126,7 +127,7 @@ func TestAuthorizeDeviceCode_Success(t *testing.T) { DeviceCodeExpiration: 30 * time.Minute, PollingInterval: 5, } - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client and device code client := createTestClient(t, s, true) @@ -154,7 +155,7 @@ func TestAuthorizeDeviceCode_InvalidUserCode(t *testing.T) { DeviceCodeExpiration: 30 * time.Minute, PollingInterval: 5, } - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Try to authorize with invalid user code err := deviceService.AuthorizeDeviceCode( @@ -175,7 +176,7 @@ func TestGetClientNameByUserCode_Success(t *testing.T) { DeviceCodeExpiration: 30 * time.Minute, PollingInterval: 5, } - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client and device code client := createTestClient(t, s, true) @@ -196,7 +197,7 @@ func TestUserCodeNormalization(t *testing.T) { DeviceCodeExpiration: 30 * time.Minute, PollingInterval: 5, } - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client and device code client := createTestClient(t, s, true) diff --git a/internal/services/token.go b/internal/services/token.go index aa005ae..eb42b19 100644 --- a/internal/services/token.go +++ b/internal/services/token.go @@ -9,6 +9,7 @@ import ( "time" "github.com/appleboy/authgate/internal/config" + "github.com/appleboy/authgate/internal/metrics" "github.com/appleboy/authgate/internal/models" "github.com/appleboy/authgate/internal/store" "github.com/appleboy/authgate/internal/token" @@ -37,6 +38,7 @@ type TokenService struct { httpTokenProvider *token.HTTPTokenProvider tokenProviderMode string auditService *AuditService + metrics metrics.MetricsRecorder } func NewTokenService( @@ -47,6 +49,7 @@ func NewTokenService( httpProvider *token.HTTPTokenProvider, providerMode string, auditService *AuditService, + m metrics.MetricsRecorder, ) *TokenService { return &TokenService{ store: s, @@ -56,6 +59,7 @@ func NewTokenService( httpTokenProvider: httpProvider, tokenProviderMode: providerMode, auditService: auditService, + metrics: m, } } @@ -66,6 +70,12 @@ func (s *TokenService) ExchangeDeviceCode( ) (*models.AccessToken, *models.AccessToken, error) { dc, err := s.deviceService.GetDeviceCode(deviceCode) if err != nil { + // Record validation result + result := "invalid" + if errors.Is(err, ErrDeviceCodeExpired) { + result = "expired" + } + s.metrics.RecordOAuthDeviceCodeValidation(result) if errors.Is(err, ErrDeviceCodeExpired) { return nil, nil, ErrExpiredToken } @@ -74,24 +84,32 @@ func (s *TokenService) ExchangeDeviceCode( // Check if client matches if dc.ClientID != clientID { + s.metrics.RecordOAuthDeviceCodeValidation("invalid") return nil, nil, ErrAccessDenied } // Check if client is active client, err := s.store.GetClient(clientID) if err != nil { + s.metrics.RecordOAuthDeviceCodeValidation("invalid") return nil, nil, ErrAccessDenied } if !client.IsActive { + s.metrics.RecordOAuthDeviceCodeValidation("invalid") return nil, nil, ErrAccessDenied } // Check if authorized if !dc.Authorized { + s.metrics.RecordOAuthDeviceCodeValidation("pending") return nil, nil, ErrAuthorizationPending } + // Record successful validation + s.metrics.RecordOAuthDeviceCodeValidation("success") + // Generate access token using provider + start := time.Now() var accessTokenResult *token.TokenResult var providerErr error @@ -218,6 +236,11 @@ func (s *TokenService) ExchangeDeviceCode( return nil, nil, fmt.Errorf("failed to commit transaction: %w", err) } + // Record token issuance metrics + duration := time.Since(start) + s.metrics.RecordTokenIssued("access", "device_code", duration, s.tokenProviderMode) + s.metrics.RecordTokenIssued("refresh", "device_code", duration, s.tokenProviderMode) + // Delete the used device code _ = s.store.DeleteDeviceCodeByID(dc.ID) @@ -336,6 +359,9 @@ func (s *TokenService) RevokeTokenByID(ctx context.Context, tokenID, actorUserID return err } + // Record revocation + s.metrics.RecordTokenRevoked(token.TokenCategory, "user_request") + // Log token revocation if s.auditService != nil { s.auditService.Log(ctx, AuditLogEntry{ @@ -468,29 +494,35 @@ func (s *TokenService) RefreshAccessToken( // 1. Get refresh token from database refreshToken, err := s.store.GetAccessToken(refreshTokenString) if err != nil { + s.metrics.RecordTokenRefresh(false) return nil, nil, token.ErrInvalidRefreshToken } // 2. Verify token category and status if refreshToken.TokenCategory != "refresh" { + s.metrics.RecordTokenRefresh(false) return nil, nil, token.ErrInvalidRefreshToken } if refreshToken.Status != "active" { + s.metrics.RecordTokenRefresh(false) return nil, nil, token.ErrInvalidRefreshToken // Token disabled or revoked } // 3. Verify expiration if refreshToken.IsExpired() { + s.metrics.RecordTokenRefresh(false) return nil, nil, token.ErrExpiredRefreshToken } // 4. Verify client_id if refreshToken.ClientID != clientID { + s.metrics.RecordTokenRefresh(false) return nil, nil, ErrAccessDenied } // 5. Verify scope (cannot upgrade) if !s.validateScopes(refreshToken.Scopes, requestedScopes) { + s.metrics.RecordTokenRefresh(false) return nil, nil, token.ErrInvalidScope } @@ -524,10 +556,12 @@ func (s *TokenService) RefreshAccessToken( if providerErr != nil { log.Printf("[Token] Refresh failed provider=%s: %v", s.tokenProviderMode, providerErr) + s.metrics.RecordTokenRefresh(false) return nil, nil, providerErr } if !refreshResult.Success { + s.metrics.RecordTokenRefresh(false) return nil, nil, fmt.Errorf("token refresh unsuccessful") } @@ -598,9 +632,13 @@ func (s *TokenService) RefreshAccessToken( } if err := tx.Commit().Error; err != nil { + s.metrics.RecordTokenRefresh(false) return nil, nil, fmt.Errorf("failed to commit transaction: %w", err) } + // Record successful refresh + s.metrics.RecordTokenRefresh(true) + // Log token refresh if s.auditService != nil { details := models.AuditDetails{ diff --git a/internal/services/token_test.go b/internal/services/token_test.go index 0939362..ec75247 100644 --- a/internal/services/token_test.go +++ b/internal/services/token_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/appleboy/authgate/internal/config" + "github.com/appleboy/authgate/internal/metrics" "github.com/appleboy/authgate/internal/models" "github.com/appleboy/authgate/internal/store" "github.com/appleboy/authgate/internal/token" @@ -16,9 +17,18 @@ import ( ) func createTestTokenService(s *store.Store, cfg *config.Config) *TokenService { - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) localProvider := token.NewLocalTokenProvider(cfg) - return NewTokenService(s, cfg, deviceService, localProvider, nil, "local", nil) + return NewTokenService( + s, + cfg, + deviceService, + localProvider, + nil, + "local", + nil, + metrics.NewNoopMetrics(), + ) } func createAuthorizedDeviceCode(t *testing.T, s *store.Store, clientID string) *models.DeviceCode { @@ -26,7 +36,7 @@ func createAuthorizedDeviceCode(t *testing.T, s *store.Store, clientID string) * DeviceCodeExpiration: 30 * time.Minute, PollingInterval: 5, } - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Generate device code dc, err := deviceService.GenerateDeviceCode(context.Background(), clientID, "read write") @@ -144,7 +154,7 @@ func TestExchangeDeviceCode_NotAuthorized(t *testing.T) { BaseURL: "http://localhost:8080", } tokenService := createTestTokenService(s, cfg) - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client and device code but don't authorize it client := createTestClient(t, s, true) @@ -174,7 +184,7 @@ func TestExchangeDeviceCode_ExpiredCode(t *testing.T) { BaseURL: "http://localhost:8080", } tokenService := createTestTokenService(s, cfg) - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client and device code (it will be expired) client := createTestClient(t, s, true) @@ -389,7 +399,7 @@ func TestGetUserTokens_Success(t *testing.T) { BaseURL: "http://localhost:8080", } tokenService := createTestTokenService(s, cfg) - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client client := createTestClient(t, s, true) @@ -445,7 +455,7 @@ func TestRevokeAllUserTokens_Success(t *testing.T) { BaseURL: "http://localhost:8080", } tokenService := createTestTokenService(s, cfg) - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client client := createTestClient(t, s, true) @@ -496,7 +506,7 @@ func TestGetUserTokensWithClient_Success(t *testing.T) { BaseURL: "http://localhost:8080", } tokenService := createTestTokenService(s, cfg) - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create an active client client := createTestClient(t, s, true) @@ -541,7 +551,7 @@ func TestGetUserTokensWithClient_MultipleClients(t *testing.T) { BaseURL: "http://localhost:8080", } tokenService := createTestTokenService(s, cfg) - deviceService := NewDeviceService(s, cfg, nil) + deviceService := NewDeviceService(s, cfg, nil, metrics.NewNoopMetrics()) // Create two different clients client1 := createTestClient(t, s, true) diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index 5c0abba..e5c188d 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -194,7 +194,9 @@ func (s *Store) GetUserByID(id string) (*models.User, error) { // GetUserByExternalID finds a user by their external ID and auth source func (s *Store) GetUserByExternalID(externalID, authSource string) (*models.User, error) { var user models.User - if err := s.db.Where("external_id = ? AND auth_source = ?", externalID, authSource).First(&user).Error; err != nil { + if err := s.db.Where("external_id = ? AND auth_source = ?", externalID, authSource). + First(&user). + Error; err != nil { return nil, err } return &user, nil @@ -771,3 +773,34 @@ func (s *Store) GetAuditLogStats(startTime, endTime time.Time) (AuditLogStats, e return stats, nil } + +// CountActiveTokensByCategory counts active, non-expired tokens by category +func (s *Store) CountActiveTokensByCategory(category string) (int64, error) { + var count int64 + err := s.db.Model(&models.AccessToken{}). + Where("token_category = ? AND status = ? AND expires_at > ?", + category, "active", time.Now()). + Count(&count). + Error + return count, err +} + +// CountDeviceCodes returns (total active, pending authorization) +func (s *Store) CountDeviceCodes() (total int64, pending int64, err error) { + // Count all non-expired device codes + err = s.db.Model(&models.DeviceCode{}). + Where("expires_at > ?", time.Now()). + Count(&total). + Error + if err != nil { + return 0, 0, err + } + + // Count pending (not yet authorized) + err = s.db.Model(&models.DeviceCode{}). + Where("expires_at > ? AND authorized = ?", time.Now(), false). + Count(&pending). + Error + + return total, pending, err +} diff --git a/main.go b/main.go index cfa7574..2d52a9b 100644 --- a/main.go +++ b/main.go @@ -131,9 +131,13 @@ func runServer() { log.Fatalf("Failed to initialize database: %v", err) } - // Initialize Prometheus metrics - prometheusMetrics := metrics.Init() - log.Println("Prometheus metrics initialized") + // Initialize metrics + prometheusMetrics := metrics.Init(cfg.MetricsEnabled) + if cfg.MetricsEnabled { + log.Println("Prometheus metrics initialized") + } else { + log.Println("Metrics disabled (using noop implementation)") + } // Initialize audit service auditService := services.NewAuditService(db, cfg.EnableAuditLogging, cfg.AuditLogBufferSize) @@ -155,7 +159,7 @@ func runServer() { cfg.OAuthAutoRegister, auditService, ) - deviceService := services.NewDeviceService(db, cfg, auditService) + deviceService := services.NewDeviceService(db, cfg, auditService, prometheusMetrics) tokenService := services.NewTokenService( db, cfg, @@ -164,6 +168,7 @@ func runServer() { httpTokenProvider, cfg.TokenProviderMode, auditService, + prometheusMetrics, ) clientService := services.NewClientService(db, auditService) @@ -180,6 +185,7 @@ func runServer() { cfg.BaseURL, cfg.SessionFingerprint, cfg.SessionFingerprintIP, + prometheusMetrics, ) deviceHandler := handlers.NewDeviceHandler(deviceService, userService, cfg) tokenHandler := handlers.NewTokenHandler(tokenService, cfg) @@ -191,15 +197,16 @@ func runServer() { oauthHTTPClient, cfg.SessionFingerprint, cfg.SessionFingerprintIP, + prometheusMetrics, ) auditHandler := handlers.NewAuditHandler(auditService) // Setup Gin setupGinMode(cfg) - r := gin.Default() - + r := gin.New() // Setup Prometheus metrics middleware (must be before other routes) r.Use(metrics.HTTPMetricsMiddleware(prometheusMetrics)) + r.Use(gin.Logger(), gin.Recovery()) // Setup IP middleware (for audit logging) r.Use(util.IPMiddleware()) @@ -413,7 +420,9 @@ func runServer() { for { select { case <-ticker.C: - if deleted, err := auditService.CleanupOldLogs(cfg.AuditLogRetention); err != nil { + if deleted, err := auditService.CleanupOldLogs( + cfg.AuditLogRetention, + ); err != nil { log.Printf("Failed to cleanup old audit logs: %v", err) } else if deleted > 0 { log.Printf("Cleaned up %d old audit logs", deleted) @@ -425,6 +434,26 @@ func runServer() { }) } + // Add metrics gauge update job (runs every 30 seconds) + if cfg.MetricsEnabled { + m.AddRunningJob(func(ctx context.Context) error { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + // Update immediately on startup + updateGaugeMetrics(db, prometheusMetrics) + + for { + select { + case <-ticker.C: + updateGaugeMetrics(db, prometheusMetrics) + case <-ctx.Done(): + return nil + } + } + }) + } + // Wait for graceful shutdown <-m.Done() } @@ -751,3 +780,29 @@ var ginModeLogMessage = map[bool]string{ true: "Release (production)", false: "Debug (development)", } + +// updateGaugeMetrics updates gauge metrics with current database state +func updateGaugeMetrics(db *store.Store, m metrics.MetricsRecorder) { + // Update active tokens count + activeAccessTokens, err := db.CountActiveTokensByCategory("access") + if err != nil { + log.Printf("Failed to count access tokens: %v", err) + } else { + m.SetActiveTokensCount("access", int(activeAccessTokens)) + } + + activeRefreshTokens, err := db.CountActiveTokensByCategory("refresh") + if err != nil { + log.Printf("Failed to count refresh tokens: %v", err) + } else { + m.SetActiveTokensCount("refresh", int(activeRefreshTokens)) + } + + // Update active device codes count + totalDeviceCodes, pendingDeviceCodes, err := db.CountDeviceCodes() + if err != nil { + log.Printf("Failed to count device codes: %v", err) + } else { + m.SetActiveDeviceCodesCount(int(totalDeviceCodes), int(pendingDeviceCodes)) + } +}