diff --git a/pkg/auth/monitored_token_source.go b/pkg/auth/monitored_token_source.go index 60357a2a39..34a8839829 100644 --- a/pkg/auth/monitored_token_source.go +++ b/pkg/auth/monitored_token_source.go @@ -283,11 +283,11 @@ func (mts *MonitoredTokenSource) Stopped() <-chan struct{} { return mts.stopped } -// Token retrieves a token, retrying with exponential backoff on transient errors -// (see isTransientNetworkError for the full list). On non-transient errors -// (OAuth 4xx, TLS failures) it marks the workload as unauthenticated and returns -// immediately. Context cancellation (workload removal) stops the retry without -// marking the workload as unauthenticated. +// Token retrieves a token, retrying with exponential backoff on transient +// errors and marking the workload as unauthenticated on non-transient errors. +// See isTransientNetworkError for the classification rule. Context +// cancellation (workload removal) stops the retry without marking the +// workload as unauthenticated. // // Concurrent callers are deduplicated via singleflight so that only one retry // loop runs at a time during transient failures. @@ -382,13 +382,20 @@ func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) { return false, wait } -// isTransientNetworkError reports whether err represents a transient condition -// (DNS failure, TCP transport error, timeout, OAuth server 5xx, unparsable -// token response) that is likely to resolve on its own. +// isTransientNetworkError reports whether err represents a transient +// condition that is likely to resolve on its own. The categories are: // -// OAuth2 client-level auth failures (invalid_grant, 401, 400) and TLS errors -// (certificate verification, handshake failure) are NOT considered transient and -// return false so the workload is marked unauthenticated immediately. +// - Network-level failures: DNS lookup errors, TCP transport errors, +// timeouts. +// - OAuth token-endpoint responses classified as transient by +// isTransientRetrieveError. +// - Unparsable token responses on a 2xx status (typically an HTML page +// from a load balancer or CDN). +// +// All other errors return false, causing the workload to be marked +// unauthenticated. TLS failures (certificate verification, handshake +// failure) are intentionally non-transient even though they surface +// through net.OpError like transport-level errors. // // The function is side-effect free; callers that want to emit a DCR // remediation hint on a permanent 4xx must do so themselves at the @@ -400,17 +407,8 @@ func isTransientNetworkError(err error) bool { return false } - // OAuth HTTP-level errors: 5xx (Bad Gateway, Service Unavailable, Gateway - // Timeout) are transient server-side issues that typically resolve on their - // own. 4xx errors (invalid_grant, invalid_client) are permanent auth failures. if retrieveErr, ok := errors.AsType[*oauth2.RetrieveError](err); ok { - if retrieveErr.Response != nil && retrieveErr.Response.StatusCode >= 500 { - slog.Debug("treating OAuth server error as transient", - "status_code", retrieveErr.Response.StatusCode, - ) - return true - } - return false + return isTransientRetrieveError(retrieveErr) } // Non-JSON responses from the OAuth server (e.g. load balancer HTML pages). @@ -445,33 +443,76 @@ func isTransientNetworkError(err error) bool { } // isPermanentTokenEndpointError reports whether err is an *oauth2.RetrieveError -// whose status implies the cached client credentials are themselves the -// problem — specifically 400 (invalid_grant / invalid_client), 401, or -// 403. Used at state-transition boundaries to decide whether to emit a -// DCR/CIMD remediation hint alongside the unauthentication. +// whose response carries a structured RFC 6749 'error' code, implying the +// OAuth server itself rendered a verdict on the cached credentials +// (invalid_grant, invalid_client, etc.). Used at state-transition +// boundaries to decide whether to emit a DCR/CIMD remediation hint +// alongside the unauthentication. // -// Other 4xx codes are intentionally NOT treated as permanent here even -// though isTransientNetworkError classifies the whole RetrieveError -// branch as non-transient. 408 (Request Timeout) and 429 (Too Many -// Requests) are typically transient back-pressure that the operator -// cannot remediate by deleting cached credentials; firing the -// "delete the cached credentials and restart" Warn on those would -// mislead operators chasing a transient hiccup. The narrower allowlist -// keeps the remediation hint truthful. +// This is the strict inverse of isTransientRetrieveError on the +// *oauth2.RetrieveError branch: a response is "permanent" iff the +// classifier would NOT call it transient. Concretely, the Warn fires +// only when ErrorCode is populated. 4xx responses without an OAuth +// error code (HTML pages from a WAF, CDN, or reverse proxy) — like +// 5xx, 429, 408, and nil-Response shapes — are treated as +// non-permanent because we have no OAuth-protocol verdict to act on. +// Recommending the user delete cached credentials based on a non- +// spec-compliant response would frequently mislead operators whose +// real problem is upstream of the OAuth server. func isPermanentTokenEndpointError(err error) bool { retrieveErr, ok := errors.AsType[*oauth2.RetrieveError](err) - if !ok { + if !ok || retrieveErr.Response == nil { return false } + return !isTransientRetrieveError(retrieveErr) +} + +// isTransientRetrieveError reports whether an *oauth2.RetrieveError should +// be treated as transient. The classification rules are: +// +// - nil Response: non-transient. There is no signal to act on, so we fall +// through to the unauthenticated path rather than retry blindly. +// - 5xx status: transient (server-side issue, likely to resolve). +// - 429 Too Many Requests: transient regardless of body (HTTP standard). +// - 4xx with an empty ErrorCode: transient. The oauth2 library populates +// ErrorCode from the RFC 6749 'error' field in a JSON response body. An +// empty ErrorCode means the response was not a parseable OAuth error — +// typically an HTML page from a WAF, CDN, or reverse proxy that +// intercepted the request before it reached the OAuth server. These +// infrastructure errors (Cloudflare blocks, residential-IP allowlist +// misses, transient bad-config deploys) commonly resolve on their own. +// - 4xx with a populated ErrorCode: permanent. The OAuth server returned +// a structured error code (invalid_grant, invalid_client, etc.) telling +// us specifically what's wrong; retrying won't help. +func isTransientRetrieveError(retrieveErr *oauth2.RetrieveError) bool { if retrieveErr.Response == nil { return false } - switch retrieveErr.Response.StatusCode { - case http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden: + statusCode := retrieveErr.Response.StatusCode + + if statusCode >= 500 { + slog.Debug("treating OAuth server error as transient", + "status_code", statusCode, + ) return true - default: - return false } + + if statusCode == http.StatusTooManyRequests { + slog.Debug("treating OAuth rate-limit response as transient", + "status_code", statusCode, + "error_code", retrieveErr.ErrorCode, + ) + return true + } + + if retrieveErr.ErrorCode == "" { + slog.Debug("treating OAuth 4xx without error code as transient", + "status_code", statusCode, + ) + return true + } + + return false } // isOAuthParseError detects errors from the oauth2 library that indicate the diff --git a/pkg/auth/monitored_token_source_test.go b/pkg/auth/monitored_token_source_test.go index 247765d322..5860cc33a6 100644 --- a/pkg/auth/monitored_token_source_test.go +++ b/pkg/auth/monitored_token_source_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/http" + "net/http/httptest" "net/url" "os" "strings" @@ -86,7 +87,10 @@ func (m *mockTokenSource) Token() (*oauth2.Token, error) { return tok, err } -// createRetrieveError creates an error for testing token failures +// createRetrieveError creates an error for testing token failures. ErrorCode +// is left unset, mirroring what golang.org/x/oauth2 produces when the response +// body is not a parseable RFC 6749 error response (e.g. an HTML page from a +// WAF or load balancer). func createRetrieveError(statusCode int, body string) *oauth2.RetrieveError { response := &http.Response{ StatusCode: statusCode, @@ -98,6 +102,16 @@ func createRetrieveError(statusCode int, body string) *oauth2.RetrieveError { } } +// createRetrieveErrorWithCode is like createRetrieveError but also sets the +// ErrorCode field, mirroring what golang.org/x/oauth2 populates when the +// server responds with a parseable JSON error body containing an "error" +// field. +func createRetrieveErrorWithCode(statusCode int, errorCode, body string) *oauth2.RetrieveError { + err := createRetrieveError(statusCode, body) + err.ErrorCode = errorCode + return err +} + func TestMonitoredTokenSource_SuccessfulTokenRetrieval(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -143,7 +157,7 @@ func TestMonitoredTokenSource_AuthenticationErrorMarksUnauthenticated(t *testing tokenSource := newMockTokenSource() // Create an error that simulates token retrieval failure - retrieveErr := createRetrieveError(http.StatusBadRequest, `{"error":"invalid_grant","error_description":"refresh token expired"}`) + retrieveErr := createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant","error_description":"refresh token expired"}`) tokenSource.setTokenFn(func() (*oauth2.Token, error) { return nil, retrieveErr }) @@ -238,7 +252,7 @@ func TestMonitoredTokenSource_BackgroundMonitoring(t *testing.T) { }, nil } // Subsequent calls: return authentication error - retrieveErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`) + retrieveErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`) return nil, retrieveErr }) @@ -399,7 +413,7 @@ func TestMonitoredTokenSource_MultipleCallsToToken(t *testing.T) { statusUpdater, statusManager := newMockStatusUpdater(ctrl) tokenSource := newMockTokenSource() - retrieveErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`) + retrieveErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`) tokenSource.setTokenFn(func() (*oauth2.Token, error) { return nil, retrieveErr }) @@ -578,12 +592,14 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T err error isTransient bool // true → monitor retries; false → monitor marks unauthenticated }{ - // Non-transient: plain and auth-level errors must fail fast. + // Non-transient: plain errors and OAuth protocol failures (4xx with a + // populated RFC 6749 error code) must fail fast. {name: "plain error", err: errors.New("some error"), isTransient: false}, {name: "context.Canceled", err: context.Canceled, isTransient: false}, {name: "context.DeadlineExceeded", err: context.DeadlineExceeded, isTransient: false}, - {name: "oauth2.RetrieveError 401", err: createRetrieveError(http.StatusUnauthorized, "unauthorized"), isTransient: false}, - {name: "oauth2.RetrieveError 400 invalid_grant", err: createRetrieveError(http.StatusBadRequest, "invalid_grant"), isTransient: false}, + {name: "oauth2.RetrieveError 400 invalid_grant", err: createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant"}`), isTransient: false}, + {name: "oauth2.RetrieveError 401 invalid_client", err: createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_client", `{"error":"invalid_client"}`), isTransient: false}, + {name: "oauth2.RetrieveError 403 unauthorized_client", err: createRetrieveErrorWithCode(http.StatusForbidden, "unauthorized_client", `{"error":"unauthorized_client"}`), isTransient: false}, {name: "oauth2.RetrieveError nil response", err: &oauth2.RetrieveError{}, isTransient: false}, // Transient: network-level errors must be retried. {name: "*net.DNSError timeout", err: &net.DNSError{Err: "i/o timeout", Name: "example.com", IsTimeout: true}, isTransient: true}, @@ -595,6 +611,17 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T {name: "oauth2.RetrieveError 502", err: createRetrieveError(http.StatusBadGateway, "Bad Gateway"), isTransient: true}, {name: "oauth2.RetrieveError 503", err: createRetrieveError(http.StatusServiceUnavailable, "Service Unavailable"), isTransient: true}, {name: "oauth2.RetrieveError 504", err: createRetrieveError(http.StatusGatewayTimeout, "Gateway Timeout"), isTransient: true}, + // Transient: 4xx without an RFC 6749 error code in the body. + // These are infrastructure-level errors (WAF, CDN, proxy) that + // commonly resolve on their own, not OAuth protocol failures. + {name: "oauth2.RetrieveError 401 with HTML body", err: createRetrieveError(http.StatusUnauthorized, "Unauthorized"), isTransient: true}, + {name: "oauth2.RetrieveError 403 WAF block", err: createRetrieveError(http.StatusForbidden, "Cloudflare Firewall Block"), isTransient: true}, + {name: "oauth2.RetrieveError 400 with empty body", err: createRetrieveError(http.StatusBadRequest, ""), isTransient: true}, + {name: "oauth2.RetrieveError 408 request timeout", err: createRetrieveError(http.StatusRequestTimeout, ""), isTransient: true}, + // Transient: 429 Too Many Requests is retryable per HTTP standard + // regardless of body content. + {name: "oauth2.RetrieveError 429 empty body", err: createRetrieveError(http.StatusTooManyRequests, ""), isTransient: true}, + {name: "oauth2.RetrieveError 429 with rate-limit error code", err: createRetrieveErrorWithCode(http.StatusTooManyRequests, "rate_limit_exceeded", `{"error":"rate_limit_exceeded"}`), isTransient: true}, // Transient: unparsable OAuth responses (HTML from load balancer on 200). {name: "oauth2 cannot parse json", err: fmt.Errorf("oauth2: cannot parse json: invalid character '<'"), isTransient: true}, {name: "wrapped oauth2 parse error", err: fmt.Errorf("refresh failed: %w", fmt.Errorf("oauth2: cannot parse json: invalid character '<'")), isTransient: true}, @@ -657,6 +684,145 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T } } +// TestIsPermanentTokenEndpointError verifies that isPermanentTokenEndpointError +// is the strict inverse of classifyOAuthRetrieveError on the +// *oauth2.RetrieveError branch (with a non-nil Response). The DCR/CIMD +// remediation Warn fires only when the OAuth server returned a structured +// RFC 6749 error code; non-spec-compliant responses (HTML pages from a WAF, +// CDN, or reverse proxy) should not trigger that Warn because they carry no +// OAuth-protocol verdict. +// +// Existing Token() / markAsUnauthenticated tests reach this function through +// indirect call paths and yield 100% line coverage, but none of them assert +// on the boolean it returns. This test pins the behavioral contract directly. +func TestIsPermanentTokenEndpointError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + isPermanent bool + }{ + // Not an *oauth2.RetrieveError at all — no OAuth verdict to act on. + {name: "plain error", err: errors.New("some error"), isPermanent: false}, + {name: "*oauth2.RetrieveError with nil Response", err: &oauth2.RetrieveError{}, isPermanent: false}, + // Transient HTTP-level conditions — never permanent. + {name: "5xx server error", err: createRetrieveError(http.StatusInternalServerError, "Internal Server Error"), isPermanent: false}, + {name: "429 Too Many Requests", err: createRetrieveError(http.StatusTooManyRequests, ""), isPermanent: false}, + {name: "408 Request Timeout", err: createRetrieveError(http.StatusRequestTimeout, ""), isPermanent: false}, + // 4xx without an RFC 6749 error code — infrastructure response, no OAuth verdict. + {name: "401 with HTML body (WAF)", err: createRetrieveError(http.StatusUnauthorized, "Unauthorized"), isPermanent: false}, + {name: "403 with HTML body (Cloudflare)", err: createRetrieveError(http.StatusForbidden, "Firewall Block"), isPermanent: false}, + // 4xx with an RFC 6749 error code — OAuth server rendered a verdict. + {name: "400 invalid_grant", err: createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant"}`), isPermanent: true}, + {name: "401 invalid_client", err: createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_client", `{"error":"invalid_client"}`), isPermanent: true}, + {name: "403 unauthorized_client", err: createRetrieveErrorWithCode(http.StatusForbidden, "unauthorized_client", `{"error":"unauthorized_client"}`), isPermanent: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := isPermanentTokenEndpointError(tt.err) + if got != tt.isPermanent { + t.Errorf("isPermanentTokenEndpointError(%v) = %v, want %v", + tt.err, got, tt.isPermanent) + } + }) + } +} + +// TestIsTransientNetworkError_AgainstRealOAuth2Library is a contract test, +// not a unit test of the package's own logic. It pins the assumption that +// isTransientRetrieveError relies on: that golang.org/x/oauth2 populates +// RetrieveError.ErrorCode iff the response carries a parseable RFC 6749 +// 'error' field (whether JSON or form-encoded), and leaves ErrorCode empty +// for non-spec-compliant response shapes (HTML pages from a WAF, CDN, or +// reverse proxy). +// +// Cases here are deliberately limited to response shapes where the +// synthetic test helpers (createRetrieveError, createRetrieveErrorWithCode) +// could plausibly diverge from reality. Cases unambiguously covered by the +// synthetic table (clearly populated ErrorCode JSON, status-code-only +// branches like 5xx and 429) are intentionally not duplicated here. +func TestIsTransientNetworkError_AgainstRealOAuth2Library(t *testing.T) { + t.Parallel() + + const refreshToken = "test-refresh-token" + + tests := []struct { + name string + handler http.HandlerFunc + isTransient bool + }{ + { + // HTML 4xx is the canonical WAF/CDN block shape. Pinning that + // the library leaves ErrorCode empty here is what underpins the + // "infrastructure error" branch of isTransientRetrieveError. + name: "403 with HTML body (Cloudflare WAF)", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("Cloudflare Firewall Block")) + }, + isTransient: true, + }, + { + name: "401 with HTML body", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("Unauthorized")) + }, + isTransient: true, + }, + { + // Form-encoded error responses are non-spec but supported by + // the library. Pinning that the library DOES populate ErrorCode + // from form-encoded bodies — a synthetic helper used naively + // (createRetrieveError without WithCode) would lie about this. + name: "400 with form-encoded invalid_grant", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("error=invalid_grant&error_description=refresh+token+expired")) + }, + isTransient: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(tt.handler) + t.Cleanup(server.Close) + + cfg := &oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + Endpoint: oauth2.Endpoint{TokenURL: server.URL}, + } + + expired := &oauth2.Token{ + AccessToken: "expired-access-token", + RefreshToken: refreshToken, + Expiry: time.Now().Add(-time.Hour), + } + + _, err := cfg.TokenSource(context.Background(), expired).Token() + if err == nil { + t.Fatalf("expected refresh to fail, got nil error") + } + + got := isTransientNetworkError(err) + if got != tt.isTransient { + t.Errorf("isTransientNetworkError(%v) = %v, want %v", + err, got, tt.isTransient) + } + }) + } +} + // --- background monitor transient-error behaviour --- // TestMonitoredTokenSource_TransientErrorRetriesAndSucceeds verifies that when the @@ -771,7 +937,7 @@ func TestMonitoredTokenSource_TransientThenNonTransientMarksUnauthenticated(t *t Times(1) transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} - nonTransientErr := createRetrieveError(http.StatusUnauthorized, `{"error":"invalid_token"}`) + nonTransientErr := createRetrieveErrorWithCode(http.StatusUnauthorized, "invalid_token", `{"error":"invalid_token"}`) tokenSource.setTokenFn(func() (*oauth2.Token, error) { switch tokenSource.callCount {