From 1736a6e94cfa44d021f3bfe26dd5dc009db2bc40 Mon Sep 17 00:00:00 2001 From: Trey Date: Wed, 6 May 2026 10:15:17 -0700 Subject: [PATCH 1/3] Add Redis backend for DCRCredentialStore Closes #5184 --- pkg/authserver/storage/redis.go | 182 +++++++++ .../storage/redis_integration_test.go | 176 +++++++++ pkg/authserver/storage/redis_keys.go | 24 ++ pkg/authserver/storage/redis_test.go | 373 ++++++++++++++++++ 4 files changed, 755 insertions(+) diff --git a/pkg/authserver/storage/redis.go b/pkg/authserver/storage/redis.go index 6b3fe4f19a..b9a3e3ab17 100644 --- a/pkg/authserver/storage/redis.go +++ b/pkg/authserver/storage/redis.go @@ -32,6 +32,17 @@ const ( // nullMarker is used to store nil upstream tokens in Redis. const nullMarker = "null" +// pastExpiryDCRTTL is the bounded TTL applied when a caller writes DCR +// credentials whose ClientSecretExpiresAt is already in the past. The row is +// still accepted (the resolver decides when to re-register) but it self-evicts +// almost immediately so the store does not hold an already-expired secret +// forever. One second is small enough that no caller can usefully read the +// row, large enough that real Redis applies the TTL reliably (sub-second TTLs +// are technically supported via PEXPIRE but the second-grain TTL command is +// the broadly-tested path), and short enough that operational metrics do not +// confuse this row with a healthy long-lived registration. +const pastExpiryDCRTTL = time.Second + // warnOnCleanupErr logs a warning when a best-effort cleanup operation fails. // // Secondary index cleanup in Redis (SRem from reverse-lookup sets, Del of orphaned @@ -1352,6 +1363,176 @@ func unmarshalUpstreamTokens(data []byte) (*UpstreamTokens, error) { return tokens, nil } +// ----------------------- +// DCR Credentials Storage +// ----------------------- + +// storedDCRCredentials is the on-the-wire JSON representation of DCRCredentials. +// Time fields use int64 Unix epoch; 0 is the sentinel meaning "not set", matching +// the storedUpstreamTokens convention. ClientSecretExpiresAt == 0 specifically +// encodes the RFC 7591 §3.2.1 "client_secret does not expire" semantics, in +// which case StoreDCRCredentials persists the entry without a Redis TTL. +type storedDCRCredentials struct { + // Embed the canonical key so a row recovered without its lookup key + // (e.g. via SCAN during diagnostics) still self-identifies. + KeyIssuer string `json:"key_issuer"` + KeyRedirectURI string `json:"key_redirect_uri"` + KeyScopesHash string `json:"key_scopes_hash"` + + ProviderName string `json:"provider_name,omitempty"` + + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` //nolint:gosec // G117: field legitimately holds sensitive data + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + + // Bearer token for the RFC 7592 management endpoint. + //nolint:gosec // G117: field legitimately holds sensitive data + RegistrationAccessToken string `json:"registration_access_token,omitempty"` + RegistrationClientURI string `json:"registration_client_uri,omitempty"` + + AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"` + TokenEndpoint string `json:"token_endpoint,omitempty"` + + CreatedAt int64 `json:"created_at"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at"` +} + +// toDCRCredentials decodes the stored form back into the public type, mirroring +// storedUpstreamTokens.toUpstreamTokens. Zero epoch values become the zero +// time.Time, preserving the "not set" sentinel. +func (s *storedDCRCredentials) toDCRCredentials() *DCRCredentials { + var createdAt time.Time + if s.CreatedAt != 0 { + createdAt = time.Unix(s.CreatedAt, 0) + } + var clientSecretExpiresAt time.Time + if s.ClientSecretExpiresAt != 0 { + clientSecretExpiresAt = time.Unix(s.ClientSecretExpiresAt, 0) + } + return &DCRCredentials{ + Key: DCRKey{ + Issuer: s.KeyIssuer, + RedirectURI: s.KeyRedirectURI, + ScopesHash: s.KeyScopesHash, + }, + ProviderName: s.ProviderName, + ClientID: s.ClientID, + ClientSecret: s.ClientSecret, + TokenEndpointAuthMethod: s.TokenEndpointAuthMethod, + RegistrationAccessToken: s.RegistrationAccessToken, + RegistrationClientURI: s.RegistrationClientURI, + AuthorizationEndpoint: s.AuthorizationEndpoint, + TokenEndpoint: s.TokenEndpoint, + CreatedAt: createdAt, + ClientSecretExpiresAt: clientSecretExpiresAt, + } +} + +// StoreDCRCredentials persists DCR credentials, overwriting any existing entry +// for the same Key. Defensive copy is provided implicitly by JSON serialisation — +// caller mutations after the call cannot reach the persisted bytes. +// +// # TTL +// +// When creds.ClientSecretExpiresAt is non-zero (the upstream advertised an +// RFC 7591 §3.2.1 client_secret_expires_at), the entry is stored with a Redis +// TTL derived from time.Until(ClientSecretExpiresAt) so the row evicts before +// the upstream rejects the secret at the token endpoint. When zero (RFC 7591 +// "never"), Set with TTL=0 is used and the entry is long-lived. +// +// If ClientSecretExpiresAt is already in the past at call time, the entry is +// written with the bounded TTL pastExpiryDCRTTL (1 second) rather than rejected +// or stored long-lived. This keeps the store's "already-expired secret" window +// narrow even if the resolver never re-reads the row, and matches the +// fail-loud-but-tolerant posture: the caller's expiry timestamp round-trips so +// a downstream reader can still observe it and trigger re-registration. +// +// Returns fosite.ErrInvalidRequest for nil creds, empty Issuer, or empty +// RedirectURI — the same fail-loud contract as MemoryStorage.StoreDCRCredentials. +func (s *RedisStorage) StoreDCRCredentials(ctx context.Context, creds *DCRCredentials) error { + if creds == nil { + return fosite.ErrInvalidRequest.WithHint("dcr credentials cannot be nil") + } + if creds.Key.Issuer == "" { + return fosite.ErrInvalidRequest.WithHint("dcr credentials key issuer cannot be empty") + } + if creds.Key.RedirectURI == "" { + return fosite.ErrInvalidRequest.WithHint("dcr credentials key redirect_uri cannot be empty") + } + + key := redisDCRKey(s.keyPrefix, creds.Key) + + stored := storedDCRCredentials{ + KeyIssuer: creds.Key.Issuer, + KeyRedirectURI: creds.Key.RedirectURI, + KeyScopesHash: creds.Key.ScopesHash, + ProviderName: creds.ProviderName, + ClientID: creds.ClientID, + ClientSecret: creds.ClientSecret, + TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod, + RegistrationAccessToken: creds.RegistrationAccessToken, + RegistrationClientURI: creds.RegistrationClientURI, + AuthorizationEndpoint: creds.AuthorizationEndpoint, + TokenEndpoint: creds.TokenEndpoint, + } + if !creds.CreatedAt.IsZero() { + stored.CreatedAt = creds.CreatedAt.Unix() + } + if !creds.ClientSecretExpiresAt.IsZero() { + stored.ClientSecretExpiresAt = creds.ClientSecretExpiresAt.Unix() + } + + data, err := json.Marshal(stored) //nolint:gosec // G117 - internal Redis storage serialization, not exposed to users + if err != nil { + return fmt.Errorf("failed to marshal dcr credentials: %w", err) + } + + // Derive Redis TTL from ClientSecretExpiresAt: + // * Zero (unset) -> TTL=0 (no expiration) per RFC 7591 §3.2.1 "never". + // * Future expiry -> TTL = time.Until(expiry). + // * Past expiry -> TTL = pastExpiryDCRTTL (bounded eviction window). + // See the function docstring for the past-expiry rationale. + ttl := time.Duration(0) + if !creds.ClientSecretExpiresAt.IsZero() { + if until := time.Until(creds.ClientSecretExpiresAt); until > 0 { + ttl = until + } else { + ttl = pastExpiryDCRTTL + } + } + + if err := s.client.Set(ctx, key, data, ttl).Err(); err != nil { + return fmt.Errorf("failed to store dcr credentials: %w", err) + } + return nil +} + +// GetDCRCredentials retrieves the credentials previously persisted under key. +// Returns ErrNotFound (wrapped) when no entry exists. The returned value is a +// fresh struct decoded from JSON, which acts as a defensive copy. +// +// An unpopulated key (empty Issuer or empty RedirectURI) cannot match any +// stored row because StoreDCRCredentials rejects such keys, so a Get against +// one is a normal miss — ErrNotFound — matching MemoryStorage.GetDCRCredentials +// and the DCRCredentialStore interface contract. +func (s *RedisStorage) GetDCRCredentials(ctx context.Context, key DCRKey) (*DCRCredentials, error) { + redisKey := redisDCRKey(s.keyPrefix, key) + data, err := s.client.Get(ctx, redisKey).Bytes() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("%w: %w", ErrNotFound, fosite.ErrNotFound.WithHint("DCR credentials not found")) + } + return nil, fmt.Errorf("failed to get dcr credentials: %w", err) + } + + var stored storedDCRCredentials + if err := json.Unmarshal(data, &stored); err != nil { + return nil, fmt.Errorf("failed to unmarshal dcr credentials: %w", err) + } + + return stored.toDCRCredentials(), nil +} + // ----------------------- // Pending Authorization Storage // ----------------------- @@ -1862,4 +2043,5 @@ var ( _ ClientRegistry = (*RedisStorage)(nil) _ UpstreamTokenStorage = (*RedisStorage)(nil) _ UserStorage = (*RedisStorage)(nil) + _ DCRCredentialStore = (*RedisStorage)(nil) ) diff --git a/pkg/authserver/storage/redis_integration_test.go b/pkg/authserver/storage/redis_integration_test.go index 9552bc3fa4..c87e97984c 100644 --- a/pkg/authserver/storage/redis_integration_test.go +++ b/pkg/authserver/storage/redis_integration_test.go @@ -1413,3 +1413,179 @@ func TestIntegration_MigrateLegacyUpstreamData(t *testing.T) { }) }) } + +// --- DCR Credentials --- +// +// dcrFixtureKey is defined in redis_test.go (no build tag) and is therefore +// visible here under the `integration` build tag as well — sharing a single +// source of truth for the canonical DCR fixture key across unit and +// integration tests. + +func TestIntegration_DCRCredentials_RoundTrip(t *testing.T) { + t.Parallel() + + t.Run("store and get all fields", func(t *testing.T) { + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + createdAt := time.Now().Truncate(time.Second) + expiresAt := createdAt.Add(24 * time.Hour) + + creds := &DCRCredentials{ + Key: dcrFixtureKey(), + ProviderName: "atlassian", + ClientID: "client-int-abc", + ClientSecret: "secret-int-xyz", + TokenEndpointAuthMethod: "client_secret_basic", + RegistrationAccessToken: "rat-int-123", + RegistrationClientURI: "https://idp.example.com/register/client-int-abc", + AuthorizationEndpoint: "https://idp.example.com/authorize", + TokenEndpoint: "https://idp.example.com/token", + CreatedAt: createdAt, + ClientSecretExpiresAt: expiresAt, + } + + require.NoError(t, s.StoreDCRCredentials(ctx, creds)) + + got, err := s.GetDCRCredentials(ctx, creds.Key) + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, *creds, *got) + }) + }) + + t.Run("get non-existent", func(t *testing.T) { + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + _, err := s.GetDCRCredentials(ctx, dcrFixtureKey()) + requireRedisNotFoundError(t, err) + }) + }) +} + +func TestIntegration_DCRCredentials_DistinctKeysCoexist(t *testing.T) { + t.Parallel() + + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + mkKey := func(issuer, redirect string, scopes []string) DCRKey { + return DCRKey{Issuer: issuer, RedirectURI: redirect, ScopesHash: ScopesHash(scopes)} + } + entries := []*DCRCredentials{ + {Key: mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid"}), ClientID: "a"}, + {Key: mkKey("https://idp-b.example.com", "https://x/cb", []string{"openid"}), ClientID: "b"}, + {Key: mkKey("https://idp-a.example.com", "https://y/cb", []string{"openid"}), ClientID: "c"}, + {Key: mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid", "email"}), ClientID: "d"}, + } + for _, e := range entries { + require.NoError(t, s.StoreDCRCredentials(ctx, e)) + } + + for _, want := range entries { + got, err := s.GetDCRCredentials(ctx, want.Key) + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, want.ClientID, got.ClientID) + } + }) +} + +func TestIntegration_DCRCredentials_OverwriteSemantics(t *testing.T) { + t.Parallel() + + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + key := dcrFixtureKey() + + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "first"})) + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "second"})) + + got, err := s.GetDCRCredentials(ctx, key) + require.NoError(t, err) + assert.Equal(t, "second", got.ClientID) + }) +} + +// TestIntegration_DCRCredentials_TTL pins the RFC 7591 §3.2.1 TTL contract +// against a real Redis Sentinel cluster: TTL command observes the expected +// state for both the expiring and the never-expires cases. The unit-level +// miniredis test pins the in-process behaviour; this test pins the wire +// behaviour against real Redis where TTL returns -1 for "no TTL" and -2 for +// "key does not exist". +func TestIntegration_DCRCredentials_TTL(t *testing.T) { + t.Parallel() + + t.Run("non-zero expiry sets observable TTL", func(t *testing.T) { + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + key := dcrFixtureKey() + expires := time.Now().Add(24 * time.Hour).Truncate(time.Second) + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ + Key: key, + ClientID: "client-with-expiry", + ClientSecretExpiresAt: expires, + })) + + ttl, err := s.client.TTL(ctx, redisDCRKey(s.keyPrefix, key)).Result() + require.NoError(t, err) + assert.Greater(t, ttl, time.Duration(0), "TTL must be positive when ClientSecretExpiresAt is in the future") + // Allow 1 second of slack: the wall-clock value is truncated to + // second precision and Redis itself reports TTL with second + // granularity, so an exact 24h bound would be off by up to a + // second on busy CI without changing the underlying behaviour. + assert.LessOrEqual(t, ttl, 24*time.Hour+time.Second, "TTL must not exceed the configured expiry window") + }) + }) + + t.Run("zero expiry means no TTL", func(t *testing.T) { + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + key := dcrFixtureKey() + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ + Key: key, + ClientID: "client-no-expiry", + // ClientSecretExpiresAt deliberately zero. + })) + + ttl, err := s.client.TTL(ctx, redisDCRKey(s.keyPrefix, key)).Result() + require.NoError(t, err) + // go-redis maps Redis "TTL -1" (no expiry) to time.Duration(-1). + assert.Equal(t, time.Duration(-1), ttl, "real Redis must report -1 for a row stored without a TTL") + }) + }) +} + +// TestIntegration_DCRCredentials_ConcurrentAccess pins race-freedom against +// real Redis for concurrent Put/Get on the same set of keys. Run with -race +// to validate the data-race detector is clean. +func TestIntegration_DCRCredentials_ConcurrentAccess(t *testing.T) { + t.Parallel() + + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + const goroutines = 8 + const iterations = 16 + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + gid := g + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + key := DCRKey{ + Issuer: fmt.Sprintf("https://idp-%d.example.com", gid), + RedirectURI: "https://x/cb", + ScopesHash: ScopesHash([]string{"openid"}), + } + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ + Key: key, + ClientID: fmt.Sprintf("client-%d-%d", gid, i), + })) + _, err := s.GetDCRCredentials(ctx, key) + require.NoError(t, err) + } + }() + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(30 * time.Second): + t.Fatal("timeout waiting for concurrent DCR access goroutines") + } + }) +} diff --git a/pkg/authserver/storage/redis_keys.go b/pkg/authserver/storage/redis_keys.go index 32f5855a3e..38249775a6 100644 --- a/pkg/authserver/storage/redis_keys.go +++ b/pkg/authserver/storage/redis_keys.go @@ -58,6 +58,12 @@ const ( // KeyTypeUserProviders is the key type for user to provider identity reverse lookups. KeyTypeUserProviders = "user:providers" + + // KeyTypeDCR is the key type for RFC 7591 Dynamic Client Registration credentials + // persisted by an authserver upstream-DCR resolver. Distinct from KeyTypeClient, + // which holds the authserver's *own* OAuth clients — DCR entries are credentials + // that *this* authserver registered against an *upstream* authorization server. + KeyTypeDCR = "dcr" ) // DeriveKeyPrefix creates the key prefix from the Kubernetes namespace and MCP server name. @@ -88,6 +94,24 @@ func redisProviderKey(prefix, providerID, providerSubject string) string { return fmt.Sprintf("%s%s:%d:%s:%s", prefix, KeyTypeProvider, len(providerID), providerID, providerSubject) } +// redisDCRKey generates a Redis key for a DCR credential entry, identifying the +// (Issuer, RedirectURI, ScopesHash) tuple that DCRKey canonicalises. +// +// Format: "{prefix}dcr:::::" +// +// The first two segments are length-prefixed to handle colons in RedirectURI +// (and, for symmetry, Issuer) without ambiguity, mirroring redisProviderKey. +// ScopesHash is a SHA-256 hex digest produced by storage.ScopesHash; it +// contains only [0-9a-f] and never contains a colon, so it can be appended +// without a length prefix. +func redisDCRKey(prefix string, key DCRKey) string { + return fmt.Sprintf("%s%s:%d:%s:%d:%s:%s", + prefix, KeyTypeDCR, + len(key.Issuer), key.Issuer, + len(key.RedirectURI), key.RedirectURI, + key.ScopesHash) +} + // redisUpstreamKey generates a Redis key for a per-provider upstream token entry. // Format: "{prefix}upstream:{sessionID}:{providerName}" // This enables storing tokens from multiple upstream providers per session. diff --git a/pkg/authserver/storage/redis_test.go b/pkg/authserver/storage/redis_test.go index bb7a75a6bf..9b3b140aed 100644 --- a/pkg/authserver/storage/redis_test.go +++ b/pkg/authserver/storage/redis_test.go @@ -1949,6 +1949,95 @@ func TestRedisKeyGeneration(t *testing.T) { result := redisSetKey("test:auth:", KeyTypeReqIDAccess, "req-123") assert.Equal(t, "test:auth:reqid:access:req-123", result) }) + + t.Run("redisDCRKey", func(t *testing.T) { + t.Parallel() + result := redisDCRKey("test:auth:", DCRKey{ + Issuer: "https://thv.example.com", + RedirectURI: "https://thv.example.com/oauth/callback", + ScopesHash: "abc123", + }) + // 23 = len("https://thv.example.com"), 38 = len("https://thv.example.com/oauth/callback") + assert.Equal(t, + "test:auth:dcr:23:https://thv.example.com:38:https://thv.example.com/oauth/callback:abc123", + result) + }) +} + +// TestRedisDCRKey_Distinct pins the colon-safe lookup contract: any pair of +// distinct DCRKey tuples must serialise to distinct Redis keys, even when one +// component contains the literal substring of another. This is the property +// the length-prefixed encoding exists to guarantee — a plain +// fmt.Sprintf("%s:%s:%s", ...) form would collide for these inputs. +func TestRedisDCRKey_Distinct(t *testing.T) { + t.Parallel() + + mk := func(issuer, redirect, scopes string) DCRKey { + return DCRKey{Issuer: issuer, RedirectURI: redirect, ScopesHash: scopes} + } + + tests := []struct { + name string + a, b DCRKey + }{ + { + name: "different issuer", + a: mk("https://idp-a.example.com", "https://x/cb", "h1"), + b: mk("https://idp-b.example.com", "https://x/cb", "h1"), + }, + { + name: "different redirect_uri", + a: mk("https://idp.example.com", "https://x/cb", "h1"), + b: mk("https://idp.example.com", "https://y/cb", "h1"), + }, + { + name: "different scopes hash", + a: mk("https://idp.example.com", "https://x/cb", "h1"), + b: mk("https://idp.example.com", "https://x/cb", "h2"), + }, + { + // Without length prefixing, ("ab", "cd") and ("a", "bcd") would + // both yield ":ab:cd:" as the issuer/redirect segment after a + // fmt.Sprintf collapse. The length prefix prevents that. + name: "redirect_uri-issuer boundary collision (length-prefix property)", + a: mk("ab", "cd", "h1"), + b: mk("a", "bcd", "h1"), + }, + { + // RedirectURI legitimately contains colons (e.g. ":443"). A plain + // "%s:%s:%s" key would be ambiguous; the length prefix is not. + name: "colons inside redirect_uri", + a: mk("https://idp.example.com", "https://x.example.com:443/cb", "h1"), + b: mk("https://idp.example.com", "https://x.example.com/cb:443", "h1"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ka := redisDCRKey("test:auth:", tc.a) + kb := redisDCRKey("test:auth:", tc.b) + assert.NotEqual(t, ka, kb, "distinct DCRKey tuples must produce distinct Redis keys") + }) + } +} + +// TestRedisDCRKey_Deterministic pins that the key helper is a pure function: +// the same DCRKey produces the same Redis key on every call, with no hidden +// state (e.g. accidental use of map iteration order). +func TestRedisDCRKey_Deterministic(t *testing.T) { + t.Parallel() + + key := DCRKey{ + Issuer: "https://idp.example.com", + RedirectURI: "https://thv.example.com/oauth/callback", + ScopesHash: ScopesHash([]string{"openid", "profile", "email"}), + } + + first := redisDCRKey("test:auth:", key) + for i := 0; i < 4; i++ { + assert.Equal(t, first, redisDCRKey("test:auth:", key)) + } } // --- Health Check Tests --- @@ -2181,3 +2270,287 @@ func TestRedisStorage_GetLatestUpstreamTokensForUser(t *testing.T) { }) }) } + +// --- DCR Credentials Storage --- + +// dcrFixtureKey returns a populated DCRKey for use in DCR tests. +func dcrFixtureKey() DCRKey { + return DCRKey{ + Issuer: "https://thv.example.com", + RedirectURI: "https://thv.example.com/oauth/callback", + ScopesHash: ScopesHash([]string{"openid", "profile"}), + } +} + +func TestRedisStorage_DCRCredentials_RoundTrip(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { + // Truncate to second precision: time fields are stored as int64 unix seconds. + createdAt := time.Now().Truncate(time.Second) + expiresAt := createdAt.Add(24 * time.Hour) + + creds := &DCRCredentials{ + Key: dcrFixtureKey(), + ProviderName: "atlassian", + ClientID: "client-abc", + ClientSecret: "secret-xyz", + TokenEndpointAuthMethod: "client_secret_basic", + RegistrationAccessToken: "rat-123", + RegistrationClientURI: "https://idp.example.com/register/client-abc", + AuthorizationEndpoint: "https://idp.example.com/authorize", + TokenEndpoint: "https://idp.example.com/token", + CreatedAt: createdAt, + ClientSecretExpiresAt: expiresAt, + } + + require.NoError(t, s.StoreDCRCredentials(ctx, creds)) + + got, err := s.GetDCRCredentials(ctx, creds.Key) + require.NoError(t, err) + require.NotNil(t, got) + // Every field round-trips, including the embedded Key. + assert.Equal(t, *creds, *got) + }) +} + +func TestRedisStorage_DCRCredentials_OverwriteSemantics(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { + key := dcrFixtureKey() + + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "first"})) + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "second"})) + + got, err := s.GetDCRCredentials(ctx, key) + require.NoError(t, err) + assert.Equal(t, "second", got.ClientID) + }) +} + +func TestRedisStorage_DCRCredentials_NotFound(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + key DCRKey + }{ + { + name: "populated key with no stored entry", + key: dcrFixtureKey(), + }, + // Unpopulated keys cannot match any stored row (Store rejects them), + // so a Get against one is a normal miss — not a separate error class. + // This pins consistency with MemoryStorage.GetDCRCredentials. + { + name: "empty issuer", + key: DCRKey{Issuer: "", RedirectURI: "https://x/cb"}, + }, + { + name: "empty redirect_uri", + key: DCRKey{Issuer: "https://idp.example.com", RedirectURI: ""}, + }, + { + name: "fully empty key", + key: DCRKey{}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { + _, err := s.GetDCRCredentials(ctx, tc.key) + requireRedisNotFoundError(t, err) + }) + }) + } +} + +func TestRedisStorage_DCRCredentials_DistinctKeysCoexist(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { + mkKey := func(issuer, redirect string, scopes []string) DCRKey { + return DCRKey{Issuer: issuer, RedirectURI: redirect, ScopesHash: ScopesHash(scopes)} + } + entries := []*DCRCredentials{ + {Key: mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid"}), ClientID: "a"}, + {Key: mkKey("https://idp-b.example.com", "https://x/cb", []string{"openid"}), ClientID: "b"}, + {Key: mkKey("https://idp-a.example.com", "https://y/cb", []string{"openid"}), ClientID: "c"}, + {Key: mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid", "email"}), ClientID: "d"}, + } + for _, e := range entries { + require.NoError(t, s.StoreDCRCredentials(ctx, e)) + } + + for _, want := range entries { + got, err := s.GetDCRCredentials(ctx, want.Key) + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, want.ClientID, got.ClientID) + } + }) +} + +func TestRedisStorage_DCRCredentials_StoreInvalidInputRejected(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + creds *DCRCredentials + }{ + { + name: "nil creds", + creds: nil, + }, + { + name: "empty issuer", + creds: &DCRCredentials{ + Key: DCRKey{Issuer: "", RedirectURI: "https://x/cb"}, + }, + }, + { + name: "empty redirect_uri", + creds: &DCRCredentials{ + Key: DCRKey{Issuer: "https://idp.example.com", RedirectURI: ""}, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, mr *miniredis.Miniredis) { + err := s.StoreDCRCredentials(ctx, tc.creds) + assert.ErrorIs(t, err, fosite.ErrInvalidRequest) + // Pin the fail-loud contract: a rejected Store must not leave + // any row behind, even under a partially-populated key. This + // mirrors MemoryStorage's `s.Stats().DCRCredentials == 0` guard + // (see TestMemoryStorage_DCRCredentials_StoreInvalidInputRejected). + assert.Empty(t, mr.Keys(), "rejected Store must not leave any DCR row behind") + }) + }) + } +} + +// TestRedisStorage_DCRCredentials_GetReturnsDefensiveCopy pins the +// defensive-copy contract: mutating a returned value must not be visible to +// subsequent reads. The Redis backend gets this for free from JSON +// deserialisation, but the test pins the contract so a future change (e.g. +// caching) cannot silently break it. +func TestRedisStorage_DCRCredentials_GetReturnsDefensiveCopy(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { + key := dcrFixtureKey() + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "orig"})) + + got, err := s.GetDCRCredentials(ctx, key) + require.NoError(t, err) + got.ClientID = "mutated" + + refetched, err := s.GetDCRCredentials(ctx, key) + require.NoError(t, err) + assert.Equal(t, "orig", refetched.ClientID) + }) +} + +// TestRedisStorage_DCRCredentials_TTL pins the RFC 7591 §3.2.1 +// client_secret_expires_at semantics: +// - When ClientSecretExpiresAt is non-zero, the Redis row carries a TTL +// so it evicts before the upstream rejects the secret. +// - When ClientSecretExpiresAt is zero ("never"), the row is persistent +// (Redis TTL of -1). +// - When ClientSecretExpiresAt is in the past at write time, the row is +// written without a TTL (resolver re-checks expiry on read; see +// StoreDCRCredentials docstring). +func TestRedisStorage_DCRCredentials_TTL(t *testing.T) { + t.Parallel() + + t.Run("non-zero expiry sets a TTL", func(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, mr *miniredis.Miniredis) { + key := dcrFixtureKey() + expires := time.Now().Add(24 * time.Hour).Truncate(time.Second) + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ + Key: key, + ClientID: "client-with-expiry", + ClientSecretExpiresAt: expires, + })) + + ttl := mr.TTL(redisDCRKey("test:auth:", key)) + assert.Greater(t, ttl, time.Duration(0), "TTL should be positive when ClientSecretExpiresAt is in the future") + // Allow some slack for elapsed time between Set and TTL read. + assert.LessOrEqual(t, ttl, 24*time.Hour) + }) + }) + + t.Run("zero expiry means no TTL", func(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, mr *miniredis.Miniredis) { + key := dcrFixtureKey() + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ + Key: key, + ClientID: "client-no-expiry", + // ClientSecretExpiresAt deliberately zero. + })) + + // miniredis returns 0 (not -1) for "no TTL"; the integration test + // asserts the real Redis -1 behaviour separately. + assert.Equal(t, time.Duration(0), mr.TTL(redisDCRKey("test:auth:", key)), + "row should be persistent when ClientSecretExpiresAt is zero") + }) + }) + + t.Run("past expiry uses bounded TTL", func(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, mr *miniredis.Miniredis) { + key := dcrFixtureKey() + past := time.Now().Add(-time.Hour).Truncate(time.Second) + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ + Key: key, + ClientID: "client-past-expiry", + ClientSecretExpiresAt: past, + })) + + // Pin the bounded-TTL contract for past-expiry writes: + // the row exists immediately after the write (so a resolver that + // re-reads can observe the expiry timestamp and re-register), the + // stored ClientSecretExpiresAt round-trips, and the TTL is exactly + // pastExpiryDCRTTL — not 0 (which would persist forever) and not + // the negative time.Until() value. + got, err := s.GetDCRCredentials(ctx, key) + require.NoError(t, err) + assert.Equal(t, past.Unix(), got.ClientSecretExpiresAt.Unix()) + assert.Equal(t, pastExpiryDCRTTL, mr.TTL(redisDCRKey("test:auth:", key)), + "past-expiry write must use the bounded pastExpiryDCRTTL, not TTL=0") + }) + }) +} + +// TestRedisStorage_DCRCredentials_ConcurrentAccess pins the race-freedom of +// concurrent Put/Get under -race. The store is a single Set/Get per call so the +// real test is that running it concurrently doesn't trip the race detector. +func TestRedisStorage_DCRCredentials_ConcurrentAccess(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { + const goroutines = 8 + const iterations = 16 + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + gid := g + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + key := DCRKey{ + Issuer: fmt.Sprintf("https://idp-%d.example.com", gid), + RedirectURI: "https://x/cb", + ScopesHash: ScopesHash([]string{"openid"}), + } + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ + Key: key, + ClientID: fmt.Sprintf("client-%d-%d", gid, i), + })) + _, err := s.GetDCRCredentials(ctx, key) + require.NoError(t, err) + } + }() + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for concurrent DCR access goroutines") + } + }) +} From df0ad8fa078da556e0f1350a5d3b0e07082835fe Mon Sep 17 00:00:00 2001 From: Trey Date: Thu, 7 May 2026 08:58:35 -0700 Subject: [PATCH 2/3] Tighten Redis DCR backend validation parity with Memory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses stacklok/toolhive#5195 review comments: - HIGH redis.go (3202731047): extract validateDCRCredentialsForStore as a shared free function in types.go; both MemoryStorage and RedisStorage delegate to it so the rejection set cannot drift across backends. Extends TestRedisStorage_DCRCredentials_StoreInvalidInputRejected to cover all 7 cases mirrored from the Memory baseline. - HIGH redis.go (3202731049): rewrite the StoreDCRCredentials docstring to point at the shared validator now that the rejection set matches. - MEDIUM redis_test.go (3202731051): apply the suggested replacement on TestRedisStorage_DCRCredentials_TTL — the third bullet now describes the bounded pastExpiryDCRTTL behaviour. - MEDIUM redis_test.go + redis_integration_test.go (3202731061, 3202731063): add an overlapping-key concurrent variant alongside the existing disjoint variant via a shared runDCRConcurrentAccess helper, matching the Memory baseline rationale. - MEDIUM redis_test.go + redis_integration_test.go (3202731065, 3202731066): the shared helper reports goroutine errors via atomic.AddInt32 checked from the test goroutine after wg.Wait(), so testing.T.FailNow is never invoked from a non-test goroutine. - LOW redis_keys.go (3202731071): soften the redisDCRKey docstring's ScopesHash claim and reference the Store-side validator. - LOW redis.go (3202731074): update GetDCRCredentials docstring to enumerate all rejected key components and add an "empty scopes_hash" subtest to TestRedisStorage_DCRCredentials_NotFound. --- pkg/authserver/storage/memory.go | 34 +-- pkg/authserver/storage/redis.go | 22 +- .../storage/redis_integration_test.go | 86 +++--- pkg/authserver/storage/redis_keys.go | 10 +- pkg/authserver/storage/redis_test.go | 277 ++++++++++++++---- pkg/authserver/storage/types.go | 38 +++ 6 files changed, 323 insertions(+), 144 deletions(-) diff --git a/pkg/authserver/storage/memory.go b/pkg/authserver/storage/memory.go index 584c4d3a8a..0d0e9bf905 100644 --- a/pkg/authserver/storage/memory.go +++ b/pkg/authserver/storage/memory.go @@ -1229,37 +1229,11 @@ func cloneDCRCredentials(c *DCRCredentials) *DCRCredentials { // retained verbatim for callers to re-check on read (see the interface // docstring's "TTL handling" section). // -// Validation rejects nil creds, an unpopulated Key (empty Issuer, -// RedirectURI, or ScopesHash), and missing RFC 7591 mandatory response -// fields (ClientID, AuthorizationEndpoint, TokenEndpoint). An empty -// ScopesHash is rejected because the canonical digest of any scope set — -// including the empty-scope set via ScopesHash(nil) — is non-empty, so an -// empty string can only be a caller bug; accepting it would silently -// route a forgotten-hash record to a different cache slot than a sibling -// caller that did compute ScopesHash. ClientSecret is left permissive -// because RFC 7591 §2 public clients (auth method "none") legitimately -// register without a secret. +// Validation is delegated to validateDCRCredentialsForStore so the rejection +// set stays in sync with sibling backends. func (s *MemoryStorage) StoreDCRCredentials(_ context.Context, creds *DCRCredentials) error { - if creds == nil { - return fosite.ErrInvalidRequest.WithHint("dcr credentials cannot be nil") - } - if creds.Key.Issuer == "" { - return fosite.ErrInvalidRequest.WithHint("dcr credentials key issuer cannot be empty") - } - if creds.Key.RedirectURI == "" { - return fosite.ErrInvalidRequest.WithHint("dcr credentials key redirect_uri cannot be empty") - } - if creds.Key.ScopesHash == "" { - return fosite.ErrInvalidRequest.WithHint("dcr credentials key scopes_hash cannot be empty") - } - if creds.ClientID == "" { - return fosite.ErrInvalidRequest.WithHint("dcr credentials client_id cannot be empty") - } - if creds.AuthorizationEndpoint == "" { - return fosite.ErrInvalidRequest.WithHint("dcr credentials authorization_endpoint cannot be empty") - } - if creds.TokenEndpoint == "" { - return fosite.ErrInvalidRequest.WithHint("dcr credentials token_endpoint cannot be empty") + if err := validateDCRCredentialsForStore(creds); err != nil { + return err } s.mu.Lock() diff --git a/pkg/authserver/storage/redis.go b/pkg/authserver/storage/redis.go index b9a3e3ab17..76651109eb 100644 --- a/pkg/authserver/storage/redis.go +++ b/pkg/authserver/storage/redis.go @@ -1447,17 +1447,11 @@ func (s *storedDCRCredentials) toDCRCredentials() *DCRCredentials { // fail-loud-but-tolerant posture: the caller's expiry timestamp round-trips so // a downstream reader can still observe it and trigger re-registration. // -// Returns fosite.ErrInvalidRequest for nil creds, empty Issuer, or empty -// RedirectURI — the same fail-loud contract as MemoryStorage.StoreDCRCredentials. +// Validation is delegated to validateDCRCredentialsForStore so the rejection +// set stays in sync with MemoryStorage and any future backend. func (s *RedisStorage) StoreDCRCredentials(ctx context.Context, creds *DCRCredentials) error { - if creds == nil { - return fosite.ErrInvalidRequest.WithHint("dcr credentials cannot be nil") - } - if creds.Key.Issuer == "" { - return fosite.ErrInvalidRequest.WithHint("dcr credentials key issuer cannot be empty") - } - if creds.Key.RedirectURI == "" { - return fosite.ErrInvalidRequest.WithHint("dcr credentials key redirect_uri cannot be empty") + if err := validateDCRCredentialsForStore(creds); err != nil { + return err } key := redisDCRKey(s.keyPrefix, creds.Key) @@ -1511,10 +1505,10 @@ func (s *RedisStorage) StoreDCRCredentials(ctx context.Context, creds *DCRCreden // Returns ErrNotFound (wrapped) when no entry exists. The returned value is a // fresh struct decoded from JSON, which acts as a defensive copy. // -// An unpopulated key (empty Issuer or empty RedirectURI) cannot match any -// stored row because StoreDCRCredentials rejects such keys, so a Get against -// one is a normal miss — ErrNotFound — matching MemoryStorage.GetDCRCredentials -// and the DCRCredentialStore interface contract. +// An unpopulated key (empty Issuer, RedirectURI, or ScopesHash) cannot match +// any stored row because StoreDCRCredentials rejects such keys, so a Get +// against one is a normal miss — ErrNotFound — matching +// MemoryStorage.GetDCRCredentials and the DCRCredentialStore interface contract. func (s *RedisStorage) GetDCRCredentials(ctx context.Context, key DCRKey) (*DCRCredentials, error) { redisKey := redisDCRKey(s.keyPrefix, key) data, err := s.client.Get(ctx, redisKey).Bytes() diff --git a/pkg/authserver/storage/redis_integration_test.go b/pkg/authserver/storage/redis_integration_test.go index c87e97984c..9518564818 100644 --- a/pkg/authserver/storage/redis_integration_test.go +++ b/pkg/authserver/storage/redis_integration_test.go @@ -1467,11 +1467,19 @@ func TestIntegration_DCRCredentials_DistinctKeysCoexist(t *testing.T) { mkKey := func(issuer, redirect string, scopes []string) DCRKey { return DCRKey{Issuer: issuer, RedirectURI: redirect, ScopesHash: ScopesHash(scopes)} } + mk := func(key DCRKey, clientID string) *DCRCredentials { + return &DCRCredentials{ + Key: key, + ClientID: clientID, + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", + } + } entries := []*DCRCredentials{ - {Key: mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid"}), ClientID: "a"}, - {Key: mkKey("https://idp-b.example.com", "https://x/cb", []string{"openid"}), ClientID: "b"}, - {Key: mkKey("https://idp-a.example.com", "https://y/cb", []string{"openid"}), ClientID: "c"}, - {Key: mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid", "email"}), ClientID: "d"}, + mk(mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid"}), "a"), + mk(mkKey("https://idp-b.example.com", "https://x/cb", []string{"openid"}), "b"), + mk(mkKey("https://idp-a.example.com", "https://y/cb", []string{"openid"}), "c"), + mk(mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid", "email"}), "d"), } for _, e := range entries { require.NoError(t, s.StoreDCRCredentials(ctx, e)) @@ -1491,9 +1499,17 @@ func TestIntegration_DCRCredentials_OverwriteSemantics(t *testing.T) { withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { key := dcrFixtureKey() + mk := func(clientID string) *DCRCredentials { + return &DCRCredentials{ + Key: key, + ClientID: clientID, + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", + } + } - require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "first"})) - require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "second"})) + require.NoError(t, s.StoreDCRCredentials(ctx, mk("first"))) + require.NoError(t, s.StoreDCRCredentials(ctx, mk("second"))) got, err := s.GetDCRCredentials(ctx, key) require.NoError(t, err) @@ -1517,6 +1533,8 @@ func TestIntegration_DCRCredentials_TTL(t *testing.T) { require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ Key: key, ClientID: "client-with-expiry", + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", ClientSecretExpiresAt: expires, })) @@ -1535,8 +1553,10 @@ func TestIntegration_DCRCredentials_TTL(t *testing.T) { withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { key := dcrFixtureKey() require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ - Key: key, - ClientID: "client-no-expiry", + Key: key, + ClientID: "client-no-expiry", + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", // ClientSecretExpiresAt deliberately zero. })) @@ -1549,43 +1569,25 @@ func TestIntegration_DCRCredentials_TTL(t *testing.T) { } // TestIntegration_DCRCredentials_ConcurrentAccess pins race-freedom against -// real Redis for concurrent Put/Get on the same set of keys. Run with -race -// to validate the data-race detector is clean. +// real Redis. Mirrors the unit-test sibling (overlapping + disjoint +// keyspaces) using the shared runDCRConcurrentAccess helper from +// redis_test.go (visible across build tags), with a longer timeout to +// absorb real-Redis network latency. Run with -race to validate the +// data-race detector is clean. func TestIntegration_DCRCredentials_ConcurrentAccess(t *testing.T) { t.Parallel() - withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { - const goroutines = 8 - const iterations = 16 - - var wg sync.WaitGroup - wg.Add(goroutines) - for g := 0; g < goroutines; g++ { - gid := g - go func() { - defer wg.Done() - for i := 0; i < iterations; i++ { - key := DCRKey{ - Issuer: fmt.Sprintf("https://idp-%d.example.com", gid), - RedirectURI: "https://x/cb", - ScopesHash: ScopesHash([]string{"openid"}), - } - require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ - Key: key, - ClientID: fmt.Sprintf("client-%d-%d", gid, i), - })) - _, err := s.GetDCRCredentials(ctx, key) - require.NoError(t, err) - } - }() - } + t.Run("overlapping_key", func(t *testing.T) { + t.Parallel() + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + runDCRConcurrentAccess(ctx, t, s, dcrConcurrentOverlappingKey, 30*time.Second) + }) + }) - done := make(chan struct{}) - go func() { wg.Wait(); close(done) }() - select { - case <-done: - case <-time.After(30 * time.Second): - t.Fatal("timeout waiting for concurrent DCR access goroutines") - } + t.Run("disjoint_keys", func(t *testing.T) { + t.Parallel() + withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { + runDCRConcurrentAccess(ctx, t, s, dcrConcurrentDisjointKeys, 30*time.Second) + }) }) } diff --git a/pkg/authserver/storage/redis_keys.go b/pkg/authserver/storage/redis_keys.go index 38249775a6..b7f3a5a625 100644 --- a/pkg/authserver/storage/redis_keys.go +++ b/pkg/authserver/storage/redis_keys.go @@ -101,9 +101,13 @@ func redisProviderKey(prefix, providerID, providerSubject string) string { // // The first two segments are length-prefixed to handle colons in RedirectURI // (and, for symmetry, Issuer) without ambiguity, mirroring redisProviderKey. -// ScopesHash is a SHA-256 hex digest produced by storage.ScopesHash; it -// contains only [0-9a-f] and never contains a colon, so it can be appended -// without a length prefix. +// ScopesHash is expected to be a SHA-256 hex digest produced by +// storage.ScopesHash — only [0-9a-f] and never colon-bearing — so it is +// appended without a length prefix. The format is robust for that domain; +// validateDCRCredentialsForStore (called by every Store path) already +// rejects an empty ScopesHash, and callers are required to compute the hash +// via storage.ScopesHash. Length-prefix collision-safety is preserved on +// the leading segments either way. func redisDCRKey(prefix string, key DCRKey) string { return fmt.Sprintf("%s%s:%d:%s:%d:%s:%s", prefix, KeyTypeDCR, diff --git a/pkg/authserver/storage/redis_test.go b/pkg/authserver/storage/redis_test.go index 9b3b140aed..a48a6f6eff 100644 --- a/pkg/authserver/storage/redis_test.go +++ b/pkg/authserver/storage/redis_test.go @@ -12,6 +12,7 @@ import ( "fmt" "net/url" "sync" + "sync/atomic" "testing" "time" @@ -2315,9 +2316,17 @@ func TestRedisStorage_DCRCredentials_RoundTrip(t *testing.T) { func TestRedisStorage_DCRCredentials_OverwriteSemantics(t *testing.T) { withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { key := dcrFixtureKey() + mk := func(clientID string) *DCRCredentials { + return &DCRCredentials{ + Key: key, + ClientID: clientID, + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", + } + } - require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "first"})) - require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "second"})) + require.NoError(t, s.StoreDCRCredentials(ctx, mk("first"))) + require.NoError(t, s.StoreDCRCredentials(ctx, mk("second"))) got, err := s.GetDCRCredentials(ctx, key) require.NoError(t, err) @@ -2347,6 +2356,10 @@ func TestRedisStorage_DCRCredentials_NotFound(t *testing.T) { name: "empty redirect_uri", key: DCRKey{Issuer: "https://idp.example.com", RedirectURI: ""}, }, + { + name: "empty scopes_hash", + key: DCRKey{Issuer: "https://idp.example.com", RedirectURI: "https://x/cb", ScopesHash: ""}, + }, { name: "fully empty key", key: DCRKey{}, @@ -2367,11 +2380,19 @@ func TestRedisStorage_DCRCredentials_DistinctKeysCoexist(t *testing.T) { mkKey := func(issuer, redirect string, scopes []string) DCRKey { return DCRKey{Issuer: issuer, RedirectURI: redirect, ScopesHash: ScopesHash(scopes)} } + mk := func(key DCRKey, clientID string) *DCRCredentials { + return &DCRCredentials{ + Key: key, + ClientID: clientID, + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", + } + } entries := []*DCRCredentials{ - {Key: mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid"}), ClientID: "a"}, - {Key: mkKey("https://idp-b.example.com", "https://x/cb", []string{"openid"}), ClientID: "b"}, - {Key: mkKey("https://idp-a.example.com", "https://y/cb", []string{"openid"}), ClientID: "c"}, - {Key: mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid", "email"}), ClientID: "d"}, + mk(mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid"}), "a"), + mk(mkKey("https://idp-b.example.com", "https://x/cb", []string{"openid"}), "b"), + mk(mkKey("https://idp-a.example.com", "https://y/cb", []string{"openid"}), "c"), + mk(mkKey("https://idp-a.example.com", "https://x/cb", []string{"openid", "email"}), "d"), } for _, e := range entries { require.NoError(t, s.StoreDCRCredentials(ctx, e)) @@ -2386,34 +2407,84 @@ func TestRedisStorage_DCRCredentials_DistinctKeysCoexist(t *testing.T) { }) } +// TestRedisStorage_DCRCredentials_StoreInvalidInputRejected mirrors +// TestMemoryStorage_DCRCredentials_StoreInvalidInputRejected: every input +// rejected by validateDCRCredentialsForStore must produce +// fosite.ErrInvalidRequest and leave no row behind in Redis. func TestRedisStorage_DCRCredentials_StoreInvalidInputRejected(t *testing.T) { t.Parallel() + // validCreds returns a fully-populated DCRCredentials that subtests + // mutate to isolate a single missing field. Keeping every other field + // valid ensures the assertion proves which field was rejected. + validCreds := func() *DCRCredentials { + return &DCRCredentials{ + Key: DCRKey{ + Issuer: "https://idp.example.com", + RedirectURI: "https://x/cb", + ScopesHash: ScopesHash([]string{"openid"}), + }, + ClientID: "abc", + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", + } + } + tests := []struct { - name string - creds *DCRCredentials + name string + mutator func(*DCRCredentials) *DCRCredentials }{ { - name: "nil creds", - creds: nil, + name: "nil creds", + mutator: func(*DCRCredentials) *DCRCredentials { return nil }, }, { name: "empty issuer", - creds: &DCRCredentials{ - Key: DCRKey{Issuer: "", RedirectURI: "https://x/cb"}, + mutator: func(c *DCRCredentials) *DCRCredentials { + c.Key.Issuer = "" + return c }, }, { name: "empty redirect_uri", - creds: &DCRCredentials{ - Key: DCRKey{Issuer: "https://idp.example.com", RedirectURI: ""}, + mutator: func(c *DCRCredentials) *DCRCredentials { + c.Key.RedirectURI = "" + return c + }, + }, + { + name: "empty scopes_hash", + mutator: func(c *DCRCredentials) *DCRCredentials { + c.Key.ScopesHash = "" + return c + }, + }, + { + name: "empty client_id", + mutator: func(c *DCRCredentials) *DCRCredentials { + c.ClientID = "" + return c + }, + }, + { + name: "empty authorization_endpoint", + mutator: func(c *DCRCredentials) *DCRCredentials { + c.AuthorizationEndpoint = "" + return c + }, + }, + { + name: "empty token_endpoint", + mutator: func(c *DCRCredentials) *DCRCredentials { + c.TokenEndpoint = "" + return c }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { withRedisStorage(t, func(ctx context.Context, s *RedisStorage, mr *miniredis.Miniredis) { - err := s.StoreDCRCredentials(ctx, tc.creds) + err := s.StoreDCRCredentials(ctx, tc.mutator(validCreds())) assert.ErrorIs(t, err, fosite.ErrInvalidRequest) // Pin the fail-loud contract: a rejected Store must not leave // any row behind, even under a partially-populated key. This @@ -2433,7 +2504,12 @@ func TestRedisStorage_DCRCredentials_StoreInvalidInputRejected(t *testing.T) { func TestRedisStorage_DCRCredentials_GetReturnsDefensiveCopy(t *testing.T) { withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { key := dcrFixtureKey() - require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{Key: key, ClientID: "orig"})) + require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ + Key: key, + ClientID: "orig", + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", + })) got, err := s.GetDCRCredentials(ctx, key) require.NoError(t, err) @@ -2451,9 +2527,10 @@ func TestRedisStorage_DCRCredentials_GetReturnsDefensiveCopy(t *testing.T) { // so it evicts before the upstream rejects the secret. // - When ClientSecretExpiresAt is zero ("never"), the row is persistent // (Redis TTL of -1). -// - When ClientSecretExpiresAt is in the past at write time, the row is -// written without a TTL (resolver re-checks expiry on read; see -// StoreDCRCredentials docstring). +// - When ClientSecretExpiresAt is in the past at write time, the row +// is written with the bounded `pastExpiryDCRTTL` (1 second) so an +// already-expired secret self-evicts almost immediately rather than +// persisting forever (see StoreDCRCredentials docstring). func TestRedisStorage_DCRCredentials_TTL(t *testing.T) { t.Parallel() @@ -2464,6 +2541,8 @@ func TestRedisStorage_DCRCredentials_TTL(t *testing.T) { require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ Key: key, ClientID: "client-with-expiry", + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", ClientSecretExpiresAt: expires, })) @@ -2478,8 +2557,10 @@ func TestRedisStorage_DCRCredentials_TTL(t *testing.T) { withRedisStorage(t, func(ctx context.Context, s *RedisStorage, mr *miniredis.Miniredis) { key := dcrFixtureKey() require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ - Key: key, - ClientID: "client-no-expiry", + Key: key, + ClientID: "client-no-expiry", + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", // ClientSecretExpiresAt deliberately zero. })) @@ -2497,6 +2578,8 @@ func TestRedisStorage_DCRCredentials_TTL(t *testing.T) { require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ Key: key, ClientID: "client-past-expiry", + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", ClientSecretExpiresAt: past, })) @@ -2515,42 +2598,126 @@ func TestRedisStorage_DCRCredentials_TTL(t *testing.T) { }) } -// TestRedisStorage_DCRCredentials_ConcurrentAccess pins the race-freedom of -// concurrent Put/Get under -race. The store is a single Set/Get per call so the -// real test is that running it concurrently doesn't trip the race detector. +// TestRedisStorage_DCRCredentials_ConcurrentAccess pins race-freedom of +// concurrent Put/Get under -race. Mirrors the Memory baseline by exercising +// both an overlapping keyspace (every goroutine hammers the same key, so +// reads can observe any goroutine's last write) and a disjoint keyspace +// (per-goroutine key, so each goroutine's Get must always hit). With go +// test -race this catches a future change that drops the lock or returns +// an internal pointer instead of a defensive copy. +// +// Errors from spawned goroutines are reported via an atomic counter checked +// from the test goroutine after wg.Wait() — calling require.NoError / +// FailNow from a goroutine other than the one running the test function +// is undefined behaviour per the testing.T docs. func TestRedisStorage_DCRCredentials_ConcurrentAccess(t *testing.T) { - withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { - const goroutines = 8 - const iterations = 16 - - var wg sync.WaitGroup - wg.Add(goroutines) - for g := 0; g < goroutines; g++ { - gid := g - go func() { - defer wg.Done() - for i := 0; i < iterations; i++ { - key := DCRKey{ - Issuer: fmt.Sprintf("https://idp-%d.example.com", gid), - RedirectURI: "https://x/cb", - ScopesHash: ScopesHash([]string{"openid"}), - } - require.NoError(t, s.StoreDCRCredentials(ctx, &DCRCredentials{ - Key: key, - ClientID: fmt.Sprintf("client-%d-%d", gid, i), - })) - _, err := s.GetDCRCredentials(ctx, key) - require.NoError(t, err) - } - }() + t.Parallel() + + t.Run("overlapping_key", func(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { + runDCRConcurrentAccess(ctx, t, s, dcrConcurrentOverlappingKey, 10*time.Second) + }) + }) + + t.Run("disjoint_keys", func(t *testing.T) { + withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { + runDCRConcurrentAccess(ctx, t, s, dcrConcurrentDisjointKeys, 10*time.Second) + }) + }) +} + +// dcrConcurrentMode selects the keyspace strategy used by +// runDCRConcurrentAccess. Two strategies — overlapping (every goroutine +// writes/reads the same key) and disjoint (each goroutine has its own key) — +// mirror the Memory baseline rationale at +// TestMemoryStorage_DCRCredentials_ConcurrentAccess. +type dcrConcurrentMode int + +const ( + dcrConcurrentOverlappingKey dcrConcurrentMode = iota + dcrConcurrentDisjointKeys +) + +// runDCRConcurrentAccess fans out goroutines doing alternating +// StoreDCRCredentials / GetDCRCredentials and asserts no Store errored and, +// when the keyspace is disjoint, that every Get hit. Shared between the +// unit-test (miniredis) and integration-test (real Redis) suites — the +// integration suite passes a longer deadline. +func runDCRConcurrentAccess( + ctx context.Context, + t *testing.T, + s *RedisStorage, + mode dcrConcurrentMode, + deadline time.Duration, +) { + t.Helper() + + const ( + goroutines = 8 + iterations = 16 + ) + + keyFor := func(gid, _ int) DCRKey { + switch mode { + case dcrConcurrentOverlappingKey: + return dcrFixtureKey() + case dcrConcurrentDisjointKeys: + return DCRKey{ + Issuer: fmt.Sprintf("https://idp-%d.example.com", gid), + RedirectURI: "https://x/cb", + ScopesHash: ScopesHash([]string{"openid"}), + } } + t.Fatalf("unknown dcrConcurrentMode %d", mode) + return DCRKey{} + } - done := make(chan struct{}) - go func() { wg.Wait(); close(done) }() - select { - case <-done: - case <-time.After(10 * time.Second): - t.Fatal("timeout waiting for concurrent DCR access goroutines") + mkCreds := func(key DCRKey, gid, i int) *DCRCredentials { + return &DCRCredentials{ + Key: key, + ClientID: fmt.Sprintf("client-%d-%d", gid, i), + AuthorizationEndpoint: "https://idp.example.com/auth", + TokenEndpoint: "https://idp.example.com/token", } - }) + } + + var ( + storeErrCount int32 + getErrCount int32 + ) + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + gid := g + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + key := keyFor(gid, i) + if err := s.StoreDCRCredentials(ctx, mkCreds(key, gid, i)); err != nil { + atomic.AddInt32(&storeErrCount, 1) + continue + } + if _, err := s.GetDCRCredentials(ctx, key); err != nil { + // In the disjoint keyspace, every goroutine just wrote its own + // key; a miss is a real error. In the overlapping keyspace, + // the immediate Get can race with another goroutine's + // rewrite-then-evict only if a TTL expires mid-test, which + // none of these credentials use, so a miss there is also an + // error to track. + atomic.AddInt32(&getErrCount, 1) + } + } + }() + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(deadline): + t.Fatal("timeout waiting for concurrent DCR access goroutines") + } + + assert.Zero(t, atomic.LoadInt32(&storeErrCount), "no concurrent Store should have errored") + assert.Zero(t, atomic.LoadInt32(&getErrCount), "no concurrent Get should have errored") } diff --git a/pkg/authserver/storage/types.go b/pkg/authserver/storage/types.go index 19231212f3..595a80101e 100644 --- a/pkg/authserver/storage/types.go +++ b/pkg/authserver/storage/types.go @@ -171,6 +171,44 @@ func ScopesHash(scopes []string) string { return hex.EncodeToString(h.Sum(nil)) } +// validateDCRCredentialsForStore enforces the rejection contract that every +// DCRCredentialStore implementation must apply before persisting. Extracting +// it as a free function called by every backend prevents the validation set +// from drifting across implementations: a record that fails loud against one +// backend cannot silently persist against another. +// +// Rejected inputs: nil creds, an unpopulated Key (empty Issuer, RedirectURI, +// or ScopesHash), and missing RFC 7591 mandatory response fields (ClientID, +// AuthorizationEndpoint, TokenEndpoint). An empty ScopesHash is rejected +// because the canonical digest of any scope set — including the empty-scope +// set via ScopesHash(nil) — is non-empty, so an empty string can only be a +// caller bug. ClientSecret is left permissive because RFC 7591 §2 public +// clients (auth method "none") legitimately register without a secret. +func validateDCRCredentialsForStore(creds *DCRCredentials) error { + if creds == nil { + return fosite.ErrInvalidRequest.WithHint("dcr credentials cannot be nil") + } + if creds.Key.Issuer == "" { + return fosite.ErrInvalidRequest.WithHint("dcr credentials key issuer cannot be empty") + } + if creds.Key.RedirectURI == "" { + return fosite.ErrInvalidRequest.WithHint("dcr credentials key redirect_uri cannot be empty") + } + if creds.Key.ScopesHash == "" { + return fosite.ErrInvalidRequest.WithHint("dcr credentials key scopes_hash cannot be empty") + } + if creds.ClientID == "" { + return fosite.ErrInvalidRequest.WithHint("dcr credentials client_id cannot be empty") + } + if creds.AuthorizationEndpoint == "" { + return fosite.ErrInvalidRequest.WithHint("dcr credentials authorization_endpoint cannot be empty") + } + if creds.TokenEndpoint == "" { + return fosite.ErrInvalidRequest.WithHint("dcr credentials token_endpoint cannot be empty") + } + return nil +} + // DCRCredentials is the persisted form of an RFC 7591 Dynamic Client // Registration result. All fields are populated from the upstream's DCR // response. The RFC 7592 management fields (RegistrationAccessToken, From 670ea89d3583c2a606d1cecdd7e80080380dd3e8 Mon Sep 17 00:00:00 2001 From: Trey Date: Fri, 8 May 2026 08:01:59 -0700 Subject: [PATCH 3/3] Fix DCR integration parallel-test panic and stale type docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses stacklok/toolhive#5195 review comments: - HIGH redis_integration_test.go (3204538839, 3204538845, 3204538852, 3204538858): drop the explicit t.Parallel() at the outer DistinctKeysCoexist / OverwriteSemantics tests and inside the ConcurrentAccess overlapping_key / disjoint_keys subtests. Each one was followed directly by withIntegrationStorage(t, ...), whose own t.Parallel() panicked with "t.Parallel called multiple times" on the same *testing.T. - MEDIUM types.go (review body): drop the "future sub-issue" framing from the DCRCredentials Lifetime section, the ClientSecretExpiresAt field doc, and the DCRCredentialStore interface doc now that RedisStorage exists. Also correct the "Redis SetEX" wording to "SET with a duration" in three places — go-redis emits SET ... EX/PX for a duration-bearing Set, never SETEX. --- .../storage/redis_integration_test.go | 6 ------ pkg/authserver/storage/types.go | 21 ++++++++++--------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/pkg/authserver/storage/redis_integration_test.go b/pkg/authserver/storage/redis_integration_test.go index 9518564818..5cd08f5e19 100644 --- a/pkg/authserver/storage/redis_integration_test.go +++ b/pkg/authserver/storage/redis_integration_test.go @@ -1461,8 +1461,6 @@ func TestIntegration_DCRCredentials_RoundTrip(t *testing.T) { } func TestIntegration_DCRCredentials_DistinctKeysCoexist(t *testing.T) { - t.Parallel() - withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { mkKey := func(issuer, redirect string, scopes []string) DCRKey { return DCRKey{Issuer: issuer, RedirectURI: redirect, ScopesHash: ScopesHash(scopes)} @@ -1495,8 +1493,6 @@ func TestIntegration_DCRCredentials_DistinctKeysCoexist(t *testing.T) { } func TestIntegration_DCRCredentials_OverwriteSemantics(t *testing.T) { - t.Parallel() - withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { key := dcrFixtureKey() mk := func(clientID string) *DCRCredentials { @@ -1578,14 +1574,12 @@ func TestIntegration_DCRCredentials_ConcurrentAccess(t *testing.T) { t.Parallel() t.Run("overlapping_key", func(t *testing.T) { - t.Parallel() withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { runDCRConcurrentAccess(ctx, t, s, dcrConcurrentOverlappingKey, 30*time.Second) }) }) t.Run("disjoint_keys", func(t *testing.T) { - t.Parallel() withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { runDCRConcurrentAccess(ctx, t, s, dcrConcurrentDisjointKeys, 30*time.Second) }) diff --git a/pkg/authserver/storage/types.go b/pkg/authserver/storage/types.go index 595a80101e..1e70da53d2 100644 --- a/pkg/authserver/storage/types.go +++ b/pkg/authserver/storage/types.go @@ -227,8 +227,8 @@ func validateDCRCredentialsForStore(creds *DCRCredentials) error { // Entries are long-lived — RFC 7591 client registrations do not expire unless // the upstream asserts client_secret_expires_at. The in-memory backend // retains entries for the process lifetime and is intentionally excluded from -// the periodic cleanup loop. The Redis backend (future sub-issue) applies -// TTL via SetEX when ClientSecretExpiresAt is non-zero. +// the periodic cleanup loop. The Redis backend applies TTL via SET with a +// duration when ClientSecretExpiresAt is non-zero. type DCRCredentials struct { // Key is the canonical cache key: (Issuer, RedirectURI, ScopesHash). Key DCRKey @@ -263,8 +263,8 @@ type DCRCredentials struct { // rather than special-casing 0. // // When non-zero, this is the authoritative signal a backend uses to TTL - // the persisted entry: the Redis backend (sub-issue 2) plumbs it through - // SetEX so the row evicts before the upstream rejects the secret at the + // the persisted entry: the Redis backend plumbs it through SET with a + // duration so the row evicts before the upstream rejects the secret at the // token endpoint. The in-memory backend ignores this field — entries // persist for the process lifetime and the resolver re-checks the // expiry on read. @@ -272,8 +272,8 @@ type DCRCredentials struct { } // DCRCredentialStore is a narrow, segregated interface for persisting -// dynamic-client-registration credentials. Both MemoryStorage and a future -// Redis-backed store implement it; an authserver backed by Redis shares DCR +// dynamic-client-registration credentials. Both MemoryStorage and +// RedisStorage implement it; an authserver backed by Redis shares DCR // credentials across replicas and restarts. // // # Cross-replica limitation @@ -292,10 +292,11 @@ type DCRCredentials struct { // // Implementations SHOULD honor a non-zero DCRCredentials.ClientSecretExpiresAt // as a backend-level TTL when the underlying store supports one (e.g. Redis -// SetEX) so an entry evicts before the upstream rejects the secret at the -// token endpoint. Backends without a native TTL (e.g. the in-memory backend) -// retain the field verbatim and rely on the caller — typically the runner's -// resolver — to re-check expiry on read; see MemoryStorage.GetDCRCredentials. +// SET with a duration) so an entry evicts before the upstream rejects the +// secret at the token endpoint. Backends without a native TTL (e.g. the +// in-memory backend) retain the field verbatim and rely on the caller — +// typically the runner's resolver — to re-check expiry on read; see +// MemoryStorage.GetDCRCredentials. // A zero ClientSecretExpiresAt means the upstream did not assert an expiry // and no TTL is applied. //