From 9d9f68bfd7eea6b9b03bab86df80ee8df329ebff Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 29 Apr 2026 12:55:53 +0100 Subject: [PATCH] Extract and consume identity from OAuth2 token response MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire identityFromToken into the embedded auth server's OAuth2 upstream provider. Extension point: the existing tokenResponseRewriter (which already reads and parses every successful token-endpoint response to normalise non-standard envelopes) gains a parallel responsibility — extract user identity claims from the same body when the operator configures IdentityFromTokenConfig with gjson dot-notation paths. Identity extraction runs on the RAW pre-rewrite body, so paths are resolved against the original provider response even when TokenResponseMapping is also configured. The rewriter passes the extracted *partialIdentity back to exchangeCodeForTokens via a returned reference; RefreshTokens passes nil and the rewriter is either omitted entirely or runs with identityCfg=nil because providers like Snowflake omit username on refresh and identity is cached at auth-code time in session storage. The new priority chain in BaseOAuth2Provider.ExchangeCodeForIdentity: 1. IdentityFromToken — when configured, return the extracted identity. If extraction failed (path didn't resolve), return ErrIdentityResolutionFailed without consulting userInfo or synthesising — the operator's "identity is in the token" claim is explicit and we surface its failure rather than silently fall through. 2. UserInfo — existing fetchUserInfo path, unchanged. 3. Synthesis — existing synthesizeIdentity path (PR 5094), unchanged. OIDC providers always have ID-token-derived identity, so the OIDC provider's ExchangeCodeForIdentity discards the rewriter's identityFromToken return value with a defensive WARN if a future config-loader bug ever sets IdentityFromToken on an OIDC base config (structurally absent on the OIDC CRD type today). The tripwire test asserts userinfo HTTP is never contacted when identityFromToken is configured, including on extraction failure. Other new tests cover the happy path, userInfo-only regression, the refresh path (no identity extraction), and information disclosure (raw body never appears in error messages or logs above DEBUG). Two existing slog.Info calls in exchangeCodeForTokens are downgraded to slog.Debug to comply with the silent-success-at-INFO rule. Closes: #5156 --- docs/arch/11-auth-server-storage.md | 37 +- pkg/authserver/upstream/oauth2.go | 156 ++++++-- pkg/authserver/upstream/oauth2_test.go | 372 +++++++++++++++++- pkg/authserver/upstream/oidc.go | 23 +- pkg/authserver/upstream/token_exchange.go | 85 +++- .../upstream/token_exchange_test.go | 207 +++++++++- 6 files changed, 801 insertions(+), 79 deletions(-) diff --git a/docs/arch/11-auth-server-storage.md b/docs/arch/11-auth-server-storage.md index b4c21b681e..001458cbe8 100644 --- a/docs/arch/11-auth-server-storage.md +++ b/docs/arch/11-auth-server-storage.md @@ -64,19 +64,46 @@ The storage layer implements multiple interfaces from the [fosite](https://githu - Memory backend: `pkg/authserver/storage/memory.go` - Redis backend: `pkg/authserver/storage/redis.go` -## Synthesis-mode subjects +## Identity resolution for pure OAuth2 providers -OAuth2 upstreams configured without a userInfo endpoint use a fallback identity-resolution mode: the embedded auth server synthesizes a non-PII subject by hashing the upstream access token. The mode changes what `UserStorage` and `UpstreamTokenStorage` see and is observable to operators inspecting stored state. +For pure OAuth 2.0 upstream providers (`OAuth2Config`), OIDC is unavailable and there is no ID token. `BaseOAuth2Provider.ExchangeCodeForIdentity` resolves user identity through a three-way priority chain. Each path has distinct implications for `UserStorage`, `UpstreamTokenStorage`, and the Redis secondary index. -**When the path triggers.** Pure OAuth 2.0 upstream provider (`OAuth2Config`) configured with `userInfo == nil`. Reached at `BaseOAuth2Provider.ExchangeCodeForIdentity` after token exchange when no userInfo endpoint is available to consult. OIDC providers and OAuth2 providers with `userInfo` configured continue to resolve identity normally and are not affected. +### IdentityFromToken (priority 1) + +An operator opt-in path that extracts identity claims directly from the token endpoint response body, skipping the userinfo HTTP call entirely. + +**When the path triggers.** `IdentityFromToken` is configured on the upstream provider (`p.config.IdentityFromToken != nil`). The `tokenResponseRewriter` intercepts the token endpoint response and runs extraction against the raw pre-rewrite body; the result is available to `ExchangeCodeForIdentity` without an additional round-trip. + +**Subject format.** Real, stable subject string extracted from the token response body via a gjson dot-notation path (e.g. `username`, `authed_user.id`). For token responses that embed a JWT, the `@upstreamjwt` modifier decodes the payload for further drilling (e.g. `access_token|@upstreamjwt|sub`). The `@upstreamjwt` modifier performs no signature verification — it is intended only for JWTs received directly from the upstream token endpoint over a TLS-authenticated channel. The returned `*Identity` carries `Synthetic = false`. Path semantics and trust-model notes are documented on the runtime config struct `IdentityFromTokenConfig` in `pkg/authserver/upstream/identity_from_token.go`. The corresponding CRD type (`cmd/thv-operator/api/v1alpha1.IdentityFromTokenConfig`) is defined in a sibling PR; operator-to-runner translation of this config lands separately. + +**`UserResolver` interaction.** Because `Identity.Synthetic` is false, `callback.go` takes the normal path: `UserResolver.ResolveUser` runs, a row is created (or looked up) in `UserStorage`, a provider-identities entry is written, and `UpdateLastAuthenticated` is called. `UpstreamTokens.UserID` carries the resolved internal user UUID, not the raw operator-supplied subject string. + +**Reverse-index implication (Redis backend).** Stable user IDs mean `KeyTypeUserUpstream` works as designed — one set per user accumulates session IDs across re-authentications. No set churn. + +**Operator visibility.** The `IdentitySynthesized` condition does not fire for upstreams using `IdentityFromToken`. However, `SyntheticIdentityUpstreams()` (the controller-side predicate that drives the condition) currently checks only for `userInfo == nil` and does not yet inspect `IdentityFromToken`. Until the CRD type and controller logic land in a follow-up, an upstream with `IdentityFromToken` configured but no `userInfo` will still trigger `IdentitySynthesizedActive` — even though synthesis is not reached at runtime. + +**Implementation.** +- `pkg/authserver/upstream/oauth2.go` — `ExchangeCodeForIdentity` priority 1 branch +- `pkg/authserver/upstream/identity_from_token.go` — `IdentityFromTokenConfig`, `extractIdentityFromTokenResponse`, `@upstreamjwt` modifier +- `pkg/authserver/upstream/token_exchange.go` — `tokenResponseRewriter.RoundTrip` extracts identity from the raw pre-rewrite body + +### UserInfo endpoint (priority 2) + +Existing behavior. When `IdentityFromToken` is unconfigured and `userInfo` is set, `fetchUserInfo` is called with the upstream access token. Subject, name, and email come from the userinfo response. `UserResolver.ResolveUser` runs normally, `Identity.Synthetic` is false. + +### Synthesis-mode subjects (priority 3) + +Reached when both `IdentityFromToken` is unconfigured AND `userInfo` is absent. The embedded auth server synthesizes a non-PII subject by hashing the upstream access token. The mode changes what `UserStorage` and `UpstreamTokenStorage` see and is observable to operators inspecting stored state. + +**When the path triggers.** Pure OAuth 2.0 upstream provider (`OAuth2Config`) where both `IdentityFromToken` and `userInfo` are unconfigured. Reached at `BaseOAuth2Provider.ExchangeCodeForIdentity` as the final fallback. OIDC providers and OAuth2 providers with either `IdentityFromToken` or `userInfo` configured are not affected. **Subject format.** `tk-` followed by 32 lowercase hex characters (the first 16 bytes of `SHA-256(accessToken)`), e.g. `tk-89abcdef0123456789abcdef01234567`. The output is opaque: assuming the upstream issues opaque (non-JWT) bearer tokens, the digest reveals nothing about the input that an attacker holding a candidate token could not already confirm by re-hashing. The returned `*Identity` carries `Synthetic = true`; the `upstream.IsSynthesizedSubject(string)` predicate lets bare-string consumers recognize the prefix. -**`UserResolver` bypass.** Synthetic identities skip `UserResolver.ResolveUser` entirely — no row is created in `UserStorage`, no entry is written to provider-identities, and `UpdateLastAuthenticated` is not called. The synthesized subject rotates per access token, so persisting it would create a fresh `users` row on every re-authentication. `UpstreamTokens.UserID` therefore carries the `tk-…` value directly rather than a stable internal UUID. +**`UserResolver` bypass.** The bypass is gated on `Identity.Synthetic` in `callback.go` — synthesis is the only path that sets this field. Synthetic identities skip `UserResolver.ResolveUser` entirely — no row is created in `UserStorage`, no entry is written to provider-identities, and `UpdateLastAuthenticated` is not called. The synthesized subject rotates per access token, so persisting it would create a fresh `users` row on every re-authentication. `UpstreamTokens.UserID` therefore carries the `tk-…` value directly rather than a stable internal UUID. **Reverse-index implication (Redis backend).** The `KeyTypeUserUpstream` secondary-index set under `thv:auth:{ns:name}:user:upstream:{userID}` is designed around stable user IDs — one set per user, holding all of that user's session IDs. Under synthesis the userID rotates with every re-authentication, so each session lands in its own one-element set. Reads continue to work, but set churn is much higher than under OIDC. The existing TODO at `pkg/authserver/storage/redis.go:43-45` to scan and clean up stale secondary-index entries applies, and synthesis-mode workloads make a periodic scan more important. -**Operator visibility.** When at least one configured OAuth2 upstream has `userInfo == nil`, the controller surfaces the `IdentitySynthesized` condition on the `MCPExternalAuthConfig` and `VirtualMCPServer` status (Reason `IdentitySynthesizedActive`, naming the affected upstreams). The condition flips to `False` (Reason `IdentitySynthesizedInactive`) once every upstream has `userInfo` configured. +**Operator visibility.** When at least one configured OAuth2 upstream has `userInfo == nil`, the controller surfaces the `IdentitySynthesized` condition on the `MCPExternalAuthConfig` and `VirtualMCPServer` status (Reason `IdentitySynthesizedActive`, naming the affected upstreams). The condition flips to `False` (Reason `IdentitySynthesizedInactive`) once every upstream has `userInfo` configured. Note: the controller predicate (`SyntheticIdentityUpstreams`) checks only for `userInfo == nil` and does not yet account for `IdentityFromToken`; see the known gap noted under priority 1. **Implementation.** - `pkg/authserver/upstream/oauth2.go` — `synthesizeIdentity`, `synthesizeSubjectFromAccessToken`, `IsSynthesizedSubject` diff --git a/pkg/authserver/upstream/oauth2.go b/pkg/authserver/upstream/oauth2.go index 8d0f349685..d6b5c41886 100644 --- a/pkg/authserver/upstream/oauth2.go +++ b/pkg/authserver/upstream/oauth2.go @@ -152,7 +152,16 @@ type OAuth2Config struct { // When set, the provider performs the token exchange HTTP call directly (bypassing // golang.org/x/oauth2) and extracts fields using gjson dot-notation paths. // When nil, standard OAuth 2.0 token response parsing is used. + // See also: IdentityFromToken for extracting user identity from the same response. TokenResponseMapping *TokenResponseMapping `json:"token_response_mapping,omitempty" yaml:"token_response_mapping,omitempty"` + + // IdentityFromToken extracts user identity from the token-endpoint response + // body when the upstream provider includes identity claims there (e.g., + // Snowflake's `username`, Slack's `authed_user.id`). When set, the embedded + // auth server skips the userinfo HTTP call entirely. See the CRD type + // (cmd/thv-operator/api/v1alpha1.IdentityFromTokenConfig) for the + // authoritative trust-model and uniqueness documentation. + IdentityFromToken *IdentityFromTokenConfig `json:"identity_from_token,omitempty" yaml:"identity_from_token,omitempty"` } // TokenResponseMapping configures extraction of token fields from non-standard @@ -195,6 +204,11 @@ func (c *OAuth2Config) Validate() error { return errors.New("token_response_mapping.access_token_path is required when token_response_mapping is set") } } + if c.IdentityFromToken != nil { + if c.IdentityFromToken.SubjectPath == "" { + return errors.New("identity_from_token.subject_path is required when identity_from_token is set") + } + } return c.CommonOAuthConfig.Validate() } @@ -400,39 +414,71 @@ func (p *BaseOAuth2Provider) buildAuthorizationURL( } // ExchangeCodeForIdentity exchanges an authorization code for tokens and resolves -// the user's identity in a single atomic operation. -// For pure OAuth2 providers, identity is resolved via UserInfo when configured; -// otherwise Subject is synthesized via synthesizeIdentity (which rejects empty -// access tokens to prevent the well-known sha256("") subject collision) and -// Name/Email are left empty. The nonce parameter is ignored (no ID token). +// the user's identity in a single atomic operation. For pure OAuth2 providers +// (no ID token) the priority chain is: +// +// 1. IdentityFromToken (operator opt-in): extract identity claims from the +// token-endpoint response body using gjson paths. The userinfo HTTP call +// is skipped entirely. +// 2. UserInfo endpoint: fetch identity from the configured userinfo URL. +// 3. Synthesis: when neither is configured, synthesizeIdentity derives a +// non-PII Subject from the access token (rejects empty tokens to prevent +// the well-known sha256("") collision). Name and Email are empty; +// Synthetic=true tells the callback handler to bypass UserResolver +// because the subject rotates per access token. +// +// The nonce parameter is ignored (no ID token to validate). func (p *BaseOAuth2Provider) ExchangeCodeForIdentity(ctx context.Context, code, codeVerifier, _ string) (*Identity, error) { - tokens, err := p.exchangeCodeForTokens(ctx, code, codeVerifier) + exchanged, err := p.exchangeCodeForTokens(ctx, code, codeVerifier) if err != nil { return nil, err } - // No userinfo: synthesize a non-PII subject from the access token. - // Synthetic=true tells the callback handler to bypass UserResolver — the - // synthesized subject rotates per access token, so persisting it would - // create a new `users` row on every re-authentication. - if p.config.UserInfo == nil { - return synthesizeIdentity(tokens) - } - - userInfo, err := p.fetchUserInfo(ctx, tokens.AccessToken) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrIdentityResolutionFailed, err) - } - if userInfo == nil || userInfo.Subject == "" { - return nil, ErrIdentityResolutionFailed - } - - return &Identity{ - Tokens: tokens, - Subject: userInfo.Subject, - Name: userInfo.Name, - Email: userInfo.Email, - }, nil + // Priority 1: identityFromToken (configured by operator). + if p.config.IdentityFromToken != nil { + if exchanged.identity == nil { + // The rewriter logged the extraction failure at WARN with the + // operator-supplied path. Surface the same error here so callers + // can report it without requiring WARN-level log access. + if exchanged.extractionErr != nil { + return nil, exchanged.extractionErr + } + // Unreachable in practice: when identity is nil the rewriter always + // sets extractionErr. Kept as a safe fallback. + return nil, fmt.Errorf( + "%w: identityFromToken configured but extraction failed", + ErrIdentityResolutionFailed, + ) + } + return &Identity{ + Tokens: exchanged.tokens, + Subject: exchanged.identity.Subject, + Name: exchanged.identity.Name, + Email: exchanged.identity.Email, + }, nil + } + + // Priority 2: userInfo (existing behavior). + if p.config.UserInfo != nil { + userInfo, err := p.fetchUserInfo(ctx, exchanged.tokens.AccessToken) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrIdentityResolutionFailed, err) + } + if userInfo == nil || userInfo.Subject == "" { + return nil, ErrIdentityResolutionFailed + } + return &Identity{ + Tokens: exchanged.tokens, + Subject: userInfo.Subject, + Name: userInfo.Name, + Email: userInfo.Email, + }, nil + } + + // Priority 3: synthesis (PR 5094). Subject derived from access token; rotates + // per token. The callback handler treats Synthetic=true as opt-out from the + // user-resolver to avoid creating a fresh users row on every re-auth. + return synthesizeIdentity(exchanged.tokens) } // synthesizedSubjectPrefix tags subjects produced by @@ -496,21 +542,45 @@ func synthesizeIdentity(tokens *Tokens) (*Identity, error) { }, nil } +// tokenExchangeResult bundles the outputs of a successful token exchange: +// the obtained tokens, any identity extracted from the token response body, +// and any error from the identity extraction step. The exchange-level error +// is returned as the function's own error return value. +// +// extractionErr is populated when IdentityFromToken is configured but the +// extractor could not resolve the subject path. It carries operator-actionable +// diagnostics (path name, type description) and already wraps +// ErrIdentityResolutionFailed. The caller must check extractionErr when +// identity is nil and IdentityFromToken is set. +type tokenExchangeResult struct { + tokens *Tokens + identity *partialIdentity + extractionErr error +} + // exchangeCodeForTokens exchanges an authorization code for tokens with the upstream IDP. -func (p *BaseOAuth2Provider) exchangeCodeForTokens(ctx context.Context, code, codeVerifier string) (*Tokens, error) { +// It returns a tokenExchangeResult containing the tokens, any identity extracted from +// the token response body, and any extraction error. The function error is the +// exchange-level error (network, HTTP, token parsing). +func (p *BaseOAuth2Provider) exchangeCodeForTokens( + ctx context.Context, code, codeVerifier string, +) (*tokenExchangeResult, error) { if code == "" { return nil, errors.New("authorization code is required") } - slog.Info("exchanging authorization code for tokens", + slog.Debug("exchanging authorization code for tokens", "token_endpoint", p.config.TokenEndpoint, "has_pkce_verifier", codeVerifier != "", ) - // Wrap HTTP client with token response rewriter if mapping is configured. - // This normalizes non-standard responses (e.g., GovSlack's nested fields) - // before the oauth2 library parses them, keeping the standard exchange flow. - httpClient := wrapHTTPClientWithMapping(p.httpClient, p.config.TokenResponseMapping, p.config.TokenEndpoint) + // Wrap HTTP client with token response rewriter if mapping or identity extraction + // is configured. At auth-code time, both mapping (field normalization) and + // identityCfg (identity extraction) may be active together. Keep a reference to + // the rewriter so we can read extractedIdentity and extractionErr after Exchange returns. + httpClient, rewriter := wrapHTTPClientForTokenExchange( + p.httpClient, p.config.TokenResponseMapping, p.config.IdentityFromToken, p.config.TokenEndpoint, + ) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) // Build exchange options @@ -534,7 +604,16 @@ func (p *BaseOAuth2Provider) exchangeCodeForTokens(ctx context.Context, code, co "expires_at", expiresAtLogValue(tokens.ExpiresAt), ) - return tokens, nil + // Read any identity and extraction error captured by the rewriter during the + // token round-trip. rewriter is nil when neither mapping nor identityCfg is + // configured; nil-safe. + result := &tokenExchangeResult{tokens: tokens} + if rewriter != nil { + result.identity = rewriter.extractedIdentity + result.extractionErr = rewriter.extractionErr + } + + return result, nil } // RefreshTokens refreshes the upstream IDP tokens. @@ -559,7 +638,12 @@ func (p *BaseOAuth2Provider) RefreshTokens(ctx context.Context, refreshToken, _ ) // Wrap HTTP client with token response rewriter if mapping is configured. - httpClient := wrapHTTPClientWithMapping(p.httpClient, p.config.TokenResponseMapping, p.config.TokenEndpoint) + // Identity extraction (identityCfg) is intentionally nil here: per Snowflake's + // contract and the general design, the username/identity field is only present + // in the initial auth-code response and is omitted on refresh. Identity is + // cached at auth-code time and read from session storage on subsequent requests. + // The rewriter is discarded because refresh does not produce identity. + httpClient, _ := wrapHTTPClientForTokenExchange(p.httpClient, p.config.TokenResponseMapping, nil, p.config.TokenEndpoint) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) opts := []oauth2.AuthCodeOption{ diff --git a/pkg/authserver/upstream/oauth2_test.go b/pkg/authserver/upstream/oauth2_test.go index 0ffcb154db..5770878a20 100644 --- a/pkg/authserver/upstream/oauth2_test.go +++ b/pkg/authserver/upstream/oauth2_test.go @@ -19,10 +19,12 @@ import ( "encoding/hex" "encoding/json" "errors" + "fmt" "net/http" "net/http/httptest" "net/url" "strings" + "sync/atomic" "testing" "time" @@ -408,7 +410,7 @@ func TestBaseOAuth2Provider_exchangeCodeForTokens(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - tokens, err := provider.exchangeCodeForTokens(ctx, "test-auth-code", "test-verifier") + result, err := provider.exchangeCodeForTokens(ctx, "test-auth-code", "test-verifier") require.NoError(t, err) // Verify request parameters @@ -420,12 +422,12 @@ func TestBaseOAuth2Provider_exchangeCodeForTokens(t *testing.T) { assert.Equal(t, "http://localhost:8080/callback", receivedParams.Get("redirect_uri")) // Verify response - assert.Equal(t, "exchanged-access-token", tokens.AccessToken) - assert.Equal(t, "exchanged-refresh-token", tokens.RefreshToken) + assert.Equal(t, "exchanged-access-token", result.tokens.AccessToken) + assert.Equal(t, "exchanged-refresh-token", result.tokens.RefreshToken) // Verify expiration is set approximately correctly expectedExpiry := time.Now().Add(7200 * time.Second) - assert.WithinDuration(t, expectedExpiry, tokens.ExpiresAt, 10*time.Second) + assert.WithinDuration(t, expectedExpiry, result.tokens.ExpiresAt, 10*time.Second) }) t.Run("handles error response from token endpoint", func(t *testing.T) { @@ -660,11 +662,11 @@ func TestBaseOAuth2Provider_exchangeCodeForTokens(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - tokens, err := provider.exchangeCodeForTokens(ctx, "test-code", "") + result, err := provider.exchangeCodeForTokens(ctx, "test-code", "") require.NoError(t, err) // No expires_in in the response means the token has no expiry. - assert.True(t, tokens.ExpiresAt.IsZero(), "ExpiresAt should be zero for non-expiring tokens") + assert.True(t, result.tokens.ExpiresAt.IsZero(), "ExpiresAt should be zero for non-expiring tokens") }) } @@ -935,9 +937,9 @@ func TestBaseOAuth2Provider_WithOAuth2HTTPClient(t *testing.T) { // Verify the provider works with custom client ctx := context.Background() - tokens, err := provider.exchangeCodeForTokens(ctx, "test-code", "") + result, err := provider.exchangeCodeForTokens(ctx, "test-code", "") require.NoError(t, err) - assert.NotEmpty(t, tokens.AccessToken) + assert.NotEmpty(t, result.tokens.AccessToken) } func TestBaseOAuth2Provider_TokenTypeValidation(t *testing.T) { @@ -1069,11 +1071,11 @@ func TestBaseOAuth2Provider_IDToken(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - tokens, err := provider.exchangeCodeForTokens(ctx, "test-code", "") + result, err := provider.exchangeCodeForTokens(ctx, "test-code", "") require.NoError(t, err) // OAuth2 providers can also return ID tokens if they support hybrid flows - assert.Equal(t, "test-id-token.payload.signature", tokens.IDToken) + assert.Equal(t, "test-id-token.payload.signature", result.tokens.IDToken) } func Test_validateRedirectURI(t *testing.T) { @@ -2137,3 +2139,353 @@ func TestAuthorizationURL_AdditionalAuthorizationParams(t *testing.T) { assert.Equal(t, "caller-value", parsed.Query().Get("custom")) }) } + +// tokenBodyWithUsername is a minimal valid token response body that includes a +// "username" field used across several identityFromToken test cases. +const tokenBodyWithUsername = `{"access_token":"a","token_type":"Bearer","username":"u1"}` + +// newTokenResponseServer is a test helper that starts an httptest.Server whose +// /token endpoint returns tokenBody as the JSON response body, and whose +// /authorize endpoint always returns 200 OK. +func newTokenResponseServer(t *testing.T, tokenBody string) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("/authorize", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tokenBody)) + }) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv +} + +// TestBaseOAuth2Provider_ExchangeCodeForIdentity_IdentityFromToken covers the +// priority-1 path where IdentityFromToken is configured. +func TestBaseOAuth2Provider_ExchangeCodeForIdentity_IdentityFromToken(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("identityFromToken resolves subject name email", func(t *testing.T) { + t.Parallel() + + tokenBody := `{"access_token":"a","token_type":"Bearer","username":"u1","display_name":"User One","email":"u1@example.com"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "username", + NamePath: "display_name", + EmailPath: "email", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.Equal(t, "u1", identity.Subject) + assert.Equal(t, "User One", identity.Name) + assert.Equal(t, "u1@example.com", identity.Email) + assert.NotEmpty(t, identity.Tokens.AccessToken) + }) + + t.Run("@upstreamjwt modifier resolves identity from JWT-shaped access token", func(t *testing.T) { + t.Parallel() + + RegisterModifiers() + + accessToken := makeJWT(`{"sub":"u-jwt-1","name":"JWT User","email":"jwt@example.com"}`) + tokenBody := fmt.Sprintf(`{"access_token":%q,"token_type":"Bearer"}`, accessToken) + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "access_token|@upstreamjwt|sub", + NamePath: "access_token|@upstreamjwt|name", + EmailPath: "access_token|@upstreamjwt|email", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.Equal(t, "u-jwt-1", identity.Subject) + assert.Equal(t, "JWT User", identity.Name) + assert.Equal(t, "jwt@example.com", identity.Email) + assert.False(t, identity.Synthetic) + }) + + t.Run("identityFromToken bypasses userinfo endpoint entirely", func(t *testing.T) { + t.Parallel() + + // httptest.Server dispatches each request on its own goroutine, so + // the counter must be accessed atomically — t.Errorf inside the + // handler is the load-bearing assertion; the counter just gives a + // numeric value the final assertion can read race-free. + var tripwireCallCount atomic.Int32 + tripwire := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + tripwireCallCount.Add(1) + t.Errorf("userinfo endpoint must NOT be called when identityFromToken is configured") + })) + t.Cleanup(tripwire.Close) + + tokenSrv := newTokenResponseServer(t, tokenBodyWithUsername) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "username", + }, + // UserInfo is also configured — identityFromToken must win and + // the tripwire userinfo server must never be contacted. + UserInfo: &UserInfoConfig{ + EndpointURL: tripwire.URL, + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.Equal(t, "u1", identity.Subject) + assert.Equal(t, int32(0), tripwireCallCount.Load(), "userinfo endpoint must not be called") + }) + + t.Run("identityFromToken extraction failure returns ErrIdentityResolutionFailed without calling userinfo", func(t *testing.T) { + t.Parallel() + + var tripwireCallCount atomic.Int32 + tripwire := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + tripwireCallCount.Add(1) + t.Errorf("userinfo endpoint must NOT be called when identityFromToken is configured") + })) + t.Cleanup(tripwire.Close) + + // Token body does NOT contain "missing_path" + tokenSrv := newTokenResponseServer(t, tokenBodyWithUsername) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "missing_path", + }, + UserInfo: &UserInfoConfig{ + EndpointURL: tripwire.URL, + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + _, err = provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrIdentityResolutionFailed)) + assert.Equal(t, int32(0), tripwireCallCount.Load(), "userinfo endpoint must not be called on extraction failure") + }) + + t.Run("extraction failure error surfaces the misconfigured path name", func(t *testing.T) { + t.Parallel() + + tokenSrv := newTokenResponseServer(t, tokenBodyWithUsername) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + // Path that does not exist in the token response so the extractor + // produces a diagnostic that names the path and the failure reason. + SubjectPath: "nonexistent_field", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + _, err = provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrIdentityResolutionFailed)) + // The error must carry the operator-supplied path name so the + // misconfiguration is diagnosable without log access. + assert.Contains(t, err.Error(), "nonexistent_field", + "error must contain the misconfigured subject path name") + assert.Contains(t, err.Error(), "not found", + "error must describe why extraction failed") + }) + + t.Run("userInfo-only path is unchanged when identityFromToken is not set", func(t *testing.T) { + t.Parallel() + + userInfoSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "sub": "u-456", + "name": "Bob", + "email": "bob@example.com", + }) + })) + t.Cleanup(userInfoSrv.Close) + + tokenBody := `{"access_token":"a","token_type":"Bearer"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + UserInfo: &UserInfoConfig{ + EndpointURL: userInfoSrv.URL, + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.Equal(t, "u-456", identity.Subject) + assert.Equal(t, "Bob", identity.Name) + assert.Equal(t, "bob@example.com", identity.Email) + }) + + t.Run("neither identityFromToken nor userInfo set falls through to synthesis", func(t *testing.T) { + t.Parallel() + + tokenBody := `{"access_token":"a","token_type":"Bearer"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + // Both identity surfaces absent — the priority chain falls through to + // synthesizeIdentity (PR 5094): a non-PII Subject derived from the + // access token, Synthetic=true, no Name/Email. + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.True(t, identity.Synthetic, "expected synthesized identity when neither IdentityFromToken nor UserInfo is set") + assert.True(t, IsSynthesizedSubject(identity.Subject), + "expected synthesized subject prefix; got %q", identity.Subject) + assert.Empty(t, identity.Name, "synthesized identity has no name") + assert.Empty(t, identity.Email, "synthesized identity has no email") + }) + + t.Run("refresh path does not trigger identity extraction", func(t *testing.T) { + t.Parallel() + + tokenBody := `{"access_token":"refreshed","token_type":"Bearer","refresh_token":"new-rt","username":"u1"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "username", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + // RefreshTokens must not blow up and must return valid tokens. + tokens, err := provider.RefreshTokens(ctx, "old-refresh-token", "") + require.NoError(t, err) + assert.Equal(t, "refreshed", tokens.AccessToken) + + // Verify that the wrapped client on the refresh path uses nil identityCfg + // by inspecting the transport after constructing it the same way RefreshTokens does. + _, rewriter := wrapHTTPClientForTokenExchange(provider.httpClient, provider.config.TokenResponseMapping, nil, provider.config.TokenEndpoint) + // When only mapping is nil and identityCfg is nil, wrapHTTPClientForTokenExchange + // returns the original client and nil rewriter. When mapping is non-nil, + // the rewriter has nil identityCfg. Here, both are nil, so rewriter is nil. + assert.Nil(t, rewriter, "refresh path rewriter must be nil when both mapping and identityCfg are nil") + }) + + t.Run("error message does not contain token body content", func(t *testing.T) { + t.Parallel() + + secretMarker := "SUPER_SECRET_TOKEN_BODY_MARKER_XYZ789" + tokenBody := `{"access_token":"` + secretMarker + `","token_type":"Bearer","username":"u1"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + // Deliberately wrong path to trigger extraction failure. + SubjectPath: "missing_field", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + _, err = provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrIdentityResolutionFailed)) + // The error must not leak any part of the token response body. + assert.NotContains(t, err.Error(), secretMarker, + "error message must not contain token body content") + }) +} diff --git a/pkg/authserver/upstream/oidc.go b/pkg/authserver/upstream/oidc.go index 1ea53c4424..ed1b653e91 100644 --- a/pkg/authserver/upstream/oidc.go +++ b/pkg/authserver/upstream/oidc.go @@ -242,27 +242,36 @@ func (p *OIDCProviderImpl) ExchangeCodeForIdentity( return nil, errors.New("OIDC endpoints not discovered") } - tokens, err := p.exchangeCodeForTokens(ctx, code, codeVerifier) + // OIDC resolves identity from the validated ID token, not the token response + // body, so the extracted identity return value is intentionally discarded. + // Defense-in-depth: warn if the OAuth2 base config carries a non-nil + // IdentityFromToken — the field is structurally absent on the OIDC CRD type + // today, but a future config-loader bug or hand-built runtime config could + // silently drop the operator's intent without this signal. + if p.config.IdentityFromToken != nil { + slog.Warn("OIDC provider ignoring IdentityFromToken; identity is resolved from the validated ID token") + } + exchanged, err := p.exchangeCodeForTokens(ctx, code, codeVerifier) if err != nil { return nil, err } // OIDC-specific: ID token MUST be present per Section 3.1.3.3. - if tokens.IDToken == "" { + if exchanged.tokens.IDToken == "" { return nil, fmt.Errorf("%w: ID token required for OIDC provider", ErrIdentityResolutionFailed) } // Validate ID token with nonce in a single pass — no double-validation. - validatedToken, err := p.validateIDToken(ctx, tokens.IDToken, nonce) + validatedToken, err := p.validateIDToken(ctx, exchanged.tokens.IDToken, nonce) if err != nil { slog.Debug("id token validation failed", "error", err) return nil, fmt.Errorf("%w: %w", ErrIdentityResolutionFailed, err) } slog.Debug("authorization code exchange successful", - "has_refresh_token", tokens.RefreshToken != "", - "has_id_token", tokens.IDToken != "", - "expires_at", expiresAtLogValue(tokens.ExpiresAt), + "has_refresh_token", exchanged.tokens.RefreshToken != "", + "has_id_token", exchanged.tokens.IDToken != "", + "expires_at", expiresAtLogValue(exchanged.tokens.ExpiresAt), ) // Extract optional standard claims (name, email) from ID token @@ -278,7 +287,7 @@ func (p *OIDCProviderImpl) ExchangeCodeForIdentity( } return &Identity{ - Tokens: tokens, + Tokens: exchanged.tokens, Subject: validatedToken.Subject, Name: idClaims.Name, Email: idClaims.Email, diff --git a/pkg/authserver/upstream/token_exchange.go b/pkg/authserver/upstream/token_exchange.go index 16fa65c1da..4203e23ce2 100644 --- a/pkg/authserver/upstream/token_exchange.go +++ b/pkg/authserver/upstream/token_exchange.go @@ -7,28 +7,44 @@ import ( "bytes" "encoding/json" "io" + "log/slog" "net/http" "github.com/tidwall/gjson" ) // tokenResponseRewriter is an http.RoundTripper that normalizes non-standard -// OAuth token responses before the golang.org/x/oauth2 library parses them. +// OAuth token responses before the golang.org/x/oauth2 library parses them, +// and optionally extracts user identity claims from the raw response body. // // Some providers (e.g., GovSlack) nest token fields under non-standard paths // like "authed_user.access_token" instead of the top-level "access_token". // This RoundTripper intercepts the response, extracts fields using gjson // dot-notation paths, and rewrites the response body with standard top-level // field names so the oauth2 library can parse them normally. +// +// When identityCfg is set, identity extraction runs on the RAW pre-rewrite +// body before any field relocation occurs. The extracted identity is stored +// in extractedIdentity and consumed by the caller after Exchange returns. +// If extraction fails, extractionErr holds the cause so callers can surface +// operator-actionable diagnostics (path names, type descriptions) without +// re-reading the response body. type tokenResponseRewriter struct { - base http.RoundTripper - mapping *TokenResponseMapping - tokenURL string + base http.RoundTripper + mapping *TokenResponseMapping + identityCfg *IdentityFromTokenConfig + tokenURL string + extractedIdentity *partialIdentity + extractionErr error } // RoundTrip intercepts HTTP responses from the token endpoint and rewrites // the JSON body to place mapped fields at the top level. Non-token-endpoint // requests (e.g., userInfo) pass through unchanged. +// +// When identityCfg is set, identity is extracted from the raw response body +// BEFORE the rewrite step so that identity paths are resolved against the +// original provider response, not the normalized form. func (t *tokenResponseRewriter) RoundTrip(req *http.Request) (*http.Response, error) { resp, err := t.base.RoundTrip(req) if err != nil { @@ -51,7 +67,31 @@ func (t *tokenResponseRewriter) RoundTrip(req *http.Request) (*http.Response, er return nil, err } - rewritten := rewriteTokenResponse(body, t.mapping) + // Extract from the raw body before rewriteTokenResponse runs. The rewrite + // never touches identity-shaped fields today, but extracting first makes + // the ordering invariant independent of that assumption. + if t.identityCfg != nil { + result, extractErr := extractIdentityFromTokenResponse(body, t.identityCfg) + if extractErr != nil { + // WARN so an operator misconfiguration (e.g., wrong subjectPath) is + // visible without enabling DEBUG. The error is safe to log: it + // contains operator-supplied paths and type descriptions, never + // any portion of the response body. + slog.Warn("identity extraction from token response failed", "error", extractErr) + t.extractionErr = extractErr + } else { + t.extractedIdentity = &result + } + } + + // Only run the field-rewrite step when a mapping is configured. + // When mapping is nil (e.g. only identityCfg is set), pass the body through unchanged. + var rewritten []byte + if t.mapping != nil { + rewritten = rewriteTokenResponse(body, t.mapping) + } else { + rewritten = body + } resp.Body = io.NopCloser(bytes.NewReader(rewritten)) resp.ContentLength = int64(len(rewritten)) @@ -107,12 +147,20 @@ func rewriteTokenResponse(body []byte, mapping *TokenResponseMapping) []byte { return rewritten } -// wrapHTTPClientWithMapping wraps an HTTP client's transport with a -// tokenResponseRewriter when a TokenResponseMapping is configured. -// Returns the original client unchanged if mapping is nil. -func wrapHTTPClientWithMapping(client *http.Client, mapping *TokenResponseMapping, tokenURL string) *http.Client { - if mapping == nil { - return client +// wrapHTTPClientForTokenExchange wraps an HTTP client's transport with a +// tokenResponseRewriter when either a TokenResponseMapping or an +// IdentityFromTokenConfig is configured (or both). Returns the original client +// and nil rewriter when both are nil, so the standard oauth2 library path is +// used. The returned rewriter (when non-nil) can be read after Exchange returns +// to retrieve any identity extracted during the token round-trip. +func wrapHTTPClientForTokenExchange( + client *http.Client, + mapping *TokenResponseMapping, + identityCfg *IdentityFromTokenConfig, + tokenURL string, +) (*http.Client, *tokenResponseRewriter) { + if mapping == nil && identityCfg == nil { + return client, nil } base := client.Transport @@ -120,14 +168,17 @@ func wrapHTTPClientWithMapping(client *http.Client, mapping *TokenResponseMappin base = http.DefaultTransport } + rewriter := &tokenResponseRewriter{ + base: base, + mapping: mapping, + identityCfg: identityCfg, + tokenURL: tokenURL, + } + // Create a shallow copy to avoid mutating the original client wrapped := *client - wrapped.Transport = &tokenResponseRewriter{ - base: base, - mapping: mapping, - tokenURL: tokenURL, - } - return &wrapped + wrapped.Transport = rewriter + return &wrapped, rewriter } // pathOrDefault returns the path if non-empty, otherwise returns the default. diff --git a/pkg/authserver/upstream/token_exchange_test.go b/pkg/authserver/upstream/token_exchange_test.go index 125c78750e..d60270313d 100644 --- a/pkg/authserver/upstream/token_exchange_test.go +++ b/pkg/authserver/upstream/token_exchange_test.go @@ -15,6 +15,14 @@ import ( "github.com/stretchr/testify/require" ) +// tokenEndpointHandler returns an HTTP handler that responds with the given JSON body. +func tokenEndpointHandler(body string) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(body)) + } +} + func TestRewriteTokenResponse(t *testing.T) { t.Parallel() @@ -154,7 +162,7 @@ func TestTokenResponseRewriter_TokenEndpoint(t *testing.T) { ScopePath: "authed_user.scope", } - client := wrapHTTPClientWithMapping(http.DefaultClient, mapping, tokenServer.URL) + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, nil, tokenServer.URL) req, err := http.NewRequest("POST", tokenServer.URL, strings.NewReader("grant_type=authorization_code")) require.NoError(t, err) @@ -187,7 +195,7 @@ func TestTokenResponseRewriter_NonTokenEndpoint(t *testing.T) { mapping := &TokenResponseMapping{AccessTokenPath: "authed_user.access_token"} // Token URL points elsewhere, so this server's responses should pass through unchanged - client := wrapHTTPClientWithMapping(http.DefaultClient, mapping, "https://other.example.com/token") + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, nil, "https://other.example.com/token") req, err := http.NewRequest("GET", server.URL, nil) require.NoError(t, err) @@ -211,8 +219,9 @@ func TestWrapHTTPClientWithMapping_NilMapping(t *testing.T) { t.Parallel() original := &http.Client{} - result := wrapHTTPClientWithMapping(original, nil, "https://example.com/token") + result, rewriter := wrapHTTPClientForTokenExchange(original, nil, nil, "https://example.com/token") assert.Same(t, original, result) + assert.Nil(t, rewriter) } func TestTokenResponseRewriter_ErrorResponse(t *testing.T) { @@ -225,7 +234,7 @@ func TestTokenResponseRewriter_ErrorResponse(t *testing.T) { defer tokenServer.Close() mapping := &TokenResponseMapping{AccessTokenPath: "authed_user.access_token"} - client := wrapHTTPClientWithMapping(http.DefaultClient, mapping, tokenServer.URL) + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, nil, tokenServer.URL) req, err := http.NewRequest("POST", tokenServer.URL, strings.NewReader("grant_type=authorization_code")) require.NoError(t, err) @@ -280,3 +289,193 @@ func TestOAuth2Config_Validate_TokenResponseMapping(t *testing.T) { }) } } + +func TestOAuth2Config_Validate_IdentityFromToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + identityCfg *IdentityFromTokenConfig + wantErr bool + errContains string + }{ + {name: "nil identity config is valid", identityCfg: nil, wantErr: false}, + { + name: "valid identity config with subject path", + identityCfg: &IdentityFromTokenConfig{SubjectPath: "username"}, + wantErr: false, + }, + { + name: "missing subject path", + identityCfg: &IdentityFromTokenConfig{NamePath: "name"}, + wantErr: true, + errContains: "identity_from_token.subject_path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + cfg := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ClientID: "test", RedirectURI: "http://localhost/callback"}, + AuthorizationEndpoint: "https://example.com/authorize", + TokenEndpoint: "https://example.com/token", + IdentityFromToken: tt.identityCfg, + } + err := cfg.Validate() + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestTokenResponseRewriter_IdentityCfgNil verifies that when identityCfg is nil, +// extractedIdentity remains nil after RoundTrip even when a mapping is configured. +func TestTokenResponseRewriter_IdentityCfgNil(t *testing.T) { + t.Parallel() + + body := `{"access_token":"tok","token_type":"Bearer","username":"u1"}` + server := httptest.NewServer(tokenEndpointHandler(body)) + t.Cleanup(server.Close) + + mapping := &TokenResponseMapping{AccessTokenPath: "access_token"} + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, nil, server.URL) + + transport, ok := client.Transport.(*tokenResponseRewriter) + require.True(t, ok, "expected *tokenResponseRewriter transport") + + req, err := http.NewRequest("POST", server.URL, strings.NewReader("grant_type=authorization_code")) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) + + assert.Nil(t, transport.extractedIdentity) +} + +// TestTokenResponseRewriter_IdentityCfgSet verifies that when identityCfg is set +// and the body contains the subject path, extractedIdentity is populated. +func TestTokenResponseRewriter_IdentityCfgSet(t *testing.T) { + t.Parallel() + + body := `{"access_token":"a","token_type":"Bearer","username":"u1"}` + server := httptest.NewServer(tokenEndpointHandler(body)) + t.Cleanup(server.Close) + + identityCfg := &IdentityFromTokenConfig{SubjectPath: "username"} + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, nil, identityCfg, server.URL) + + transport, ok := client.Transport.(*tokenResponseRewriter) + require.True(t, ok, "expected *tokenResponseRewriter transport") + + req, err := http.NewRequest("POST", server.URL, strings.NewReader("grant_type=authorization_code")) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) + + require.NotNil(t, transport.extractedIdentity) + assert.Equal(t, "u1", transport.extractedIdentity.Subject) +} + +// TestTokenResponseRewriter_IdentityFromRawBody verifies the raw-body +// invariant: identity is resolved against the pre-rewrite body. The fixture +// places the subject at a nested path ("authed_user.id") and configures a +// mapping that lifts access_token to the top level; the assertion confirms +// extraction reads the original nested location, not a post-rewrite shape. +func TestTokenResponseRewriter_IdentityFromRawBody(t *testing.T) { + t.Parallel() + + rawBody := `{"authed_user":{"access_token":"x","id":"U1234"},"username":""}` + server := httptest.NewServer(tokenEndpointHandler(rawBody)) + t.Cleanup(server.Close) + + mapping := &TokenResponseMapping{AccessTokenPath: "authed_user.access_token"} + identityCfg := &IdentityFromTokenConfig{SubjectPath: "authed_user.id"} + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, identityCfg, server.URL) + + transport, ok := client.Transport.(*tokenResponseRewriter) + require.True(t, ok, "expected *tokenResponseRewriter transport") + + req, err := http.NewRequest("POST", server.URL, strings.NewReader("grant_type=authorization_code")) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) + + // Identity should be extracted from the raw body where authed_user.id == "U1234" + require.NotNil(t, transport.extractedIdentity) + assert.Equal(t, "U1234", transport.extractedIdentity.Subject) + + // The rewritten body should have access_token lifted to top-level + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(respBody, &parsed)) + assert.Equal(t, "x", parsed["access_token"]) + assert.Equal(t, "Bearer", parsed["token_type"]) +} + +// TestWrapHTTPClientForTokenExchange_OnlyIdentityCfg verifies that wrapping occurs +// when only identityCfg is set (mapping is nil). +func TestWrapHTTPClientForTokenExchange_OnlyIdentityCfg(t *testing.T) { + t.Parallel() + + original := &http.Client{} + identityCfg := &IdentityFromTokenConfig{SubjectPath: "username"} + result, rewriter := wrapHTTPClientForTokenExchange(original, nil, identityCfg, "https://example.com/token") + + assert.NotSame(t, original, result) + assert.NotNil(t, rewriter) + _, ok := result.Transport.(*tokenResponseRewriter) + assert.True(t, ok, "expected *tokenResponseRewriter transport when only identityCfg is set") +} + +// TestWrapHTTPClientForTokenExchange_OnlyMapping verifies that wrapping occurs +// when only mapping is set (identityCfg is nil), preserving existing behavior. +func TestWrapHTTPClientForTokenExchange_OnlyMapping(t *testing.T) { + t.Parallel() + + original := &http.Client{} + mapping := &TokenResponseMapping{AccessTokenPath: "access_token"} + result, rewriter := wrapHTTPClientForTokenExchange(original, mapping, nil, "https://example.com/token") + + assert.NotSame(t, original, result) + assert.NotNil(t, rewriter) + _, ok := result.Transport.(*tokenResponseRewriter) + assert.True(t, ok, "expected *tokenResponseRewriter transport when only mapping is set") +} + +// TestWrapHTTPClientForTokenExchange_BothSet verifies that wrapping occurs +// when both mapping and identityCfg are set. +func TestWrapHTTPClientForTokenExchange_BothSet(t *testing.T) { + t.Parallel() + + original := &http.Client{} + mapping := &TokenResponseMapping{AccessTokenPath: "access_token"} + identityCfg := &IdentityFromTokenConfig{SubjectPath: "username"} + result, rewriter := wrapHTTPClientForTokenExchange(original, mapping, identityCfg, "https://example.com/token") + + assert.NotSame(t, original, result) + assert.NotNil(t, rewriter) + _, ok := result.Transport.(*tokenResponseRewriter) + assert.True(t, ok, "expected *tokenResponseRewriter transport when both are set") +} + +// TestWrapHTTPClientForTokenExchange_BothNil verifies that the original client +// is returned unchanged when both mapping and identityCfg are nil. +func TestWrapHTTPClientForTokenExchange_BothNil(t *testing.T) { + t.Parallel() + + original := &http.Client{} + result, rewriter := wrapHTTPClientForTokenExchange(original, nil, nil, "https://example.com/token") + assert.Same(t, original, result) + assert.Nil(t, rewriter) +}