From ab9807b80981d324811b2554a96a96f5b3dbc3e1 Mon Sep 17 00:00:00 2001 From: Aron Gates Date: Fri, 6 Mar 2026 14:07:43 +0000 Subject: [PATCH 1/7] Refresh upstream tokens transparently instead of forcing re-auth MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The upstreamswap middleware returned 401 when upstream access tokens expired, forcing users through full re-authentication even though valid refresh tokens existed in storage. This happened because: 1. Redis/memory storage TTL was set to access token expiry, deleting the entry (and refresh token) when the access token expired 2. Storage returned nil on ErrExpired, discarding the refresh token 3. The middleware had no refresh path — only 401 Fix all three layers: - Add DefaultRefreshTokenTTL (30 days) to storage entry TTL so refresh tokens survive past access token expiry - Return token data alongside ErrExpired from storage so callers can use the refresh token - Add UpstreamTokenRefresher interface and implementation that wraps the upstream OAuth2Provider and storage - Plumb the refresher through Server → EmbeddedAuthServer → Runner → MiddlewareRunner - Update upstreamswap middleware to attempt refresh before returning 401, only requiring re-auth when the refresh token itself fails Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/upstreamswap/middleware.go | 102 +++++++++++++++---- pkg/auth/upstreamswap/middleware_test.go | 25 +++-- pkg/authserver/refresher.go | 83 +++++++++++++++ pkg/authserver/runner/embeddedauthserver.go | 6 ++ pkg/authserver/server.go | 5 + pkg/authserver/server_impl.go | 12 +++ pkg/authserver/storage/memory.go | 26 +++-- pkg/authserver/storage/memory_test.go | 14 ++- pkg/authserver/storage/mocks/mock_storage.go | 39 +++++++ pkg/authserver/storage/redis.go | 24 +++-- pkg/authserver/storage/redis_test.go | 15 ++- pkg/authserver/storage/types.go | 14 ++- pkg/runner/runner.go | 12 +++ pkg/transport/types/mocks/mock_transport.go | 14 +++ pkg/transport/types/transport.go | 5 + 15 files changed, 338 insertions(+), 58 deletions(-) create mode 100644 pkg/authserver/refresher.go diff --git a/pkg/auth/upstreamswap/middleware.go b/pkg/auth/upstreamswap/middleware.go index 5af02e099c..1bf80208ef 100644 --- a/pkg/auth/upstreamswap/middleware.go +++ b/pkg/auth/upstreamswap/middleware.go @@ -6,6 +6,7 @@ package upstreamswap import ( + "context" "encoding/json" "errors" "fmt" @@ -48,6 +49,10 @@ type MiddlewareParams struct { // This allows lazy access to the storage, which may not be available at middleware creation time. type StorageGetter func() storage.UpstreamTokenStorage +// RefresherGetter is a function that returns an upstream token refresher. +// This allows lazy access to the refresher, which may not be available at middleware creation time. +type RefresherGetter func() storage.UpstreamTokenRefresher + // Middleware wraps the upstream swap middleware functionality. type Middleware struct { middleware types.MiddlewareFunction @@ -81,12 +86,13 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("invalid upstream swap configuration: %w", err) } - // Get storage getter from runner. - // The storage getter is a lazy accessor that checks storage availability at request time, - // so it's always non-nil. Actual storage availability is verified when processing requests. + // Get storage getter and refresher getter from runner. + // These are lazy accessors that check availability at request time, + // so they're always non-nil. Actual availability is verified when processing requests. storageGetter := runner.GetUpstreamTokenStorage() + refresherGetter := runner.GetUpstreamTokenRefresher() - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) upstreamSwapMw := &Middleware{ middleware: middleware, @@ -141,7 +147,7 @@ func createCustomInjector(headerName string) injectionFunc { } // createMiddlewareFunc creates the actual middleware function. -func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.MiddlewareFunction { +func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter, refresherGetter RefresherGetter) types.MiddlewareFunction { // Determine injection strategy at startup time strategy := cfg.HeaderStrategy if strategy == "" { @@ -188,13 +194,13 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle return } - // 4. Lookup upstream tokens - tokens, err := stor.GetUpstreamTokens(r.Context(), tsid) + // 4. Lookup upstream tokens, refreshing if expired + tokens, err := getOrRefreshUpstreamTokens(r.Context(), stor, tsid, refresherGetter) if err != nil { slog.Warn("Failed to get upstream tokens", "middleware", "upstreamswap", "error", err) - // Token is expired, was not found, or failed binding validation - // (e.g., subject/client mismatch). All three are client-attributable + // Token is expired (refresh failed), was not found, or failed binding + // validation (e.g., subject/client mismatch). All three are client-attributable // errors that require the caller to re-authenticate with the upstream IdP. if errors.Is(err, storage.ErrExpired) || errors.Is(err, storage.ErrNotFound) || @@ -207,16 +213,6 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle return } - // 5. Check if expired - // Defense in depth: some storage implementations may return tokens - // without checking expiry (the interface does not require it). - if tokens.IsExpired(time.Now()) { - slog.Warn("Upstream tokens expired", - "middleware", "upstreamswap") - writeUpstreamAuthRequired(w) - return - } - // 6. Inject access token if tokens.AccessToken == "" { slog.Warn("Access token is empty", @@ -233,3 +229,71 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle }) } } + +// getOrRefreshUpstreamTokens retrieves upstream tokens from storage, automatically +// refreshing them if expired and a refresh token is available. +func getOrRefreshUpstreamTokens( + ctx context.Context, + stor storage.UpstreamTokenStorage, + sessionID string, + refresherGetter RefresherGetter, +) (*storage.UpstreamTokens, error) { + tokens, err := stor.GetUpstreamTokens(ctx, sessionID) + if err != nil { + // ErrExpired returns tokens (including refresh token) alongside the error. + // Attempt a refresh before giving up. + if errors.Is(err, storage.ErrExpired) && tokens != nil { + if refreshed := tryRefreshUpstreamTokens(ctx, sessionID, tokens, refresherGetter); refreshed != nil { + return refreshed, nil + } + } + return nil, err + } + + // Defense in depth: some storage implementations may return tokens + // without checking expiry (the interface does not require it). + if tokens.IsExpired(time.Now()) { + if refreshed := tryRefreshUpstreamTokens(ctx, sessionID, tokens, refresherGetter); refreshed != nil { + return refreshed, nil + } + return nil, storage.ErrExpired + } + + return tokens, nil +} + +// tryRefreshUpstreamTokens attempts to refresh expired upstream tokens using the +// configured refresher. Returns the refreshed tokens on success, or nil on failure. +func tryRefreshUpstreamTokens( + ctx context.Context, + sessionID string, + expired *storage.UpstreamTokens, + refresherGetter RefresherGetter, +) *storage.UpstreamTokens { + if expired.RefreshToken == "" { + slog.Debug("No refresh token available, cannot refresh upstream tokens", + "middleware", "upstreamswap") + return nil + } + + if refresherGetter == nil { + return nil + } + refresher := refresherGetter() + if refresher == nil { + slog.Debug("Token refresher unavailable, cannot refresh upstream tokens", + "middleware", "upstreamswap") + return nil + } + + refreshed, err := refresher.RefreshAndStore(ctx, sessionID, expired) + if err != nil { + slog.Warn("Upstream token refresh failed", + "middleware", "upstreamswap", "error", err) + return nil + } + + slog.Debug("Successfully refreshed upstream tokens", + "middleware", "upstreamswap") + return refreshed +} diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index 85f66c31c4..b48bf775f8 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -99,7 +99,7 @@ func TestMiddleware_NoIdentity(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -131,7 +131,7 @@ func TestMiddleware_NoTsidClaim(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -166,7 +166,7 @@ func TestMiddleware_StorageUnavailable(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -222,7 +222,7 @@ func TestMiddleware_ClientAttributableStorageErrors_Returns401(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -268,7 +268,7 @@ func TestMiddleware_StorageError(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -319,7 +319,7 @@ func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { cfg := &Config{ HeaderStrategy: HeaderStrategyReplace, } - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var capturedAuthHeader string nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -369,7 +369,7 @@ func TestMiddleware_CustomHeader(t *testing.T) { HeaderStrategy: HeaderStrategyCustom, CustomHeaderName: "X-Upstream-Token", } - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var capturedCustomHeader string var capturedAuthHeader string @@ -422,7 +422,7 @@ func TestMiddleware_ExpiredTokens_Returns401(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -473,7 +473,7 @@ func TestMiddleware_EmptySelectedToken(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool var capturedAuthHeader string @@ -563,7 +563,7 @@ func TestMiddlewareWithContext(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) // Test that context is properly passed through var receivedCtx context.Context @@ -677,6 +677,9 @@ func TestCreateMiddleware(t *testing.T) { mockRunner.EXPECT().GetUpstreamTokenStorage().Return(func() storage.UpstreamTokenStorage { return nil // Storage availability is checked at request time }) + mockRunner.EXPECT().GetUpstreamTokenRefresher().Return(func() storage.UpstreamTokenRefresher { + return nil // Refresher availability is checked at request time + }) mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(_ string, mw types.Middleware) { _, ok := mw.(*Middleware) assert.True(t, ok, "Expected middleware to be of type *upstreamswap.Middleware") @@ -738,7 +741,7 @@ func TestMiddleware_TsidClaimWrongType(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { diff --git a/pkg/authserver/refresher.go b/pkg/authserver/refresher.go new file mode 100644 index 0000000000..b4093fcbe0 --- /dev/null +++ b/pkg/authserver/refresher.go @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package authserver + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/stacklok/toolhive/pkg/authserver/storage" + "github.com/stacklok/toolhive/pkg/authserver/upstream" +) + +// upstreamTokenRefresher implements storage.UpstreamTokenRefresher by wrapping +// an upstream OAuth2Provider (for token refresh) and UpstreamTokenStorage (for +// persisting the refreshed tokens). +type upstreamTokenRefresher struct { + provider upstream.OAuth2Provider + storage storage.UpstreamTokenStorage +} + +// RefreshAndStore refreshes expired upstream tokens using the stored refresh token, +// persists the new tokens, and returns them. +func (r *upstreamTokenRefresher) RefreshAndStore( + ctx context.Context, + sessionID string, + expired *storage.UpstreamTokens, +) (*storage.UpstreamTokens, error) { + if expired == nil { + return nil, errors.New("expired tokens are required") + } + if expired.RefreshToken == "" { + return nil, errors.New("no refresh token available for upstream token refresh") + } + + slog.Debug("attempting upstream token refresh", + "session_id", sessionID, + "provider_id", expired.ProviderID, + ) + + // Refresh tokens via the upstream provider + newTokens, err := r.provider.RefreshTokens(ctx, expired.RefreshToken, expired.UpstreamSubject) + if err != nil { + return nil, fmt.Errorf("upstream token refresh failed: %w", err) + } + + // Build updated storage tokens preserving binding fields from the original + updated := &storage.UpstreamTokens{ + ProviderID: expired.ProviderID, + AccessToken: newTokens.AccessToken, + RefreshToken: newTokens.RefreshToken, + IDToken: newTokens.IDToken, + ExpiresAt: newTokens.ExpiresAt, + UserID: expired.UserID, + UpstreamSubject: expired.UpstreamSubject, + ClientID: expired.ClientID, + } + + // If the provider didn't rotate the refresh token, keep the original + if updated.RefreshToken == "" { + updated.RefreshToken = expired.RefreshToken + } + + // Store the refreshed tokens + if err := r.storage.StoreUpstreamTokens(ctx, sessionID, updated); err != nil { + // Log but still return the refreshed tokens — the current request can + // proceed even if storage fails. The next request will retry the refresh. + slog.Warn("failed to store refreshed upstream tokens", + "session_id", sessionID, + "error", err, + ) + return updated, nil + } + + slog.Debug("upstream tokens refreshed successfully", + "session_id", sessionID, + "provider_id", expired.ProviderID, + ) + + return updated, nil +} diff --git a/pkg/authserver/runner/embeddedauthserver.go b/pkg/authserver/runner/embeddedauthserver.go index 91abdc82cd..0643481c4e 100644 --- a/pkg/authserver/runner/embeddedauthserver.go +++ b/pkg/authserver/runner/embeddedauthserver.go @@ -135,6 +135,12 @@ func (e *EmbeddedAuthServer) IDPTokenStorage() storage.UpstreamTokenStorage { return e.server.IDPTokenStorage() } +// UpstreamTokenRefresher returns a refresher that can refresh expired upstream +// tokens using the upstream provider's refresh token grant. +func (e *EmbeddedAuthServer) UpstreamTokenRefresher() storage.UpstreamTokenRefresher { + return e.server.UpstreamTokenRefresher() +} + // createKeyProvider creates a KeyProvider from SigningKeyRunConfig. // Returns a GeneratingProvider if config is nil or empty (development mode). func createKeyProvider(cfg *authserver.SigningKeyRunConfig) (keys.KeyProvider, error) { diff --git a/pkg/authserver/server.go b/pkg/authserver/server.go index 1c8286ab6a..aef969fec8 100644 --- a/pkg/authserver/server.go +++ b/pkg/authserver/server.go @@ -31,6 +31,11 @@ type Server interface { // Returns nil if no upstream IDP is configured. IDPTokenStorage() storage.UpstreamTokenStorage + // UpstreamTokenRefresher returns a refresher that can refresh expired upstream + // tokens using the upstream provider's refresh token grant. + // Returns nil if no upstream IDP is configured. + UpstreamTokenRefresher() storage.UpstreamTokenRefresher + // Close releases resources held by the server. Close() error } diff --git a/pkg/authserver/server_impl.go b/pkg/authserver/server_impl.go index 92276e05e6..1438d9c71b 100644 --- a/pkg/authserver/server_impl.go +++ b/pkg/authserver/server_impl.go @@ -161,6 +161,18 @@ func (s *server) IDPTokenStorage() storage.UpstreamTokenStorage { return s.storage } +// UpstreamTokenRefresher returns a refresher that wraps the upstream provider +// and storage to transparently refresh expired upstream tokens. +func (s *server) UpstreamTokenRefresher() storage.UpstreamTokenRefresher { + if s.upstreamIDP == nil { + return nil + } + return &upstreamTokenRefresher{ + provider: s.upstreamIDP, + storage: s.storage, + } +} + // Close releases resources held by the server. func (s *server) Close() error { slog.Debug("closing OAuth authorization server") diff --git a/pkg/authserver/storage/memory.go b/pkg/authserver/storage/memory.go index 29402eb069..3423d0e81a 100644 --- a/pkg/authserver/storage/memory.go +++ b/pkg/authserver/storage/memory.go @@ -693,11 +693,13 @@ func (s *MemoryStorage) StoreUpstreamTokens(_ context.Context, sessionID string, defer s.mu.Unlock() now := time.Now() + // Add DefaultRefreshTokenTTL beyond access token expiry so the refresh token + // survives in storage for transparent token refresh by the middleware. var expiresAt time.Time if tokens != nil && !tokens.ExpiresAt.IsZero() { - expiresAt = tokens.ExpiresAt + expiresAt = tokens.ExpiresAt.Add(DefaultRefreshTokenTTL) } else { - expiresAt = now.Add(DefaultAccessTokenTTL) + expiresAt = now.Add(DefaultAccessTokenTTL + DefaultRefreshTokenTTL) } // Make a defensive copy to prevent aliasing issues @@ -735,18 +737,12 @@ func (s *MemoryStorage) GetUpstreamTokens(_ context.Context, sessionID string) ( return nil, fmt.Errorf("%w: %w", ErrNotFound, fosite.ErrNotFound.WithHint("Upstream tokens not found")) } - // Check if expired - if time.Now().After(entry.expiresAt) { - slog.Debug("upstream tokens expired", "session_id", sessionID) - return nil, ErrExpired - } - // Return a defensive copy to prevent aliasing issues tokens := entry.value if tokens == nil { return nil, nil } - return &UpstreamTokens{ + result := &UpstreamTokens{ ProviderID: tokens.ProviderID, AccessToken: tokens.AccessToken, RefreshToken: tokens.RefreshToken, @@ -755,7 +751,17 @@ func (s *MemoryStorage) GetUpstreamTokens(_ context.Context, sessionID string) ( UserID: tokens.UserID, UpstreamSubject: tokens.UpstreamSubject, ClientID: tokens.ClientID, - }, nil + } + + // Check the token's own ExpiresAt (access token expiry), not the entry's expiresAt + // (storage TTL which includes DefaultRefreshTokenTTL buffer for refresh token survival). + // Return tokens along with ErrExpired so callers can use the refresh token. + if !result.ExpiresAt.IsZero() && time.Now().After(result.ExpiresAt) { + slog.Debug("upstream tokens expired", "session_id", sessionID) + return result, ErrExpired + } + + return result, nil } // DeleteUpstreamTokens removes the upstream IDP tokens for a session. diff --git a/pkg/authserver/storage/memory_test.go b/pkg/authserver/storage/memory_test.go index 5546931a14..85cc839f33 100644 --- a/pkg/authserver/storage/memory_test.go +++ b/pkg/authserver/storage/memory_test.go @@ -473,17 +473,22 @@ func TestMemoryStorage_UpstreamTokens(t *testing.T) { }) }) - t.Run("get expired tokens returns ErrExpired", func(t *testing.T) { + t.Run("get expired tokens returns ErrExpired with token data", func(t *testing.T) { withStorage(t, func(ctx context.Context, s *MemoryStorage) { require.NoError(t, s.StoreUpstreamTokens(ctx, "expired", &UpstreamTokens{ - AccessToken: "expired-token", ExpiresAt: time.Now().Add(-time.Hour), + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), })) assert.Equal(t, 1, s.Stats().UpstreamTokens) retrieved, err := s.GetUpstreamTokens(ctx, "expired") require.Error(t, err) assert.ErrorIs(t, err, ErrExpired) - assert.Nil(t, retrieved) + // Tokens should be returned alongside ErrExpired for refresh purposes + require.NotNil(t, retrieved) + assert.Equal(t, "expired-token", retrieved.AccessToken) + assert.Equal(t, "refresh-token", retrieved.RefreshToken) }) }) } @@ -637,7 +642,8 @@ func TestMemoryStorage_CleanupExpired(t *testing.T) { { name: "upstream tokens", setup: func(ctx context.Context, s *MemoryStorage) { - _ = s.StoreUpstreamTokens(ctx, "expired", &UpstreamTokens{AccessToken: "exp", ExpiresAt: time.Now().Add(-time.Hour)}) + // Entry must be older than DefaultRefreshTokenTTL past access token expiry to be cleaned up + _ = s.StoreUpstreamTokens(ctx, "expired", &UpstreamTokens{AccessToken: "exp", ExpiresAt: time.Now().Add(-DefaultRefreshTokenTTL - time.Hour)}) _ = s.StoreUpstreamTokens(ctx, "valid", &UpstreamTokens{AccessToken: "val", ExpiresAt: time.Now().Add(time.Hour)}) }, getStats: func(st Stats) int { return st.UpstreamTokens }, diff --git a/pkg/authserver/storage/mocks/mock_storage.go b/pkg/authserver/storage/mocks/mock_storage.go index 42d41949f8..278a1be587 100644 --- a/pkg/authserver/storage/mocks/mock_storage.go +++ b/pkg/authserver/storage/mocks/mock_storage.go @@ -234,6 +234,45 @@ func (mr *MockUpstreamTokenStorageMockRecorder) StoreUpstreamTokens(ctx, session return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StoreUpstreamTokens", reflect.TypeOf((*MockUpstreamTokenStorage)(nil).StoreUpstreamTokens), ctx, sessionID, tokens) } +// MockUpstreamTokenRefresher is a mock of UpstreamTokenRefresher interface. +type MockUpstreamTokenRefresher struct { + ctrl *gomock.Controller + recorder *MockUpstreamTokenRefresherMockRecorder + isgomock struct{} +} + +// MockUpstreamTokenRefresherMockRecorder is the mock recorder for MockUpstreamTokenRefresher. +type MockUpstreamTokenRefresherMockRecorder struct { + mock *MockUpstreamTokenRefresher +} + +// NewMockUpstreamTokenRefresher creates a new mock instance. +func NewMockUpstreamTokenRefresher(ctrl *gomock.Controller) *MockUpstreamTokenRefresher { + mock := &MockUpstreamTokenRefresher{ctrl: ctrl} + mock.recorder = &MockUpstreamTokenRefresherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUpstreamTokenRefresher) EXPECT() *MockUpstreamTokenRefresherMockRecorder { + return m.recorder +} + +// RefreshAndStore mocks base method. +func (m *MockUpstreamTokenRefresher) RefreshAndStore(ctx context.Context, sessionID string, expired *storage.UpstreamTokens) (*storage.UpstreamTokens, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RefreshAndStore", ctx, sessionID, expired) + ret0, _ := ret[0].(*storage.UpstreamTokens) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RefreshAndStore indicates an expected call of RefreshAndStore. +func (mr *MockUpstreamTokenRefresherMockRecorder) RefreshAndStore(ctx, sessionID, expired any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshAndStore", reflect.TypeOf((*MockUpstreamTokenRefresher)(nil).RefreshAndStore), ctx, sessionID, expired) +} + // MockUserStorage is a mock of UserStorage interface. type MockUserStorage struct { ctrl *gomock.Controller diff --git a/pkg/authserver/storage/redis.go b/pkg/authserver/storage/redis.go index 8ec68f83b0..4b27d84a10 100644 --- a/pkg/authserver/storage/redis.go +++ b/pkg/authserver/storage/redis.go @@ -765,11 +765,13 @@ func marshalUpstreamTokensWithTTL(tokens *UpstreamTokens) ([]byte, time.Duration return nil, 0, fmt.Errorf("failed to marshal upstream tokens: %w", err) } - ttl := DefaultAccessTokenTTL + // Add DefaultRefreshTokenTTL beyond access token expiry so the refresh token + // survives in storage for transparent token refresh by the middleware. + ttl := DefaultAccessTokenTTL + DefaultRefreshTokenTTL if !tokens.ExpiresAt.IsZero() { - ttl = time.Until(tokens.ExpiresAt) + ttl = time.Until(tokens.ExpiresAt) + DefaultRefreshTokenTTL if ttl < 0 { - ttl = DefaultAccessTokenTTL + ttl = DefaultRefreshTokenTTL } } @@ -841,14 +843,13 @@ func (s *RedisStorage) GetUpstreamTokens(ctx context.Context, sessionID string) // stores 0 for zero time. Skip the expiry check in this case since Redis TTL // handles the actual expiration. var expiresAt time.Time + var expired bool if stored.ExpiresAt != 0 { expiresAt = time.Unix(stored.ExpiresAt, 0) - if time.Now().After(expiresAt) { - return nil, ErrExpired - } + expired = time.Now().After(expiresAt) } - return &UpstreamTokens{ + tokens := &UpstreamTokens{ ProviderID: stored.ProviderID, AccessToken: stored.AccessToken, RefreshToken: stored.RefreshToken, @@ -857,7 +858,14 @@ func (s *RedisStorage) GetUpstreamTokens(ctx context.Context, sessionID string) UserID: stored.UserID, UpstreamSubject: stored.UpstreamSubject, ClientID: stored.ClientID, - }, nil + } + + // Return tokens along with ErrExpired so callers can use the refresh token + if expired { + return tokens, ErrExpired + } + + return tokens, nil } // DeleteUpstreamTokens removes the upstream IDP tokens for a session. diff --git a/pkg/authserver/storage/redis_test.go b/pkg/authserver/storage/redis_test.go index db0910c9fe..92f2639b5b 100644 --- a/pkg/authserver/storage/redis_test.go +++ b/pkg/authserver/storage/redis_test.go @@ -622,19 +622,24 @@ func TestRedisStorage_UpstreamTokens(t *testing.T) { }) }) - t.Run("get expired tokens returns ErrExpired", func(t *testing.T) { + t.Run("get expired tokens returns ErrExpired with token data", func(t *testing.T) { withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { // Store with an ExpiresAt that's already in the past. - // The TTL will be set to DefaultAccessTokenTTL, but the stored - // ExpiresAt will be checked and return ErrExpired. + // The TTL includes DefaultRefreshTokenTTL so the key survives + // past access token expiry, allowing refresh token retrieval. require.NoError(t, s.StoreUpstreamTokens(ctx, "expired", &UpstreamTokens{ - AccessToken: "expired-token", ExpiresAt: time.Now().Add(-time.Hour), + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), })) retrieved, err := s.GetUpstreamTokens(ctx, "expired") require.Error(t, err) assert.ErrorIs(t, err, ErrExpired) - assert.Nil(t, retrieved) + // Tokens should be returned alongside ErrExpired for refresh purposes + require.NotNil(t, retrieved) + assert.Equal(t, "expired-token", retrieved.AccessToken) + assert.Equal(t, "refresh-token", retrieved.RefreshToken) }) }) diff --git a/pkg/authserver/storage/types.go b/pkg/authserver/storage/types.go index 308b4d8841..f6c7c89316 100644 --- a/pkg/authserver/storage/types.go +++ b/pkg/authserver/storage/types.go @@ -201,7 +201,9 @@ type UpstreamTokenStorage interface { // GetUpstreamTokens retrieves the upstream IDP tokens for a session. // Returns ErrNotFound if the session does not exist. - // Returns ErrExpired if the tokens have expired. + // Returns ErrExpired if the tokens have expired. When ErrExpired is returned, + // the token data (including refresh token) is also returned to allow callers + // to attempt a token refresh. // Returns ErrInvalidBinding if binding validation fails. GetUpstreamTokens(ctx context.Context, sessionID string) (*UpstreamTokens, error) @@ -210,6 +212,16 @@ type UpstreamTokenStorage interface { DeleteUpstreamTokens(ctx context.Context, sessionID string) error } +// UpstreamTokenRefresher can refresh expired upstream tokens using their stored refresh token. +// This is implemented by the auth server and used by the upstreamswap middleware to +// transparently refresh tokens without forcing re-authentication. +type UpstreamTokenRefresher interface { + // RefreshAndStore refreshes the upstream tokens for the given session using + // the stored refresh token, stores the new tokens, and returns them. + // Returns an error if the refresh token is empty, revoked, or the refresh fails. + RefreshAndStore(ctx context.Context, sessionID string, expired *UpstreamTokens) (*UpstreamTokens, error) +} + // UserStorage provides user and provider identity management operations. // This interface supports multi-IDP scenarios where a single user can authenticate // via multiple upstream identity providers (e.g., Google and GitHub). diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 4b2ee6c859..08a87d1690 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -145,6 +145,18 @@ func (r *Runner) GetUpstreamTokenStorage() func() storage.UpstreamTokenStorage { } } +// GetUpstreamTokenRefresher returns a lazy accessor for the upstream token refresher. +// The returned function should be called at request time; it returns nil if +// the embedded auth server is not configured. +func (r *Runner) GetUpstreamTokenRefresher() func() storage.UpstreamTokenRefresher { + return func() storage.UpstreamTokenRefresher { + if r.embeddedAuthServer == nil { + return nil + } + return r.embeddedAuthServer.UpstreamTokenRefresher() + } +} + // GetName returns the name of the mcp-service from the runner config (implements types.RunnerConfig) func (c *RunConfig) GetName() string { return c.Name diff --git a/pkg/transport/types/mocks/mock_transport.go b/pkg/transport/types/mocks/mock_transport.go index 87ce9602fb..20ad19fb31 100644 --- a/pkg/transport/types/mocks/mock_transport.go +++ b/pkg/transport/types/mocks/mock_transport.go @@ -123,6 +123,20 @@ func (mr *MockMiddlewareRunnerMockRecorder) GetConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetConfig)) } +// GetUpstreamTokenRefresher mocks base method. +func (m *MockMiddlewareRunner) GetUpstreamTokenRefresher() func() storage.UpstreamTokenRefresher { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUpstreamTokenRefresher") + ret0, _ := ret[0].(func() storage.UpstreamTokenRefresher) + return ret0 +} + +// GetUpstreamTokenRefresher indicates an expected call of GetUpstreamTokenRefresher. +func (mr *MockMiddlewareRunnerMockRecorder) GetUpstreamTokenRefresher() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpstreamTokenRefresher", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetUpstreamTokenRefresher)) +} + // GetUpstreamTokenStorage mocks base method. func (m *MockMiddlewareRunner) GetUpstreamTokenStorage() func() storage.UpstreamTokenStorage { m.ctrl.T.Helper() diff --git a/pkg/transport/types/transport.go b/pkg/transport/types/transport.go index 67e8b93e5d..8706aa09f2 100644 --- a/pkg/transport/types/transport.go +++ b/pkg/transport/types/transport.go @@ -86,6 +86,11 @@ type MiddlewareRunner interface { // before the embedded auth server is initialized. Storage availability is // determined at request time when the returned function is called. GetUpstreamTokenStorage() func() storage.UpstreamTokenStorage + + // GetUpstreamTokenRefresher returns a lazy accessor for the upstream token refresher. + // The returned function should be called at request time; it returns nil if + // the embedded auth server is not configured or the refresher is unavailable. + GetUpstreamTokenRefresher() func() storage.UpstreamTokenRefresher } // RunnerConfig defines the config interface needed by middleware to access runner configuration From b0551aa0e4b78a71e90870b1122dc6284767232a Mon Sep 17 00:00:00 2001 From: Aron Gates Date: Fri, 6 Mar 2026 15:08:08 +0000 Subject: [PATCH 2/7] Add unit tests for token refresh paths Add comprehensive tests for RefreshAndStore (6 cases) and middleware refresh paths (4 cases: successful refresh, failed refresh, no refresh token, defense-in-depth expired-without-error). Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/upstreamswap/middleware_test.go | 245 +++++++++++++++++++++++ pkg/authserver/refresher_test.go | 219 ++++++++++++++++++++ 2 files changed, 464 insertions(+) create mode 100644 pkg/authserver/refresher_test.go diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index b48bf775f8..3ca7dfc346 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -767,3 +767,248 @@ func TestMiddleware_TsidClaimWrongType(t *testing.T) { assert.True(t, nextCalled, "next handler should be called when tsid is wrong type") } + +func TestMiddleware_ExpiredTokens_RefreshSuccess(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access-token", + RefreshToken: "my-refresh-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + refreshedTokens := &storage.UpstreamTokens{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(expiredTokens, storage.ErrExpired) + + mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) + mockRefresher.EXPECT(). + RefreshAndStore(gomock.Any(), "session-123", expiredTokens). + Return(refreshedTokens, nil) + + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } + refresherGetter := func() storage.UpstreamTokenRefresher { + return mockRefresher + } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + var nextCalled bool + var capturedAuthHeader string + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + nextCalled = true + capturedAuthHeader = r.Header.Get("Authorization") + }) + + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.True(t, nextCalled, "next handler should be called after successful refresh") + assert.Equal(t, "Bearer new-access-token", capturedAuthHeader) +} + +func TestMiddleware_ExpiredTokens_RefreshFails(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access-token", + RefreshToken: "my-refresh-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(expiredTokens, storage.ErrExpired) + + mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) + mockRefresher.EXPECT(). + RefreshAndStore(gomock.Any(), "session-123", expiredTokens). + Return(nil, errors.New("refresh failed")) + + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } + refresherGetter := func() storage.UpstreamTokenRefresher { + return mockRefresher + } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.False(t, nextCalled, "next handler should NOT be called when refresh fails") + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) +} + +func TestMiddleware_ExpiredTokens_NoRefreshToken(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access-token", + RefreshToken: "", // No refresh token + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(expiredTokens, storage.ErrExpired) + + // No refresher mock needed — refresh should not be attempted + + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } + refresherGetter := func() storage.UpstreamTokenRefresher { + t.Fatal("refresher getter should not be called when there is no refresh token") + return nil + } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.False(t, nextCalled, "next handler should NOT be called when no refresh token available") + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) +} + +func TestMiddleware_DefenseInDepth_ExpiredButNoError_RefreshSuccess(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Storage returns tokens with ExpiresAt in the past but NO error + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access-token", + RefreshToken: "my-refresh-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + refreshedTokens := &storage.UpstreamTokens{ + AccessToken: "refreshed-access-token", + RefreshToken: "refreshed-refresh-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(expiredTokens, nil) // No error, but token is expired + + mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) + mockRefresher.EXPECT(). + RefreshAndStore(gomock.Any(), "session-123", expiredTokens). + Return(refreshedTokens, nil) + + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } + refresherGetter := func() storage.UpstreamTokenRefresher { + return mockRefresher + } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + var nextCalled bool + var capturedAuthHeader string + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + nextCalled = true + capturedAuthHeader = r.Header.Get("Authorization") + }) + + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.True(t, nextCalled, "next handler should be called after successful defense-in-depth refresh") + assert.Equal(t, "Bearer refreshed-access-token", capturedAuthHeader) +} diff --git a/pkg/authserver/refresher_test.go b/pkg/authserver/refresher_test.go new file mode 100644 index 0000000000..6a85538924 --- /dev/null +++ b/pkg/authserver/refresher_test.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package authserver + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/authserver/storage" + storagemocks "github.com/stacklok/toolhive/pkg/authserver/storage/mocks" + "github.com/stacklok/toolhive/pkg/authserver/upstream" + upstreammocks "github.com/stacklok/toolhive/pkg/authserver/upstream/mocks" +) + +func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { + t.Parallel() + + newExpiry := time.Now().Add(1 * time.Hour) + + baseExpired := &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "old-access", + RefreshToken: "old-refresh", + IDToken: "old-id-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + UserID: "user-123", + UpstreamSubject: "upstream-sub-456", + ClientID: "client-abc", + } + + tests := []struct { + name string + sessionID string + expired *storage.UpstreamTokens + setupProvider func(*upstreammocks.MockOAuth2Provider) + setupStorage func(*storagemocks.MockUpstreamTokenStorage) + wantErr bool + wantErrContain string + checkResult func(*testing.T, *storage.UpstreamTokens) + }{ + { + name: "successful refresh with token rotation", + sessionID: "session-1", + expired: baseExpired, + setupProvider: func(p *upstreammocks.MockOAuth2Provider) { + p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). + Return(&upstream.Tokens{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + IDToken: "new-id-token", + ExpiresAt: newExpiry, + }, nil) + }, + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-1", gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, tokens *storage.UpstreamTokens) error { + // Verify binding fields are preserved from expired tokens + assert.Equal(t, "github", tokens.ProviderID) + assert.Equal(t, "user-123", tokens.UserID) + assert.Equal(t, "upstream-sub-456", tokens.UpstreamSubject) + assert.Equal(t, "client-abc", tokens.ClientID) + // Verify new token values + assert.Equal(t, "new-access", tokens.AccessToken) + assert.Equal(t, "new-refresh", tokens.RefreshToken) + assert.Equal(t, "new-id-token", tokens.IDToken) + assert.Equal(t, newExpiry, tokens.ExpiresAt) + return nil + }) + }, + checkResult: func(t *testing.T, result *storage.UpstreamTokens) { + t.Helper() + assert.Equal(t, "new-access", result.AccessToken) + assert.Equal(t, "new-refresh", result.RefreshToken) + assert.Equal(t, "new-id-token", result.IDToken) + assert.Equal(t, newExpiry, result.ExpiresAt) + // Binding fields preserved + assert.Equal(t, "github", result.ProviderID) + assert.Equal(t, "user-123", result.UserID) + assert.Equal(t, "upstream-sub-456", result.UpstreamSubject) + assert.Equal(t, "client-abc", result.ClientID) + }, + }, + { + name: "provider does not rotate refresh token - keeps old one", + sessionID: "session-2", + expired: baseExpired, + setupProvider: func(p *upstreammocks.MockOAuth2Provider) { + p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). + Return(&upstream.Tokens{ + AccessToken: "new-access", + RefreshToken: "", // Provider did not rotate + IDToken: "new-id-token", + ExpiresAt: newExpiry, + }, nil) + }, + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-2", gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, tokens *storage.UpstreamTokens) error { + assert.Equal(t, "old-refresh", tokens.RefreshToken) + return nil + }) + }, + checkResult: func(t *testing.T, result *storage.UpstreamTokens) { + t.Helper() + assert.Equal(t, "new-access", result.AccessToken) + assert.Equal(t, "old-refresh", result.RefreshToken) + }, + }, + { + name: "nil expired tokens returns error", + sessionID: "session-3", + expired: nil, + setupProvider: func(_ *upstreammocks.MockOAuth2Provider) {}, + setupStorage: func(_ *storagemocks.MockUpstreamTokenStorage) {}, + wantErr: true, + wantErrContain: "expired tokens are required", + }, + { + name: "empty refresh token returns error", + sessionID: "session-4", + expired: &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "old-access", + RefreshToken: "", + UserID: "user-123", + UpstreamSubject: "upstream-sub-456", + ClientID: "client-abc", + }, + setupProvider: func(_ *upstreammocks.MockOAuth2Provider) {}, + setupStorage: func(_ *storagemocks.MockUpstreamTokenStorage) {}, + wantErr: true, + wantErrContain: "no refresh token available", + }, + { + name: "provider refresh fails returns error", + sessionID: "session-5", + expired: baseExpired, + setupProvider: func(p *upstreammocks.MockOAuth2Provider) { + p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). + Return(nil, errors.New("upstream IDP unavailable")) + }, + setupStorage: func(_ *storagemocks.MockUpstreamTokenStorage) {}, + wantErr: true, + wantErrContain: "upstream token refresh failed", + }, + { + name: "storage fails after refresh - returns refreshed tokens anyway", + sessionID: "session-6", + expired: baseExpired, + setupProvider: func(p *upstreammocks.MockOAuth2Provider) { + p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). + Return(&upstream.Tokens{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + IDToken: "new-id-token", + ExpiresAt: newExpiry, + }, nil) + }, + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-6", gomock.Any()). + Return(errors.New("redis connection lost")) + }, + checkResult: func(t *testing.T, result *storage.UpstreamTokens) { + t.Helper() + // Tokens should still be returned despite storage failure + assert.Equal(t, "new-access", result.AccessToken) + assert.Equal(t, "new-refresh", result.RefreshToken) + assert.Equal(t, "new-id-token", result.IDToken) + assert.Equal(t, newExpiry, result.ExpiresAt) + // Binding fields preserved + assert.Equal(t, "github", result.ProviderID) + assert.Equal(t, "user-123", result.UserID) + assert.Equal(t, "upstream-sub-456", result.UpstreamSubject) + assert.Equal(t, "client-abc", result.ClientID) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockProvider := upstreammocks.NewMockOAuth2Provider(ctrl) + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + + tt.setupProvider(mockProvider) + tt.setupStorage(mockStorage) + + refresher := &upstreamTokenRefresher{ + provider: mockProvider, + storage: mockStorage, + } + + result, err := refresher.RefreshAndStore(context.Background(), tt.sessionID, tt.expired) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErrContain) + assert.Nil(t, result) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + if tt.checkResult != nil { + tt.checkResult(t, result) + } + }) + } +} From 5100d714bbe2b3a0bf99e45a287260d6b5fdfe51 Mon Sep 17 00:00:00 2001 From: Aron Date: Mon, 9 Mar 2026 10:46:51 +0000 Subject: [PATCH 3/7] Update pkg/authserver/refresher.go Co-authored-by: Jakub Hrozek --- pkg/authserver/refresher.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/authserver/refresher.go b/pkg/authserver/refresher.go index b4093fcbe0..709a2d7bcc 100644 --- a/pkg/authserver/refresher.go +++ b/pkg/authserver/refresher.go @@ -21,6 +21,9 @@ type upstreamTokenRefresher struct { storage storage.UpstreamTokenStorage } +// Compile-time check that upstreamTokenRefresher implements storage.UpstreamTokenRefresher. +var _ storage.UpstreamTokenRefresher = (*upstreamTokenRefresher)(nil) + // RefreshAndStore refreshes expired upstream tokens using the stored refresh token, // persists the new tokens, and returns them. func (r *upstreamTokenRefresher) RefreshAndStore( From b21489dba52d883f4bed5dc33b65e6e9e6ac0dfd Mon Sep 17 00:00:00 2001 From: Aron Date: Mon, 9 Mar 2026 10:47:03 +0000 Subject: [PATCH 4/7] Update pkg/auth/upstreamswap/middleware.go Co-authored-by: Jakub Hrozek --- pkg/auth/upstreamswap/middleware.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/auth/upstreamswap/middleware.go b/pkg/auth/upstreamswap/middleware.go index 1bf80208ef..6a39c262eb 100644 --- a/pkg/auth/upstreamswap/middleware.go +++ b/pkg/auth/upstreamswap/middleware.go @@ -252,7 +252,7 @@ func getOrRefreshUpstreamTokens( // Defense in depth: some storage implementations may return tokens // without checking expiry (the interface does not require it). - if tokens.IsExpired(time.Now()) { + if !tokens.ExpiresAt.IsZero() && tokens.IsExpired(time.Now()) { if refreshed := tryRefreshUpstreamTokens(ctx, sessionID, tokens, refresherGetter); refreshed != nil { return refreshed, nil } From 595caccbf9f7e694b2d9f0818afbce4f1df75d3a Mon Sep 17 00:00:00 2001 From: Aron Gates Date: Mon, 9 Mar 2026 10:54:47 +0000 Subject: [PATCH 5/7] Address PR review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix step numbering: renumber step 6 → 5 after step 5 removal - Update redis integration test: assert returned token data is non-nil on ErrExpired, consistent with the unit test contract - Fix test closures: pass subtest t to setupStorage/setupProvider to ensure assertion failures are attributed to the correct subtest Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/upstreamswap/middleware.go | 2 +- pkg/authserver/refresher_test.go | 32 +++++++++---------- .../storage/redis_integration_test.go | 12 +++++-- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/pkg/auth/upstreamswap/middleware.go b/pkg/auth/upstreamswap/middleware.go index 6a39c262eb..7705ec8ee1 100644 --- a/pkg/auth/upstreamswap/middleware.go +++ b/pkg/auth/upstreamswap/middleware.go @@ -213,7 +213,7 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter, refresherGet return } - // 6. Inject access token + // 5. Inject access token if tokens.AccessToken == "" { slog.Warn("Access token is empty", "middleware", "upstreamswap") diff --git a/pkg/authserver/refresher_test.go b/pkg/authserver/refresher_test.go index 6a85538924..ffd9436f2c 100644 --- a/pkg/authserver/refresher_test.go +++ b/pkg/authserver/refresher_test.go @@ -39,8 +39,8 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { name string sessionID string expired *storage.UpstreamTokens - setupProvider func(*upstreammocks.MockOAuth2Provider) - setupStorage func(*storagemocks.MockUpstreamTokenStorage) + setupProvider func(*testing.T, *upstreammocks.MockOAuth2Provider) + setupStorage func(*testing.T, *storagemocks.MockUpstreamTokenStorage) wantErr bool wantErrContain string checkResult func(*testing.T, *storage.UpstreamTokens) @@ -49,7 +49,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { name: "successful refresh with token rotation", sessionID: "session-1", expired: baseExpired, - setupProvider: func(p *upstreammocks.MockOAuth2Provider) { + setupProvider: func(_ *testing.T, p *upstreammocks.MockOAuth2Provider) { p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). Return(&upstream.Tokens{ AccessToken: "new-access", @@ -58,7 +58,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { ExpiresAt: newExpiry, }, nil) }, - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + setupStorage: func(t *testing.T, s *storagemocks.MockUpstreamTokenStorage) { s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-1", gomock.Any()). DoAndReturn(func(_ context.Context, _ string, tokens *storage.UpstreamTokens) error { // Verify binding fields are preserved from expired tokens @@ -91,7 +91,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { name: "provider does not rotate refresh token - keeps old one", sessionID: "session-2", expired: baseExpired, - setupProvider: func(p *upstreammocks.MockOAuth2Provider) { + setupProvider: func(_ *testing.T, p *upstreammocks.MockOAuth2Provider) { p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). Return(&upstream.Tokens{ AccessToken: "new-access", @@ -100,7 +100,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { ExpiresAt: newExpiry, }, nil) }, - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + setupStorage: func(t *testing.T, s *storagemocks.MockUpstreamTokenStorage) { s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-2", gomock.Any()). DoAndReturn(func(_ context.Context, _ string, tokens *storage.UpstreamTokens) error { assert.Equal(t, "old-refresh", tokens.RefreshToken) @@ -117,8 +117,8 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { name: "nil expired tokens returns error", sessionID: "session-3", expired: nil, - setupProvider: func(_ *upstreammocks.MockOAuth2Provider) {}, - setupStorage: func(_ *storagemocks.MockUpstreamTokenStorage) {}, + setupProvider: func(_ *testing.T, _ *upstreammocks.MockOAuth2Provider) {}, + setupStorage: func(_ *testing.T, _ *storagemocks.MockUpstreamTokenStorage) {}, wantErr: true, wantErrContain: "expired tokens are required", }, @@ -133,8 +133,8 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { UpstreamSubject: "upstream-sub-456", ClientID: "client-abc", }, - setupProvider: func(_ *upstreammocks.MockOAuth2Provider) {}, - setupStorage: func(_ *storagemocks.MockUpstreamTokenStorage) {}, + setupProvider: func(_ *testing.T, _ *upstreammocks.MockOAuth2Provider) {}, + setupStorage: func(_ *testing.T, _ *storagemocks.MockUpstreamTokenStorage) {}, wantErr: true, wantErrContain: "no refresh token available", }, @@ -142,11 +142,11 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { name: "provider refresh fails returns error", sessionID: "session-5", expired: baseExpired, - setupProvider: func(p *upstreammocks.MockOAuth2Provider) { + setupProvider: func(_ *testing.T, p *upstreammocks.MockOAuth2Provider) { p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). Return(nil, errors.New("upstream IDP unavailable")) }, - setupStorage: func(_ *storagemocks.MockUpstreamTokenStorage) {}, + setupStorage: func(_ *testing.T, _ *storagemocks.MockUpstreamTokenStorage) {}, wantErr: true, wantErrContain: "upstream token refresh failed", }, @@ -154,7 +154,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { name: "storage fails after refresh - returns refreshed tokens anyway", sessionID: "session-6", expired: baseExpired, - setupProvider: func(p *upstreammocks.MockOAuth2Provider) { + setupProvider: func(_ *testing.T, p *upstreammocks.MockOAuth2Provider) { p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). Return(&upstream.Tokens{ AccessToken: "new-access", @@ -163,7 +163,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { ExpiresAt: newExpiry, }, nil) }, - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + setupStorage: func(t *testing.T, s *storagemocks.MockUpstreamTokenStorage) { s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-6", gomock.Any()). Return(errors.New("redis connection lost")) }, @@ -192,8 +192,8 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { mockProvider := upstreammocks.NewMockOAuth2Provider(ctrl) mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - tt.setupProvider(mockProvider) - tt.setupStorage(mockStorage) + tt.setupProvider(t, mockProvider) + tt.setupStorage(t, mockStorage) refresher := &upstreamTokenRefresher{ provider: mockProvider, diff --git a/pkg/authserver/storage/redis_integration_test.go b/pkg/authserver/storage/redis_integration_test.go index 86c8f2474e..aa955acef3 100644 --- a/pkg/authserver/storage/redis_integration_test.go +++ b/pkg/authserver/storage/redis_integration_test.go @@ -659,13 +659,19 @@ func TestIntegration_UpstreamTokens(t *testing.T) { }) }) - t.Run("expired tokens return ErrExpired", func(t *testing.T) { + t.Run("expired tokens return ErrExpired with token data", func(t *testing.T) { withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { require.NoError(t, s.StoreUpstreamTokens(ctx, "sess-exp", &UpstreamTokens{ - AccessToken: "expired", ExpiresAt: time.Now().Add(-time.Hour), + AccessToken: "expired-token", + RefreshToken: "expired-refresh", + ExpiresAt: time.Now().Add(-time.Hour), })) - _, err := s.GetUpstreamTokens(ctx, "sess-exp") + tokens, err := s.GetUpstreamTokens(ctx, "sess-exp") assert.ErrorIs(t, err, ErrExpired) + // Expired tokens should still return the data (needed for refresh) + require.NotNil(t, tokens, "expired tokens should return data for refresh") + assert.Equal(t, "expired-token", tokens.AccessToken) + assert.Equal(t, "expired-refresh", tokens.RefreshToken) }) }) From 2f1c6dd2e45c5d4a68d77ac8290ac2172a38a768 Mon Sep 17 00:00:00 2001 From: Aron Gates Date: Mon, 9 Mar 2026 11:19:36 +0000 Subject: [PATCH 6/7] Add singleflight for concurrent refresh deduplication and fix lint Wrap upstream token refresh in singleflight.Group keyed on sessionID to collapse concurrent refreshes into one call. Prevents providers with single-use refresh tokens from failing all but the first concurrent caller. Added TestSingleFlightRefresh_ConcurrentRequests to verify the fix. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/upstreamswap/middleware_test.go | 4 ++-- pkg/authserver/refresher_test.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index 993627a1a1..b3d39214ac 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -1107,11 +1107,11 @@ func (f *fakeTokenStorage) GetUpstreamTokens(_ context.Context, _ string) (*stor return f.tokens, f.err } -func (f *fakeTokenStorage) StoreUpstreamTokens(_ context.Context, _ string, _ *storage.UpstreamTokens) error { +func (*fakeTokenStorage) StoreUpstreamTokens(_ context.Context, _ string, _ *storage.UpstreamTokens) error { return nil } -func (f *fakeTokenStorage) DeleteUpstreamTokens(_ context.Context, _ string) error { +func (*fakeTokenStorage) DeleteUpstreamTokens(_ context.Context, _ string) error { return nil } diff --git a/pkg/authserver/refresher_test.go b/pkg/authserver/refresher_test.go index ffd9436f2c..23819a6f16 100644 --- a/pkg/authserver/refresher_test.go +++ b/pkg/authserver/refresher_test.go @@ -58,7 +58,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { ExpiresAt: newExpiry, }, nil) }, - setupStorage: func(t *testing.T, s *storagemocks.MockUpstreamTokenStorage) { + setupStorage: func(_ *testing.T, s *storagemocks.MockUpstreamTokenStorage) { s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-1", gomock.Any()). DoAndReturn(func(_ context.Context, _ string, tokens *storage.UpstreamTokens) error { // Verify binding fields are preserved from expired tokens @@ -100,7 +100,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { ExpiresAt: newExpiry, }, nil) }, - setupStorage: func(t *testing.T, s *storagemocks.MockUpstreamTokenStorage) { + setupStorage: func(_ *testing.T, s *storagemocks.MockUpstreamTokenStorage) { s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-2", gomock.Any()). DoAndReturn(func(_ context.Context, _ string, tokens *storage.UpstreamTokens) error { assert.Equal(t, "old-refresh", tokens.RefreshToken) @@ -163,7 +163,7 @@ func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { ExpiresAt: newExpiry, }, nil) }, - setupStorage: func(t *testing.T, s *storagemocks.MockUpstreamTokenStorage) { + setupStorage: func(_ *testing.T, s *storagemocks.MockUpstreamTokenStorage) { s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-6", gomock.Any()). Return(errors.New("redis connection lost")) }, From 0a6fb9fa0ee622ba21cbc25ba66f5ece8f959673 Mon Sep 17 00:00:00 2001 From: Aron Date: Mon, 9 Mar 2026 23:57:28 +0000 Subject: [PATCH 7/7] Update middleware.go Co-authored-by: Jakub Hrozek --- pkg/auth/upstreamswap/middleware.go | 7 +++- pkg/auth/upstreamswap/middleware_test.go | 50 ++++++++++++++++++++---- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/pkg/auth/upstreamswap/middleware.go b/pkg/auth/upstreamswap/middleware.go index fee42ec171..ae849af1f3 100644 --- a/pkg/auth/upstreamswap/middleware.go +++ b/pkg/auth/upstreamswap/middleware.go @@ -280,7 +280,12 @@ func doSingleFlightRefresh( refresherGetter RefresherGetter, ) *storage.UpstreamTokens { result, err, _ := sfGroup.Do(sessionID, func() (any, error) { - refreshed := tryRefreshUpstreamTokens(ctx, sessionID, expired, refresherGetter) + // Detach from the triggering request's context so that if the first + // caller disconnects, the refresh still completes for waiting callers. + // The 30s timeout bounds the operation independently from client lifecycle. + refreshCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second) + defer cancel() + refreshed := tryRefreshUpstreamTokens(refreshCtx, sessionID, expired, refresherGetter) if refreshed == nil { return nil, errors.New("refresh failed") } diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index b3d39214ac..13e2442959 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -1037,11 +1037,23 @@ func TestSingleFlightRefresh_ConcurrentRequests(t *testing.T) { ExpiresAt: time.Now().Add(1 * time.Hour), } - // Use a real (non-mock) storage and refresher to avoid gomock concurrency issues - stor := &fakeTokenStorage{tokens: expiredTokens, err: storage.ErrExpired} + // Use a barrier so all goroutines complete GetUpstreamTokens before any + // enters singleflight. This guarantees they all contend on the same key. + var storageBarrier sync.WaitGroup + storageBarrier.Add(numRequests) + storageGate := make(chan struct{}) + + stor := &fakeTokenStorage{ + tokens: expiredTokens, + err: storage.ErrExpired, + barrier: &storageBarrier, + gate: storageGate, + } + proceed := make(chan struct{}) refresher := &fakeRefresher{ result: refreshedTokens, callCount: &refreshCallCount, + proceed: proceed, } storageGetter := func() storage.UpstreamTokenStorage { return stor } @@ -1083,6 +1095,18 @@ func TestSingleFlightRefresh_ConcurrentRequests(t *testing.T) { // Release all goroutines simultaneously close(ready) + + // Wait for ALL goroutines to reach GetUpstreamTokens + storageBarrier.Wait() + + // Release them all at once — they all proceed to singleflight.Do concurrently + close(storageGate) + + // Give goroutines a moment to enter singleflight, then let refresh complete + // The singleflight ensures only one actually calls RefreshAndStore + time.Sleep(10 * time.Millisecond) + close(proceed) + wg.Wait() // KEY ASSERTION: RefreshAndStore should be called exactly once. @@ -1097,13 +1121,23 @@ func TestSingleFlightRefresh_ConcurrentRequests(t *testing.T) { } } -// fakeTokenStorage always returns the configured tokens and error. +// fakeTokenStorage returns configured tokens and optionally blocks until +// a barrier is released, ensuring all goroutines reach storage before any +// proceeds to the singleflight refresh. type fakeTokenStorage struct { - tokens *storage.UpstreamTokens - err error + tokens *storage.UpstreamTokens + err error + barrier *sync.WaitGroup // if set, each call does barrier.Done() then waits + gate chan struct{} // if set, blocks until closed } func (f *fakeTokenStorage) GetUpstreamTokens(_ context.Context, _ string) (*storage.UpstreamTokens, error) { + if f.barrier != nil { + f.barrier.Done() + } + if f.gate != nil { + <-f.gate + } return f.tokens, f.err } @@ -1115,15 +1149,15 @@ func (*fakeTokenStorage) DeleteUpstreamTokens(_ context.Context, _ string) error return nil } -// fakeRefresher counts calls and adds a small delay to allow concurrency overlap. +// fakeRefresher counts calls and blocks until proceed is closed. type fakeRefresher struct { result *storage.UpstreamTokens callCount *atomic.Int32 + proceed chan struct{} } func (f *fakeRefresher) RefreshAndStore(_ context.Context, _ string, _ *storage.UpstreamTokens) (*storage.UpstreamTokens, error) { f.callCount.Add(1) - // Small delay to ensure concurrent goroutines overlap in the singleflight window - time.Sleep(50 * time.Millisecond) + <-f.proceed return f.result, nil }