diff --git a/docs/arch/11-auth-server-storage.md b/docs/arch/11-auth-server-storage.md index b4c21b681e..001458cbe8 100644 --- a/docs/arch/11-auth-server-storage.md +++ b/docs/arch/11-auth-server-storage.md @@ -64,19 +64,46 @@ The storage layer implements multiple interfaces from the [fosite](https://githu - Memory backend: `pkg/authserver/storage/memory.go` - Redis backend: `pkg/authserver/storage/redis.go` -## Synthesis-mode subjects +## Identity resolution for pure OAuth2 providers -OAuth2 upstreams configured without a userInfo endpoint use a fallback identity-resolution mode: the embedded auth server synthesizes a non-PII subject by hashing the upstream access token. The mode changes what `UserStorage` and `UpstreamTokenStorage` see and is observable to operators inspecting stored state. +For pure OAuth 2.0 upstream providers (`OAuth2Config`), OIDC is unavailable and there is no ID token. `BaseOAuth2Provider.ExchangeCodeForIdentity` resolves user identity through a three-way priority chain. Each path has distinct implications for `UserStorage`, `UpstreamTokenStorage`, and the Redis secondary index. -**When the path triggers.** Pure OAuth 2.0 upstream provider (`OAuth2Config`) configured with `userInfo == nil`. Reached at `BaseOAuth2Provider.ExchangeCodeForIdentity` after token exchange when no userInfo endpoint is available to consult. OIDC providers and OAuth2 providers with `userInfo` configured continue to resolve identity normally and are not affected. +### IdentityFromToken (priority 1) + +An operator opt-in path that extracts identity claims directly from the token endpoint response body, skipping the userinfo HTTP call entirely. + +**When the path triggers.** `IdentityFromToken` is configured on the upstream provider (`p.config.IdentityFromToken != nil`). The `tokenResponseRewriter` intercepts the token endpoint response and runs extraction against the raw pre-rewrite body; the result is available to `ExchangeCodeForIdentity` without an additional round-trip. + +**Subject format.** Real, stable subject string extracted from the token response body via a gjson dot-notation path (e.g. `username`, `authed_user.id`). For token responses that embed a JWT, the `@upstreamjwt` modifier decodes the payload for further drilling (e.g. `access_token|@upstreamjwt|sub`). The `@upstreamjwt` modifier performs no signature verification — it is intended only for JWTs received directly from the upstream token endpoint over a TLS-authenticated channel. The returned `*Identity` carries `Synthetic = false`. Path semantics and trust-model notes are documented on the runtime config struct `IdentityFromTokenConfig` in `pkg/authserver/upstream/identity_from_token.go`. The corresponding CRD type (`cmd/thv-operator/api/v1alpha1.IdentityFromTokenConfig`) is defined in a sibling PR; operator-to-runner translation of this config lands separately. + +**`UserResolver` interaction.** Because `Identity.Synthetic` is false, `callback.go` takes the normal path: `UserResolver.ResolveUser` runs, a row is created (or looked up) in `UserStorage`, a provider-identities entry is written, and `UpdateLastAuthenticated` is called. `UpstreamTokens.UserID` carries the resolved internal user UUID, not the raw operator-supplied subject string. + +**Reverse-index implication (Redis backend).** Stable user IDs mean `KeyTypeUserUpstream` works as designed — one set per user accumulates session IDs across re-authentications. No set churn. + +**Operator visibility.** The `IdentitySynthesized` condition does not fire for upstreams using `IdentityFromToken`. However, `SyntheticIdentityUpstreams()` (the controller-side predicate that drives the condition) currently checks only for `userInfo == nil` and does not yet inspect `IdentityFromToken`. Until the CRD type and controller logic land in a follow-up, an upstream with `IdentityFromToken` configured but no `userInfo` will still trigger `IdentitySynthesizedActive` — even though synthesis is not reached at runtime. + +**Implementation.** +- `pkg/authserver/upstream/oauth2.go` — `ExchangeCodeForIdentity` priority 1 branch +- `pkg/authserver/upstream/identity_from_token.go` — `IdentityFromTokenConfig`, `extractIdentityFromTokenResponse`, `@upstreamjwt` modifier +- `pkg/authserver/upstream/token_exchange.go` — `tokenResponseRewriter.RoundTrip` extracts identity from the raw pre-rewrite body + +### UserInfo endpoint (priority 2) + +Existing behavior. When `IdentityFromToken` is unconfigured and `userInfo` is set, `fetchUserInfo` is called with the upstream access token. Subject, name, and email come from the userinfo response. `UserResolver.ResolveUser` runs normally, `Identity.Synthetic` is false. + +### Synthesis-mode subjects (priority 3) + +Reached when both `IdentityFromToken` is unconfigured AND `userInfo` is absent. The embedded auth server synthesizes a non-PII subject by hashing the upstream access token. The mode changes what `UserStorage` and `UpstreamTokenStorage` see and is observable to operators inspecting stored state. + +**When the path triggers.** Pure OAuth 2.0 upstream provider (`OAuth2Config`) where both `IdentityFromToken` and `userInfo` are unconfigured. Reached at `BaseOAuth2Provider.ExchangeCodeForIdentity` as the final fallback. OIDC providers and OAuth2 providers with either `IdentityFromToken` or `userInfo` configured are not affected. **Subject format.** `tk-` followed by 32 lowercase hex characters (the first 16 bytes of `SHA-256(accessToken)`), e.g. `tk-89abcdef0123456789abcdef01234567`. The output is opaque: assuming the upstream issues opaque (non-JWT) bearer tokens, the digest reveals nothing about the input that an attacker holding a candidate token could not already confirm by re-hashing. The returned `*Identity` carries `Synthetic = true`; the `upstream.IsSynthesizedSubject(string)` predicate lets bare-string consumers recognize the prefix. -**`UserResolver` bypass.** Synthetic identities skip `UserResolver.ResolveUser` entirely — no row is created in `UserStorage`, no entry is written to provider-identities, and `UpdateLastAuthenticated` is not called. The synthesized subject rotates per access token, so persisting it would create a fresh `users` row on every re-authentication. `UpstreamTokens.UserID` therefore carries the `tk-…` value directly rather than a stable internal UUID. +**`UserResolver` bypass.** The bypass is gated on `Identity.Synthetic` in `callback.go` — synthesis is the only path that sets this field. Synthetic identities skip `UserResolver.ResolveUser` entirely — no row is created in `UserStorage`, no entry is written to provider-identities, and `UpdateLastAuthenticated` is not called. The synthesized subject rotates per access token, so persisting it would create a fresh `users` row on every re-authentication. `UpstreamTokens.UserID` therefore carries the `tk-…` value directly rather than a stable internal UUID. **Reverse-index implication (Redis backend).** The `KeyTypeUserUpstream` secondary-index set under `thv:auth:{ns:name}:user:upstream:{userID}` is designed around stable user IDs — one set per user, holding all of that user's session IDs. Under synthesis the userID rotates with every re-authentication, so each session lands in its own one-element set. Reads continue to work, but set churn is much higher than under OIDC. The existing TODO at `pkg/authserver/storage/redis.go:43-45` to scan and clean up stale secondary-index entries applies, and synthesis-mode workloads make a periodic scan more important. -**Operator visibility.** When at least one configured OAuth2 upstream has `userInfo == nil`, the controller surfaces the `IdentitySynthesized` condition on the `MCPExternalAuthConfig` and `VirtualMCPServer` status (Reason `IdentitySynthesizedActive`, naming the affected upstreams). The condition flips to `False` (Reason `IdentitySynthesizedInactive`) once every upstream has `userInfo` configured. +**Operator visibility.** When at least one configured OAuth2 upstream has `userInfo == nil`, the controller surfaces the `IdentitySynthesized` condition on the `MCPExternalAuthConfig` and `VirtualMCPServer` status (Reason `IdentitySynthesizedActive`, naming the affected upstreams). The condition flips to `False` (Reason `IdentitySynthesizedInactive`) once every upstream has `userInfo` configured. Note: the controller predicate (`SyntheticIdentityUpstreams`) checks only for `userInfo == nil` and does not yet account for `IdentityFromToken`; see the known gap noted under priority 1. **Implementation.** - `pkg/authserver/upstream/oauth2.go` — `synthesizeIdentity`, `synthesizeSubjectFromAccessToken`, `IsSynthesizedSubject` diff --git a/pkg/authserver/upstream/oauth2.go b/pkg/authserver/upstream/oauth2.go index 8d0f349685..d6b5c41886 100644 --- a/pkg/authserver/upstream/oauth2.go +++ b/pkg/authserver/upstream/oauth2.go @@ -152,7 +152,16 @@ type OAuth2Config struct { // When set, the provider performs the token exchange HTTP call directly (bypassing // golang.org/x/oauth2) and extracts fields using gjson dot-notation paths. // When nil, standard OAuth 2.0 token response parsing is used. + // See also: IdentityFromToken for extracting user identity from the same response. TokenResponseMapping *TokenResponseMapping `json:"token_response_mapping,omitempty" yaml:"token_response_mapping,omitempty"` + + // IdentityFromToken extracts user identity from the token-endpoint response + // body when the upstream provider includes identity claims there (e.g., + // Snowflake's `username`, Slack's `authed_user.id`). When set, the embedded + // auth server skips the userinfo HTTP call entirely. See the CRD type + // (cmd/thv-operator/api/v1alpha1.IdentityFromTokenConfig) for the + // authoritative trust-model and uniqueness documentation. + IdentityFromToken *IdentityFromTokenConfig `json:"identity_from_token,omitempty" yaml:"identity_from_token,omitempty"` } // TokenResponseMapping configures extraction of token fields from non-standard @@ -195,6 +204,11 @@ func (c *OAuth2Config) Validate() error { return errors.New("token_response_mapping.access_token_path is required when token_response_mapping is set") } } + if c.IdentityFromToken != nil { + if c.IdentityFromToken.SubjectPath == "" { + return errors.New("identity_from_token.subject_path is required when identity_from_token is set") + } + } return c.CommonOAuthConfig.Validate() } @@ -400,39 +414,71 @@ func (p *BaseOAuth2Provider) buildAuthorizationURL( } // ExchangeCodeForIdentity exchanges an authorization code for tokens and resolves -// the user's identity in a single atomic operation. -// For pure OAuth2 providers, identity is resolved via UserInfo when configured; -// otherwise Subject is synthesized via synthesizeIdentity (which rejects empty -// access tokens to prevent the well-known sha256("") subject collision) and -// Name/Email are left empty. The nonce parameter is ignored (no ID token). +// the user's identity in a single atomic operation. For pure OAuth2 providers +// (no ID token) the priority chain is: +// +// 1. IdentityFromToken (operator opt-in): extract identity claims from the +// token-endpoint response body using gjson paths. The userinfo HTTP call +// is skipped entirely. +// 2. UserInfo endpoint: fetch identity from the configured userinfo URL. +// 3. Synthesis: when neither is configured, synthesizeIdentity derives a +// non-PII Subject from the access token (rejects empty tokens to prevent +// the well-known sha256("") collision). Name and Email are empty; +// Synthetic=true tells the callback handler to bypass UserResolver +// because the subject rotates per access token. +// +// The nonce parameter is ignored (no ID token to validate). func (p *BaseOAuth2Provider) ExchangeCodeForIdentity(ctx context.Context, code, codeVerifier, _ string) (*Identity, error) { - tokens, err := p.exchangeCodeForTokens(ctx, code, codeVerifier) + exchanged, err := p.exchangeCodeForTokens(ctx, code, codeVerifier) if err != nil { return nil, err } - // No userinfo: synthesize a non-PII subject from the access token. - // Synthetic=true tells the callback handler to bypass UserResolver — the - // synthesized subject rotates per access token, so persisting it would - // create a new `users` row on every re-authentication. - if p.config.UserInfo == nil { - return synthesizeIdentity(tokens) - } - - userInfo, err := p.fetchUserInfo(ctx, tokens.AccessToken) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrIdentityResolutionFailed, err) - } - if userInfo == nil || userInfo.Subject == "" { - return nil, ErrIdentityResolutionFailed - } - - return &Identity{ - Tokens: tokens, - Subject: userInfo.Subject, - Name: userInfo.Name, - Email: userInfo.Email, - }, nil + // Priority 1: identityFromToken (configured by operator). + if p.config.IdentityFromToken != nil { + if exchanged.identity == nil { + // The rewriter logged the extraction failure at WARN with the + // operator-supplied path. Surface the same error here so callers + // can report it without requiring WARN-level log access. + if exchanged.extractionErr != nil { + return nil, exchanged.extractionErr + } + // Unreachable in practice: when identity is nil the rewriter always + // sets extractionErr. Kept as a safe fallback. + return nil, fmt.Errorf( + "%w: identityFromToken configured but extraction failed", + ErrIdentityResolutionFailed, + ) + } + return &Identity{ + Tokens: exchanged.tokens, + Subject: exchanged.identity.Subject, + Name: exchanged.identity.Name, + Email: exchanged.identity.Email, + }, nil + } + + // Priority 2: userInfo (existing behavior). + if p.config.UserInfo != nil { + userInfo, err := p.fetchUserInfo(ctx, exchanged.tokens.AccessToken) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrIdentityResolutionFailed, err) + } + if userInfo == nil || userInfo.Subject == "" { + return nil, ErrIdentityResolutionFailed + } + return &Identity{ + Tokens: exchanged.tokens, + Subject: userInfo.Subject, + Name: userInfo.Name, + Email: userInfo.Email, + }, nil + } + + // Priority 3: synthesis (PR 5094). Subject derived from access token; rotates + // per token. The callback handler treats Synthetic=true as opt-out from the + // user-resolver to avoid creating a fresh users row on every re-auth. + return synthesizeIdentity(exchanged.tokens) } // synthesizedSubjectPrefix tags subjects produced by @@ -496,21 +542,45 @@ func synthesizeIdentity(tokens *Tokens) (*Identity, error) { }, nil } +// tokenExchangeResult bundles the outputs of a successful token exchange: +// the obtained tokens, any identity extracted from the token response body, +// and any error from the identity extraction step. The exchange-level error +// is returned as the function's own error return value. +// +// extractionErr is populated when IdentityFromToken is configured but the +// extractor could not resolve the subject path. It carries operator-actionable +// diagnostics (path name, type description) and already wraps +// ErrIdentityResolutionFailed. The caller must check extractionErr when +// identity is nil and IdentityFromToken is set. +type tokenExchangeResult struct { + tokens *Tokens + identity *partialIdentity + extractionErr error +} + // exchangeCodeForTokens exchanges an authorization code for tokens with the upstream IDP. -func (p *BaseOAuth2Provider) exchangeCodeForTokens(ctx context.Context, code, codeVerifier string) (*Tokens, error) { +// It returns a tokenExchangeResult containing the tokens, any identity extracted from +// the token response body, and any extraction error. The function error is the +// exchange-level error (network, HTTP, token parsing). +func (p *BaseOAuth2Provider) exchangeCodeForTokens( + ctx context.Context, code, codeVerifier string, +) (*tokenExchangeResult, error) { if code == "" { return nil, errors.New("authorization code is required") } - slog.Info("exchanging authorization code for tokens", + slog.Debug("exchanging authorization code for tokens", "token_endpoint", p.config.TokenEndpoint, "has_pkce_verifier", codeVerifier != "", ) - // Wrap HTTP client with token response rewriter if mapping is configured. - // This normalizes non-standard responses (e.g., GovSlack's nested fields) - // before the oauth2 library parses them, keeping the standard exchange flow. - httpClient := wrapHTTPClientWithMapping(p.httpClient, p.config.TokenResponseMapping, p.config.TokenEndpoint) + // Wrap HTTP client with token response rewriter if mapping or identity extraction + // is configured. At auth-code time, both mapping (field normalization) and + // identityCfg (identity extraction) may be active together. Keep a reference to + // the rewriter so we can read extractedIdentity and extractionErr after Exchange returns. + httpClient, rewriter := wrapHTTPClientForTokenExchange( + p.httpClient, p.config.TokenResponseMapping, p.config.IdentityFromToken, p.config.TokenEndpoint, + ) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) // Build exchange options @@ -534,7 +604,16 @@ func (p *BaseOAuth2Provider) exchangeCodeForTokens(ctx context.Context, code, co "expires_at", expiresAtLogValue(tokens.ExpiresAt), ) - return tokens, nil + // Read any identity and extraction error captured by the rewriter during the + // token round-trip. rewriter is nil when neither mapping nor identityCfg is + // configured; nil-safe. + result := &tokenExchangeResult{tokens: tokens} + if rewriter != nil { + result.identity = rewriter.extractedIdentity + result.extractionErr = rewriter.extractionErr + } + + return result, nil } // RefreshTokens refreshes the upstream IDP tokens. @@ -559,7 +638,12 @@ func (p *BaseOAuth2Provider) RefreshTokens(ctx context.Context, refreshToken, _ ) // Wrap HTTP client with token response rewriter if mapping is configured. - httpClient := wrapHTTPClientWithMapping(p.httpClient, p.config.TokenResponseMapping, p.config.TokenEndpoint) + // Identity extraction (identityCfg) is intentionally nil here: per Snowflake's + // contract and the general design, the username/identity field is only present + // in the initial auth-code response and is omitted on refresh. Identity is + // cached at auth-code time and read from session storage on subsequent requests. + // The rewriter is discarded because refresh does not produce identity. + httpClient, _ := wrapHTTPClientForTokenExchange(p.httpClient, p.config.TokenResponseMapping, nil, p.config.TokenEndpoint) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) opts := []oauth2.AuthCodeOption{ diff --git a/pkg/authserver/upstream/oauth2_test.go b/pkg/authserver/upstream/oauth2_test.go index 0ffcb154db..5770878a20 100644 --- a/pkg/authserver/upstream/oauth2_test.go +++ b/pkg/authserver/upstream/oauth2_test.go @@ -19,10 +19,12 @@ import ( "encoding/hex" "encoding/json" "errors" + "fmt" "net/http" "net/http/httptest" "net/url" "strings" + "sync/atomic" "testing" "time" @@ -408,7 +410,7 @@ func TestBaseOAuth2Provider_exchangeCodeForTokens(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - tokens, err := provider.exchangeCodeForTokens(ctx, "test-auth-code", "test-verifier") + result, err := provider.exchangeCodeForTokens(ctx, "test-auth-code", "test-verifier") require.NoError(t, err) // Verify request parameters @@ -420,12 +422,12 @@ func TestBaseOAuth2Provider_exchangeCodeForTokens(t *testing.T) { assert.Equal(t, "http://localhost:8080/callback", receivedParams.Get("redirect_uri")) // Verify response - assert.Equal(t, "exchanged-access-token", tokens.AccessToken) - assert.Equal(t, "exchanged-refresh-token", tokens.RefreshToken) + assert.Equal(t, "exchanged-access-token", result.tokens.AccessToken) + assert.Equal(t, "exchanged-refresh-token", result.tokens.RefreshToken) // Verify expiration is set approximately correctly expectedExpiry := time.Now().Add(7200 * time.Second) - assert.WithinDuration(t, expectedExpiry, tokens.ExpiresAt, 10*time.Second) + assert.WithinDuration(t, expectedExpiry, result.tokens.ExpiresAt, 10*time.Second) }) t.Run("handles error response from token endpoint", func(t *testing.T) { @@ -660,11 +662,11 @@ func TestBaseOAuth2Provider_exchangeCodeForTokens(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - tokens, err := provider.exchangeCodeForTokens(ctx, "test-code", "") + result, err := provider.exchangeCodeForTokens(ctx, "test-code", "") require.NoError(t, err) // No expires_in in the response means the token has no expiry. - assert.True(t, tokens.ExpiresAt.IsZero(), "ExpiresAt should be zero for non-expiring tokens") + assert.True(t, result.tokens.ExpiresAt.IsZero(), "ExpiresAt should be zero for non-expiring tokens") }) } @@ -935,9 +937,9 @@ func TestBaseOAuth2Provider_WithOAuth2HTTPClient(t *testing.T) { // Verify the provider works with custom client ctx := context.Background() - tokens, err := provider.exchangeCodeForTokens(ctx, "test-code", "") + result, err := provider.exchangeCodeForTokens(ctx, "test-code", "") require.NoError(t, err) - assert.NotEmpty(t, tokens.AccessToken) + assert.NotEmpty(t, result.tokens.AccessToken) } func TestBaseOAuth2Provider_TokenTypeValidation(t *testing.T) { @@ -1069,11 +1071,11 @@ func TestBaseOAuth2Provider_IDToken(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - tokens, err := provider.exchangeCodeForTokens(ctx, "test-code", "") + result, err := provider.exchangeCodeForTokens(ctx, "test-code", "") require.NoError(t, err) // OAuth2 providers can also return ID tokens if they support hybrid flows - assert.Equal(t, "test-id-token.payload.signature", tokens.IDToken) + assert.Equal(t, "test-id-token.payload.signature", result.tokens.IDToken) } func Test_validateRedirectURI(t *testing.T) { @@ -2137,3 +2139,353 @@ func TestAuthorizationURL_AdditionalAuthorizationParams(t *testing.T) { assert.Equal(t, "caller-value", parsed.Query().Get("custom")) }) } + +// tokenBodyWithUsername is a minimal valid token response body that includes a +// "username" field used across several identityFromToken test cases. +const tokenBodyWithUsername = `{"access_token":"a","token_type":"Bearer","username":"u1"}` + +// newTokenResponseServer is a test helper that starts an httptest.Server whose +// /token endpoint returns tokenBody as the JSON response body, and whose +// /authorize endpoint always returns 200 OK. +func newTokenResponseServer(t *testing.T, tokenBody string) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("/authorize", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tokenBody)) + }) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv +} + +// TestBaseOAuth2Provider_ExchangeCodeForIdentity_IdentityFromToken covers the +// priority-1 path where IdentityFromToken is configured. +func TestBaseOAuth2Provider_ExchangeCodeForIdentity_IdentityFromToken(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("identityFromToken resolves subject name email", func(t *testing.T) { + t.Parallel() + + tokenBody := `{"access_token":"a","token_type":"Bearer","username":"u1","display_name":"User One","email":"u1@example.com"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "username", + NamePath: "display_name", + EmailPath: "email", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.Equal(t, "u1", identity.Subject) + assert.Equal(t, "User One", identity.Name) + assert.Equal(t, "u1@example.com", identity.Email) + assert.NotEmpty(t, identity.Tokens.AccessToken) + }) + + t.Run("@upstreamjwt modifier resolves identity from JWT-shaped access token", func(t *testing.T) { + t.Parallel() + + RegisterModifiers() + + accessToken := makeJWT(`{"sub":"u-jwt-1","name":"JWT User","email":"jwt@example.com"}`) + tokenBody := fmt.Sprintf(`{"access_token":%q,"token_type":"Bearer"}`, accessToken) + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "access_token|@upstreamjwt|sub", + NamePath: "access_token|@upstreamjwt|name", + EmailPath: "access_token|@upstreamjwt|email", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.Equal(t, "u-jwt-1", identity.Subject) + assert.Equal(t, "JWT User", identity.Name) + assert.Equal(t, "jwt@example.com", identity.Email) + assert.False(t, identity.Synthetic) + }) + + t.Run("identityFromToken bypasses userinfo endpoint entirely", func(t *testing.T) { + t.Parallel() + + // httptest.Server dispatches each request on its own goroutine, so + // the counter must be accessed atomically — t.Errorf inside the + // handler is the load-bearing assertion; the counter just gives a + // numeric value the final assertion can read race-free. + var tripwireCallCount atomic.Int32 + tripwire := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + tripwireCallCount.Add(1) + t.Errorf("userinfo endpoint must NOT be called when identityFromToken is configured") + })) + t.Cleanup(tripwire.Close) + + tokenSrv := newTokenResponseServer(t, tokenBodyWithUsername) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "username", + }, + // UserInfo is also configured — identityFromToken must win and + // the tripwire userinfo server must never be contacted. + UserInfo: &UserInfoConfig{ + EndpointURL: tripwire.URL, + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.Equal(t, "u1", identity.Subject) + assert.Equal(t, int32(0), tripwireCallCount.Load(), "userinfo endpoint must not be called") + }) + + t.Run("identityFromToken extraction failure returns ErrIdentityResolutionFailed without calling userinfo", func(t *testing.T) { + t.Parallel() + + var tripwireCallCount atomic.Int32 + tripwire := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + tripwireCallCount.Add(1) + t.Errorf("userinfo endpoint must NOT be called when identityFromToken is configured") + })) + t.Cleanup(tripwire.Close) + + // Token body does NOT contain "missing_path" + tokenSrv := newTokenResponseServer(t, tokenBodyWithUsername) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "missing_path", + }, + UserInfo: &UserInfoConfig{ + EndpointURL: tripwire.URL, + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + _, err = provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrIdentityResolutionFailed)) + assert.Equal(t, int32(0), tripwireCallCount.Load(), "userinfo endpoint must not be called on extraction failure") + }) + + t.Run("extraction failure error surfaces the misconfigured path name", func(t *testing.T) { + t.Parallel() + + tokenSrv := newTokenResponseServer(t, tokenBodyWithUsername) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + // Path that does not exist in the token response so the extractor + // produces a diagnostic that names the path and the failure reason. + SubjectPath: "nonexistent_field", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + _, err = provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrIdentityResolutionFailed)) + // The error must carry the operator-supplied path name so the + // misconfiguration is diagnosable without log access. + assert.Contains(t, err.Error(), "nonexistent_field", + "error must contain the misconfigured subject path name") + assert.Contains(t, err.Error(), "not found", + "error must describe why extraction failed") + }) + + t.Run("userInfo-only path is unchanged when identityFromToken is not set", func(t *testing.T) { + t.Parallel() + + userInfoSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "sub": "u-456", + "name": "Bob", + "email": "bob@example.com", + }) + })) + t.Cleanup(userInfoSrv.Close) + + tokenBody := `{"access_token":"a","token_type":"Bearer"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + UserInfo: &UserInfoConfig{ + EndpointURL: userInfoSrv.URL, + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.Equal(t, "u-456", identity.Subject) + assert.Equal(t, "Bob", identity.Name) + assert.Equal(t, "bob@example.com", identity.Email) + }) + + t.Run("neither identityFromToken nor userInfo set falls through to synthesis", func(t *testing.T) { + t.Parallel() + + tokenBody := `{"access_token":"a","token_type":"Bearer"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + // Both identity surfaces absent — the priority chain falls through to + // synthesizeIdentity (PR 5094): a non-PII Subject derived from the + // access token, Synthetic=true, no Name/Email. + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + identity, err := provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.NoError(t, err) + assert.True(t, identity.Synthetic, "expected synthesized identity when neither IdentityFromToken nor UserInfo is set") + assert.True(t, IsSynthesizedSubject(identity.Subject), + "expected synthesized subject prefix; got %q", identity.Subject) + assert.Empty(t, identity.Name, "synthesized identity has no name") + assert.Empty(t, identity.Email, "synthesized identity has no email") + }) + + t.Run("refresh path does not trigger identity extraction", func(t *testing.T) { + t.Parallel() + + tokenBody := `{"access_token":"refreshed","token_type":"Bearer","refresh_token":"new-rt","username":"u1"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + SubjectPath: "username", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + // RefreshTokens must not blow up and must return valid tokens. + tokens, err := provider.RefreshTokens(ctx, "old-refresh-token", "") + require.NoError(t, err) + assert.Equal(t, "refreshed", tokens.AccessToken) + + // Verify that the wrapped client on the refresh path uses nil identityCfg + // by inspecting the transport after constructing it the same way RefreshTokens does. + _, rewriter := wrapHTTPClientForTokenExchange(provider.httpClient, provider.config.TokenResponseMapping, nil, provider.config.TokenEndpoint) + // When only mapping is nil and identityCfg is nil, wrapHTTPClientForTokenExchange + // returns the original client and nil rewriter. When mapping is non-nil, + // the rewriter has nil identityCfg. Here, both are nil, so rewriter is nil. + assert.Nil(t, rewriter, "refresh path rewriter must be nil when both mapping and identityCfg are nil") + }) + + t.Run("error message does not contain token body content", func(t *testing.T) { + t.Parallel() + + secretMarker := "SUPER_SECRET_TOKEN_BODY_MARKER_XYZ789" + tokenBody := `{"access_token":"` + secretMarker + `","token_type":"Bearer","username":"u1"}` + tokenSrv := newTokenResponseServer(t, tokenBody) + + config := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "http://localhost:8080/callback", + }, + AuthorizationEndpoint: tokenSrv.URL + "/authorize", + TokenEndpoint: tokenSrv.URL + "/token", + IdentityFromToken: &IdentityFromTokenConfig{ + // Deliberately wrong path to trigger extraction failure. + SubjectPath: "missing_field", + }, + } + + provider, err := NewOAuth2Provider(config) + require.NoError(t, err) + + _, err = provider.ExchangeCodeForIdentity(ctx, "test-code", "", "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrIdentityResolutionFailed)) + // The error must not leak any part of the token response body. + assert.NotContains(t, err.Error(), secretMarker, + "error message must not contain token body content") + }) +} diff --git a/pkg/authserver/upstream/oidc.go b/pkg/authserver/upstream/oidc.go index 1ea53c4424..ed1b653e91 100644 --- a/pkg/authserver/upstream/oidc.go +++ b/pkg/authserver/upstream/oidc.go @@ -242,27 +242,36 @@ func (p *OIDCProviderImpl) ExchangeCodeForIdentity( return nil, errors.New("OIDC endpoints not discovered") } - tokens, err := p.exchangeCodeForTokens(ctx, code, codeVerifier) + // OIDC resolves identity from the validated ID token, not the token response + // body, so the extracted identity return value is intentionally discarded. + // Defense-in-depth: warn if the OAuth2 base config carries a non-nil + // IdentityFromToken — the field is structurally absent on the OIDC CRD type + // today, but a future config-loader bug or hand-built runtime config could + // silently drop the operator's intent without this signal. + if p.config.IdentityFromToken != nil { + slog.Warn("OIDC provider ignoring IdentityFromToken; identity is resolved from the validated ID token") + } + exchanged, err := p.exchangeCodeForTokens(ctx, code, codeVerifier) if err != nil { return nil, err } // OIDC-specific: ID token MUST be present per Section 3.1.3.3. - if tokens.IDToken == "" { + if exchanged.tokens.IDToken == "" { return nil, fmt.Errorf("%w: ID token required for OIDC provider", ErrIdentityResolutionFailed) } // Validate ID token with nonce in a single pass — no double-validation. - validatedToken, err := p.validateIDToken(ctx, tokens.IDToken, nonce) + validatedToken, err := p.validateIDToken(ctx, exchanged.tokens.IDToken, nonce) if err != nil { slog.Debug("id token validation failed", "error", err) return nil, fmt.Errorf("%w: %w", ErrIdentityResolutionFailed, err) } slog.Debug("authorization code exchange successful", - "has_refresh_token", tokens.RefreshToken != "", - "has_id_token", tokens.IDToken != "", - "expires_at", expiresAtLogValue(tokens.ExpiresAt), + "has_refresh_token", exchanged.tokens.RefreshToken != "", + "has_id_token", exchanged.tokens.IDToken != "", + "expires_at", expiresAtLogValue(exchanged.tokens.ExpiresAt), ) // Extract optional standard claims (name, email) from ID token @@ -278,7 +287,7 @@ func (p *OIDCProviderImpl) ExchangeCodeForIdentity( } return &Identity{ - Tokens: tokens, + Tokens: exchanged.tokens, Subject: validatedToken.Subject, Name: idClaims.Name, Email: idClaims.Email, diff --git a/pkg/authserver/upstream/token_exchange.go b/pkg/authserver/upstream/token_exchange.go index 16fa65c1da..4203e23ce2 100644 --- a/pkg/authserver/upstream/token_exchange.go +++ b/pkg/authserver/upstream/token_exchange.go @@ -7,28 +7,44 @@ import ( "bytes" "encoding/json" "io" + "log/slog" "net/http" "github.com/tidwall/gjson" ) // tokenResponseRewriter is an http.RoundTripper that normalizes non-standard -// OAuth token responses before the golang.org/x/oauth2 library parses them. +// OAuth token responses before the golang.org/x/oauth2 library parses them, +// and optionally extracts user identity claims from the raw response body. // // Some providers (e.g., GovSlack) nest token fields under non-standard paths // like "authed_user.access_token" instead of the top-level "access_token". // This RoundTripper intercepts the response, extracts fields using gjson // dot-notation paths, and rewrites the response body with standard top-level // field names so the oauth2 library can parse them normally. +// +// When identityCfg is set, identity extraction runs on the RAW pre-rewrite +// body before any field relocation occurs. The extracted identity is stored +// in extractedIdentity and consumed by the caller after Exchange returns. +// If extraction fails, extractionErr holds the cause so callers can surface +// operator-actionable diagnostics (path names, type descriptions) without +// re-reading the response body. type tokenResponseRewriter struct { - base http.RoundTripper - mapping *TokenResponseMapping - tokenURL string + base http.RoundTripper + mapping *TokenResponseMapping + identityCfg *IdentityFromTokenConfig + tokenURL string + extractedIdentity *partialIdentity + extractionErr error } // RoundTrip intercepts HTTP responses from the token endpoint and rewrites // the JSON body to place mapped fields at the top level. Non-token-endpoint // requests (e.g., userInfo) pass through unchanged. +// +// When identityCfg is set, identity is extracted from the raw response body +// BEFORE the rewrite step so that identity paths are resolved against the +// original provider response, not the normalized form. func (t *tokenResponseRewriter) RoundTrip(req *http.Request) (*http.Response, error) { resp, err := t.base.RoundTrip(req) if err != nil { @@ -51,7 +67,31 @@ func (t *tokenResponseRewriter) RoundTrip(req *http.Request) (*http.Response, er return nil, err } - rewritten := rewriteTokenResponse(body, t.mapping) + // Extract from the raw body before rewriteTokenResponse runs. The rewrite + // never touches identity-shaped fields today, but extracting first makes + // the ordering invariant independent of that assumption. + if t.identityCfg != nil { + result, extractErr := extractIdentityFromTokenResponse(body, t.identityCfg) + if extractErr != nil { + // WARN so an operator misconfiguration (e.g., wrong subjectPath) is + // visible without enabling DEBUG. The error is safe to log: it + // contains operator-supplied paths and type descriptions, never + // any portion of the response body. + slog.Warn("identity extraction from token response failed", "error", extractErr) + t.extractionErr = extractErr + } else { + t.extractedIdentity = &result + } + } + + // Only run the field-rewrite step when a mapping is configured. + // When mapping is nil (e.g. only identityCfg is set), pass the body through unchanged. + var rewritten []byte + if t.mapping != nil { + rewritten = rewriteTokenResponse(body, t.mapping) + } else { + rewritten = body + } resp.Body = io.NopCloser(bytes.NewReader(rewritten)) resp.ContentLength = int64(len(rewritten)) @@ -107,12 +147,20 @@ func rewriteTokenResponse(body []byte, mapping *TokenResponseMapping) []byte { return rewritten } -// wrapHTTPClientWithMapping wraps an HTTP client's transport with a -// tokenResponseRewriter when a TokenResponseMapping is configured. -// Returns the original client unchanged if mapping is nil. -func wrapHTTPClientWithMapping(client *http.Client, mapping *TokenResponseMapping, tokenURL string) *http.Client { - if mapping == nil { - return client +// wrapHTTPClientForTokenExchange wraps an HTTP client's transport with a +// tokenResponseRewriter when either a TokenResponseMapping or an +// IdentityFromTokenConfig is configured (or both). Returns the original client +// and nil rewriter when both are nil, so the standard oauth2 library path is +// used. The returned rewriter (when non-nil) can be read after Exchange returns +// to retrieve any identity extracted during the token round-trip. +func wrapHTTPClientForTokenExchange( + client *http.Client, + mapping *TokenResponseMapping, + identityCfg *IdentityFromTokenConfig, + tokenURL string, +) (*http.Client, *tokenResponseRewriter) { + if mapping == nil && identityCfg == nil { + return client, nil } base := client.Transport @@ -120,14 +168,17 @@ func wrapHTTPClientWithMapping(client *http.Client, mapping *TokenResponseMappin base = http.DefaultTransport } + rewriter := &tokenResponseRewriter{ + base: base, + mapping: mapping, + identityCfg: identityCfg, + tokenURL: tokenURL, + } + // Create a shallow copy to avoid mutating the original client wrapped := *client - wrapped.Transport = &tokenResponseRewriter{ - base: base, - mapping: mapping, - tokenURL: tokenURL, - } - return &wrapped + wrapped.Transport = rewriter + return &wrapped, rewriter } // pathOrDefault returns the path if non-empty, otherwise returns the default. diff --git a/pkg/authserver/upstream/token_exchange_test.go b/pkg/authserver/upstream/token_exchange_test.go index 125c78750e..d60270313d 100644 --- a/pkg/authserver/upstream/token_exchange_test.go +++ b/pkg/authserver/upstream/token_exchange_test.go @@ -15,6 +15,14 @@ import ( "github.com/stretchr/testify/require" ) +// tokenEndpointHandler returns an HTTP handler that responds with the given JSON body. +func tokenEndpointHandler(body string) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(body)) + } +} + func TestRewriteTokenResponse(t *testing.T) { t.Parallel() @@ -154,7 +162,7 @@ func TestTokenResponseRewriter_TokenEndpoint(t *testing.T) { ScopePath: "authed_user.scope", } - client := wrapHTTPClientWithMapping(http.DefaultClient, mapping, tokenServer.URL) + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, nil, tokenServer.URL) req, err := http.NewRequest("POST", tokenServer.URL, strings.NewReader("grant_type=authorization_code")) require.NoError(t, err) @@ -187,7 +195,7 @@ func TestTokenResponseRewriter_NonTokenEndpoint(t *testing.T) { mapping := &TokenResponseMapping{AccessTokenPath: "authed_user.access_token"} // Token URL points elsewhere, so this server's responses should pass through unchanged - client := wrapHTTPClientWithMapping(http.DefaultClient, mapping, "https://other.example.com/token") + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, nil, "https://other.example.com/token") req, err := http.NewRequest("GET", server.URL, nil) require.NoError(t, err) @@ -211,8 +219,9 @@ func TestWrapHTTPClientWithMapping_NilMapping(t *testing.T) { t.Parallel() original := &http.Client{} - result := wrapHTTPClientWithMapping(original, nil, "https://example.com/token") + result, rewriter := wrapHTTPClientForTokenExchange(original, nil, nil, "https://example.com/token") assert.Same(t, original, result) + assert.Nil(t, rewriter) } func TestTokenResponseRewriter_ErrorResponse(t *testing.T) { @@ -225,7 +234,7 @@ func TestTokenResponseRewriter_ErrorResponse(t *testing.T) { defer tokenServer.Close() mapping := &TokenResponseMapping{AccessTokenPath: "authed_user.access_token"} - client := wrapHTTPClientWithMapping(http.DefaultClient, mapping, tokenServer.URL) + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, nil, tokenServer.URL) req, err := http.NewRequest("POST", tokenServer.URL, strings.NewReader("grant_type=authorization_code")) require.NoError(t, err) @@ -280,3 +289,193 @@ func TestOAuth2Config_Validate_TokenResponseMapping(t *testing.T) { }) } } + +func TestOAuth2Config_Validate_IdentityFromToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + identityCfg *IdentityFromTokenConfig + wantErr bool + errContains string + }{ + {name: "nil identity config is valid", identityCfg: nil, wantErr: false}, + { + name: "valid identity config with subject path", + identityCfg: &IdentityFromTokenConfig{SubjectPath: "username"}, + wantErr: false, + }, + { + name: "missing subject path", + identityCfg: &IdentityFromTokenConfig{NamePath: "name"}, + wantErr: true, + errContains: "identity_from_token.subject_path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + cfg := &OAuth2Config{ + CommonOAuthConfig: CommonOAuthConfig{ClientID: "test", RedirectURI: "http://localhost/callback"}, + AuthorizationEndpoint: "https://example.com/authorize", + TokenEndpoint: "https://example.com/token", + IdentityFromToken: tt.identityCfg, + } + err := cfg.Validate() + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestTokenResponseRewriter_IdentityCfgNil verifies that when identityCfg is nil, +// extractedIdentity remains nil after RoundTrip even when a mapping is configured. +func TestTokenResponseRewriter_IdentityCfgNil(t *testing.T) { + t.Parallel() + + body := `{"access_token":"tok","token_type":"Bearer","username":"u1"}` + server := httptest.NewServer(tokenEndpointHandler(body)) + t.Cleanup(server.Close) + + mapping := &TokenResponseMapping{AccessTokenPath: "access_token"} + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, nil, server.URL) + + transport, ok := client.Transport.(*tokenResponseRewriter) + require.True(t, ok, "expected *tokenResponseRewriter transport") + + req, err := http.NewRequest("POST", server.URL, strings.NewReader("grant_type=authorization_code")) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) + + assert.Nil(t, transport.extractedIdentity) +} + +// TestTokenResponseRewriter_IdentityCfgSet verifies that when identityCfg is set +// and the body contains the subject path, extractedIdentity is populated. +func TestTokenResponseRewriter_IdentityCfgSet(t *testing.T) { + t.Parallel() + + body := `{"access_token":"a","token_type":"Bearer","username":"u1"}` + server := httptest.NewServer(tokenEndpointHandler(body)) + t.Cleanup(server.Close) + + identityCfg := &IdentityFromTokenConfig{SubjectPath: "username"} + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, nil, identityCfg, server.URL) + + transport, ok := client.Transport.(*tokenResponseRewriter) + require.True(t, ok, "expected *tokenResponseRewriter transport") + + req, err := http.NewRequest("POST", server.URL, strings.NewReader("grant_type=authorization_code")) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) + + require.NotNil(t, transport.extractedIdentity) + assert.Equal(t, "u1", transport.extractedIdentity.Subject) +} + +// TestTokenResponseRewriter_IdentityFromRawBody verifies the raw-body +// invariant: identity is resolved against the pre-rewrite body. The fixture +// places the subject at a nested path ("authed_user.id") and configures a +// mapping that lifts access_token to the top level; the assertion confirms +// extraction reads the original nested location, not a post-rewrite shape. +func TestTokenResponseRewriter_IdentityFromRawBody(t *testing.T) { + t.Parallel() + + rawBody := `{"authed_user":{"access_token":"x","id":"U1234"},"username":""}` + server := httptest.NewServer(tokenEndpointHandler(rawBody)) + t.Cleanup(server.Close) + + mapping := &TokenResponseMapping{AccessTokenPath: "authed_user.access_token"} + identityCfg := &IdentityFromTokenConfig{SubjectPath: "authed_user.id"} + client, _ := wrapHTTPClientForTokenExchange(http.DefaultClient, mapping, identityCfg, server.URL) + + transport, ok := client.Transport.(*tokenResponseRewriter) + require.True(t, ok, "expected *tokenResponseRewriter transport") + + req, err := http.NewRequest("POST", server.URL, strings.NewReader("grant_type=authorization_code")) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) + + // Identity should be extracted from the raw body where authed_user.id == "U1234" + require.NotNil(t, transport.extractedIdentity) + assert.Equal(t, "U1234", transport.extractedIdentity.Subject) + + // The rewritten body should have access_token lifted to top-level + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(respBody, &parsed)) + assert.Equal(t, "x", parsed["access_token"]) + assert.Equal(t, "Bearer", parsed["token_type"]) +} + +// TestWrapHTTPClientForTokenExchange_OnlyIdentityCfg verifies that wrapping occurs +// when only identityCfg is set (mapping is nil). +func TestWrapHTTPClientForTokenExchange_OnlyIdentityCfg(t *testing.T) { + t.Parallel() + + original := &http.Client{} + identityCfg := &IdentityFromTokenConfig{SubjectPath: "username"} + result, rewriter := wrapHTTPClientForTokenExchange(original, nil, identityCfg, "https://example.com/token") + + assert.NotSame(t, original, result) + assert.NotNil(t, rewriter) + _, ok := result.Transport.(*tokenResponseRewriter) + assert.True(t, ok, "expected *tokenResponseRewriter transport when only identityCfg is set") +} + +// TestWrapHTTPClientForTokenExchange_OnlyMapping verifies that wrapping occurs +// when only mapping is set (identityCfg is nil), preserving existing behavior. +func TestWrapHTTPClientForTokenExchange_OnlyMapping(t *testing.T) { + t.Parallel() + + original := &http.Client{} + mapping := &TokenResponseMapping{AccessTokenPath: "access_token"} + result, rewriter := wrapHTTPClientForTokenExchange(original, mapping, nil, "https://example.com/token") + + assert.NotSame(t, original, result) + assert.NotNil(t, rewriter) + _, ok := result.Transport.(*tokenResponseRewriter) + assert.True(t, ok, "expected *tokenResponseRewriter transport when only mapping is set") +} + +// TestWrapHTTPClientForTokenExchange_BothSet verifies that wrapping occurs +// when both mapping and identityCfg are set. +func TestWrapHTTPClientForTokenExchange_BothSet(t *testing.T) { + t.Parallel() + + original := &http.Client{} + mapping := &TokenResponseMapping{AccessTokenPath: "access_token"} + identityCfg := &IdentityFromTokenConfig{SubjectPath: "username"} + result, rewriter := wrapHTTPClientForTokenExchange(original, mapping, identityCfg, "https://example.com/token") + + assert.NotSame(t, original, result) + assert.NotNil(t, rewriter) + _, ok := result.Transport.(*tokenResponseRewriter) + assert.True(t, ok, "expected *tokenResponseRewriter transport when both are set") +} + +// TestWrapHTTPClientForTokenExchange_BothNil verifies that the original client +// is returned unchanged when both mapping and identityCfg are nil. +func TestWrapHTTPClientForTokenExchange_BothNil(t *testing.T) { + t.Parallel() + + original := &http.Client{} + result, rewriter := wrapHTTPClientForTokenExchange(original, nil, nil, "https://example.com/token") + assert.Same(t, original, result) + assert.Nil(t, rewriter) +}