diff --git a/pkg/auth/upstreamswap/middleware.go b/pkg/auth/upstreamswap/middleware.go index 5af02e099c..ae849af1f3 100644 --- a/pkg/auth/upstreamswap/middleware.go +++ b/pkg/auth/upstreamswap/middleware.go @@ -6,6 +6,7 @@ package upstreamswap import ( + "context" "encoding/json" "errors" "fmt" @@ -13,6 +14,8 @@ import ( "net/http" "time" + "golang.org/x/sync/singleflight" + "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/authserver/server/session" "github.com/stacklok/toolhive/pkg/authserver/storage" @@ -48,6 +51,10 @@ type MiddlewareParams struct { // This allows lazy access to the storage, which may not be available at middleware creation time. type StorageGetter func() storage.UpstreamTokenStorage +// RefresherGetter is a function that returns an upstream token refresher. +// This allows lazy access to the refresher, which may not be available at middleware creation time. +type RefresherGetter func() storage.UpstreamTokenRefresher + // Middleware wraps the upstream swap middleware functionality. type Middleware struct { middleware types.MiddlewareFunction @@ -81,12 +88,13 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("invalid upstream swap configuration: %w", err) } - // Get storage getter from runner. - // The storage getter is a lazy accessor that checks storage availability at request time, - // so it's always non-nil. Actual storage availability is verified when processing requests. + // Get storage getter and refresher getter from runner. + // These are lazy accessors that check availability at request time, + // so they're always non-nil. Actual availability is verified when processing requests. storageGetter := runner.GetUpstreamTokenStorage() + refresherGetter := runner.GetUpstreamTokenRefresher() - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) upstreamSwapMw := &Middleware{ middleware: middleware, @@ -141,13 +149,18 @@ func createCustomInjector(headerName string) injectionFunc { } // createMiddlewareFunc creates the actual middleware function. -func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.MiddlewareFunction { +func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter, refresherGetter RefresherGetter) types.MiddlewareFunction { // Determine injection strategy at startup time strategy := cfg.HeaderStrategy if strategy == "" { strategy = HeaderStrategyReplace } + // Deduplicate concurrent upstream token refresh attempts for the same session. + // Providers that rotate refresh tokens (single-use) would fail all but the + // first concurrent caller without this. + var sfGroup singleflight.Group + var injectToken injectionFunc switch strategy { case HeaderStrategyReplace: @@ -188,13 +201,13 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle return } - // 4. Lookup upstream tokens - tokens, err := stor.GetUpstreamTokens(r.Context(), tsid) + // 4. Lookup upstream tokens, refreshing if expired + tokens, err := getOrRefreshUpstreamTokens(r.Context(), &sfGroup, stor, tsid, refresherGetter) if err != nil { slog.Warn("Failed to get upstream tokens", "middleware", "upstreamswap", "error", err) - // Token is expired, was not found, or failed binding validation - // (e.g., subject/client mismatch). All three are client-attributable + // Token is expired (refresh failed), was not found, or failed binding + // validation (e.g., subject/client mismatch). All three are client-attributable // errors that require the caller to re-authenticate with the upstream IdP. if errors.Is(err, storage.ErrExpired) || errors.Is(err, storage.ErrNotFound) || @@ -207,17 +220,7 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle return } - // 5. Check if expired - // Defense in depth: some storage implementations may return tokens - // without checking expiry (the interface does not require it). - if tokens.IsExpired(time.Now()) { - slog.Warn("Upstream tokens expired", - "middleware", "upstreamswap") - writeUpstreamAuthRequired(w) - return - } - - // 6. Inject access token + // 5. Inject access token if tokens.AccessToken == "" { slog.Warn("Access token is empty", "middleware", "upstreamswap") @@ -233,3 +236,100 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle }) } } + +// getOrRefreshUpstreamTokens retrieves upstream tokens from storage, automatically +// refreshing them if expired and a refresh token is available. +func getOrRefreshUpstreamTokens( + ctx context.Context, + sfGroup *singleflight.Group, + stor storage.UpstreamTokenStorage, + sessionID string, + refresherGetter RefresherGetter, +) (*storage.UpstreamTokens, error) { + tokens, err := stor.GetUpstreamTokens(ctx, sessionID) + if err != nil { + // ErrExpired returns tokens (including refresh token) alongside the error. + // Attempt a refresh before giving up. + if errors.Is(err, storage.ErrExpired) && tokens != nil { + if refreshed := doSingleFlightRefresh(ctx, sfGroup, sessionID, tokens, refresherGetter); refreshed != nil { + return refreshed, nil + } + } + return nil, err + } + + // Defense in depth: some storage implementations may return tokens + // without checking expiry (the interface does not require it). + if !tokens.ExpiresAt.IsZero() && tokens.IsExpired(time.Now()) { + if refreshed := doSingleFlightRefresh(ctx, sfGroup, sessionID, tokens, refresherGetter); refreshed != nil { + return refreshed, nil + } + return nil, storage.ErrExpired + } + + return tokens, nil +} + +// doSingleFlightRefresh wraps tryRefreshUpstreamTokens in a singleflight.Group +// to deduplicate concurrent refresh attempts for the same session. +func doSingleFlightRefresh( + ctx context.Context, + sfGroup *singleflight.Group, + sessionID string, + expired *storage.UpstreamTokens, + refresherGetter RefresherGetter, +) *storage.UpstreamTokens { + result, err, _ := sfGroup.Do(sessionID, func() (any, error) { + // Detach from the triggering request's context so that if the first + // caller disconnects, the refresh still completes for waiting callers. + // The 30s timeout bounds the operation independently from client lifecycle. + refreshCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second) + defer cancel() + refreshed := tryRefreshUpstreamTokens(refreshCtx, sessionID, expired, refresherGetter) + if refreshed == nil { + return nil, errors.New("refresh failed") + } + return refreshed, nil + }) + if err != nil { + return nil + } + tokens, _ := result.(*storage.UpstreamTokens) + return tokens +} + +// tryRefreshUpstreamTokens attempts to refresh expired upstream tokens using the +// configured refresher. Returns the refreshed tokens on success, or nil on failure. +func tryRefreshUpstreamTokens( + ctx context.Context, + sessionID string, + expired *storage.UpstreamTokens, + refresherGetter RefresherGetter, +) *storage.UpstreamTokens { + if expired.RefreshToken == "" { + slog.Debug("No refresh token available, cannot refresh upstream tokens", + "middleware", "upstreamswap") + return nil + } + + if refresherGetter == nil { + return nil + } + refresher := refresherGetter() + if refresher == nil { + slog.Debug("Token refresher unavailable, cannot refresh upstream tokens", + "middleware", "upstreamswap") + return nil + } + + refreshed, err := refresher.RefreshAndStore(ctx, sessionID, expired) + if err != nil { + slog.Warn("Upstream token refresh failed", + "middleware", "upstreamswap", "error", err) + return nil + } + + slog.Debug("Successfully refreshed upstream tokens", + "middleware", "upstreamswap") + return refreshed +} diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index 85f66c31c4..13e2442959 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -9,6 +9,8 @@ import ( "errors" "net/http" "net/http/httptest" + "sync" + "sync/atomic" "testing" "time" @@ -99,7 +101,7 @@ func TestMiddleware_NoIdentity(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -131,7 +133,7 @@ func TestMiddleware_NoTsidClaim(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -166,7 +168,7 @@ func TestMiddleware_StorageUnavailable(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -222,7 +224,7 @@ func TestMiddleware_ClientAttributableStorageErrors_Returns401(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -268,7 +270,7 @@ func TestMiddleware_StorageError(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -319,7 +321,7 @@ func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { cfg := &Config{ HeaderStrategy: HeaderStrategyReplace, } - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var capturedAuthHeader string nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -369,7 +371,7 @@ func TestMiddleware_CustomHeader(t *testing.T) { HeaderStrategy: HeaderStrategyCustom, CustomHeaderName: "X-Upstream-Token", } - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var capturedCustomHeader string var capturedAuthHeader string @@ -422,7 +424,7 @@ func TestMiddleware_ExpiredTokens_Returns401(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -473,7 +475,7 @@ func TestMiddleware_EmptySelectedToken(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool var capturedAuthHeader string @@ -563,7 +565,7 @@ func TestMiddlewareWithContext(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) // Test that context is properly passed through var receivedCtx context.Context @@ -677,6 +679,9 @@ func TestCreateMiddleware(t *testing.T) { mockRunner.EXPECT().GetUpstreamTokenStorage().Return(func() storage.UpstreamTokenStorage { return nil // Storage availability is checked at request time }) + mockRunner.EXPECT().GetUpstreamTokenRefresher().Return(func() storage.UpstreamTokenRefresher { + return nil // Refresher availability is checked at request time + }) mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(_ string, mw types.Middleware) { _, ok := mw.(*Middleware) assert.True(t, ok, "Expected middleware to be of type *upstreamswap.Middleware") @@ -738,7 +743,7 @@ func TestMiddleware_TsidClaimWrongType(t *testing.T) { } cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + middleware := createMiddlewareFunc(cfg, storageGetter, nil) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { @@ -764,3 +769,395 @@ func TestMiddleware_TsidClaimWrongType(t *testing.T) { assert.True(t, nextCalled, "next handler should be called when tsid is wrong type") } + +func TestMiddleware_ExpiredTokens_RefreshSuccess(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access-token", + RefreshToken: "my-refresh-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + refreshedTokens := &storage.UpstreamTokens{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(expiredTokens, storage.ErrExpired) + + mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) + mockRefresher.EXPECT(). + RefreshAndStore(gomock.Any(), "session-123", expiredTokens). + Return(refreshedTokens, nil) + + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } + refresherGetter := func() storage.UpstreamTokenRefresher { + return mockRefresher + } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + var nextCalled bool + var capturedAuthHeader string + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + nextCalled = true + capturedAuthHeader = r.Header.Get("Authorization") + }) + + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.True(t, nextCalled, "next handler should be called after successful refresh") + assert.Equal(t, "Bearer new-access-token", capturedAuthHeader) +} + +func TestMiddleware_ExpiredTokens_RefreshFails(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access-token", + RefreshToken: "my-refresh-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(expiredTokens, storage.ErrExpired) + + mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) + mockRefresher.EXPECT(). + RefreshAndStore(gomock.Any(), "session-123", expiredTokens). + Return(nil, errors.New("refresh failed")) + + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } + refresherGetter := func() storage.UpstreamTokenRefresher { + return mockRefresher + } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.False(t, nextCalled, "next handler should NOT be called when refresh fails") + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) +} + +func TestMiddleware_ExpiredTokens_NoRefreshToken(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access-token", + RefreshToken: "", // No refresh token + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(expiredTokens, storage.ErrExpired) + + // No refresher mock needed — refresh should not be attempted + + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } + refresherGetter := func() storage.UpstreamTokenRefresher { + t.Fatal("refresher getter should not be called when there is no refresh token") + return nil + } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.False(t, nextCalled, "next handler should NOT be called when no refresh token available") + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) +} + +func TestMiddleware_DefenseInDepth_ExpiredButNoError_RefreshSuccess(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Storage returns tokens with ExpiresAt in the past but NO error + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access-token", + RefreshToken: "my-refresh-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + refreshedTokens := &storage.UpstreamTokens{ + AccessToken: "refreshed-access-token", + RefreshToken: "refreshed-refresh-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(expiredTokens, nil) // No error, but token is expired + + mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) + mockRefresher.EXPECT(). + RefreshAndStore(gomock.Any(), "session-123", expiredTokens). + Return(refreshedTokens, nil) + + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } + refresherGetter := func() storage.UpstreamTokenRefresher { + return mockRefresher + } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + var nextCalled bool + var capturedAuthHeader string + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + nextCalled = true + capturedAuthHeader = r.Header.Get("Authorization") + }) + + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.True(t, nextCalled, "next handler should be called after successful defense-in-depth refresh") + assert.Equal(t, "Bearer refreshed-access-token", capturedAuthHeader) +} + +// TestSingleFlightRefresh_ConcurrentRequests verifies that concurrent requests +// with the same expired session only trigger a single upstream refresh call. +// Without singleflight, providers that rotate refresh tokens (single-use) +// would fail all but the first concurrent caller. +func TestSingleFlightRefresh_ConcurrentRequests(t *testing.T) { + t.Parallel() + + const numRequests = 10 + var refreshCallCount atomic.Int32 + + expiredTokens := &storage.UpstreamTokens{ + AccessToken: "expired-access", + RefreshToken: "one-time-refresh", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + refreshedTokens := &storage.UpstreamTokens{ + AccessToken: "fresh-access", + RefreshToken: "new-refresh", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + // Use a barrier so all goroutines complete GetUpstreamTokens before any + // enters singleflight. This guarantees they all contend on the same key. + var storageBarrier sync.WaitGroup + storageBarrier.Add(numRequests) + storageGate := make(chan struct{}) + + stor := &fakeTokenStorage{ + tokens: expiredTokens, + err: storage.ErrExpired, + barrier: &storageBarrier, + gate: storageGate, + } + proceed := make(chan struct{}) + refresher := &fakeRefresher{ + result: refreshedTokens, + callCount: &refreshCallCount, + proceed: proceed, + } + + storageGetter := func() storage.UpstreamTokenStorage { return stor } + refresherGetter := func() storage.UpstreamTokenRefresher { return refresher } + + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter) + + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) + handler := middleware(nextHandler) + + // Use a barrier to ensure all goroutines start at the same time + ready := make(chan struct{}) + var wg sync.WaitGroup + results := make([]int, numRequests) + + for i := range numRequests { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-ready // Wait for all goroutines to be ready + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "sf-concurrent-session", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + results[idx] = rr.Code + }(i) + } + + // Release all goroutines simultaneously + close(ready) + + // Wait for ALL goroutines to reach GetUpstreamTokens + storageBarrier.Wait() + + // Release them all at once — they all proceed to singleflight.Do concurrently + close(storageGate) + + // Give goroutines a moment to enter singleflight, then let refresh complete + // The singleflight ensures only one actually calls RefreshAndStore + time.Sleep(10 * time.Millisecond) + close(proceed) + + wg.Wait() + + // KEY ASSERTION: RefreshAndStore should be called exactly once. + // Without singleflight, all 10 goroutines would call it independently. + assert.Equal(t, int32(1), refreshCallCount.Load(), + "RefreshAndStore should be called exactly once — singleflight deduplicates concurrent refreshes") + + // All requests should succeed + for i, code := range results { + assert.Equal(t, http.StatusOK, code, + "request %d should succeed", i) + } +} + +// fakeTokenStorage returns configured tokens and optionally blocks until +// a barrier is released, ensuring all goroutines reach storage before any +// proceeds to the singleflight refresh. +type fakeTokenStorage struct { + tokens *storage.UpstreamTokens + err error + barrier *sync.WaitGroup // if set, each call does barrier.Done() then waits + gate chan struct{} // if set, blocks until closed +} + +func (f *fakeTokenStorage) GetUpstreamTokens(_ context.Context, _ string) (*storage.UpstreamTokens, error) { + if f.barrier != nil { + f.barrier.Done() + } + if f.gate != nil { + <-f.gate + } + return f.tokens, f.err +} + +func (*fakeTokenStorage) StoreUpstreamTokens(_ context.Context, _ string, _ *storage.UpstreamTokens) error { + return nil +} + +func (*fakeTokenStorage) DeleteUpstreamTokens(_ context.Context, _ string) error { + return nil +} + +// fakeRefresher counts calls and blocks until proceed is closed. +type fakeRefresher struct { + result *storage.UpstreamTokens + callCount *atomic.Int32 + proceed chan struct{} +} + +func (f *fakeRefresher) RefreshAndStore(_ context.Context, _ string, _ *storage.UpstreamTokens) (*storage.UpstreamTokens, error) { + f.callCount.Add(1) + <-f.proceed + return f.result, nil +} diff --git a/pkg/authserver/refresher.go b/pkg/authserver/refresher.go new file mode 100644 index 0000000000..709a2d7bcc --- /dev/null +++ b/pkg/authserver/refresher.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package authserver + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/stacklok/toolhive/pkg/authserver/storage" + "github.com/stacklok/toolhive/pkg/authserver/upstream" +) + +// upstreamTokenRefresher implements storage.UpstreamTokenRefresher by wrapping +// an upstream OAuth2Provider (for token refresh) and UpstreamTokenStorage (for +// persisting the refreshed tokens). +type upstreamTokenRefresher struct { + provider upstream.OAuth2Provider + storage storage.UpstreamTokenStorage +} + +// Compile-time check that upstreamTokenRefresher implements storage.UpstreamTokenRefresher. +var _ storage.UpstreamTokenRefresher = (*upstreamTokenRefresher)(nil) + +// RefreshAndStore refreshes expired upstream tokens using the stored refresh token, +// persists the new tokens, and returns them. +func (r *upstreamTokenRefresher) RefreshAndStore( + ctx context.Context, + sessionID string, + expired *storage.UpstreamTokens, +) (*storage.UpstreamTokens, error) { + if expired == nil { + return nil, errors.New("expired tokens are required") + } + if expired.RefreshToken == "" { + return nil, errors.New("no refresh token available for upstream token refresh") + } + + slog.Debug("attempting upstream token refresh", + "session_id", sessionID, + "provider_id", expired.ProviderID, + ) + + // Refresh tokens via the upstream provider + newTokens, err := r.provider.RefreshTokens(ctx, expired.RefreshToken, expired.UpstreamSubject) + if err != nil { + return nil, fmt.Errorf("upstream token refresh failed: %w", err) + } + + // Build updated storage tokens preserving binding fields from the original + updated := &storage.UpstreamTokens{ + ProviderID: expired.ProviderID, + AccessToken: newTokens.AccessToken, + RefreshToken: newTokens.RefreshToken, + IDToken: newTokens.IDToken, + ExpiresAt: newTokens.ExpiresAt, + UserID: expired.UserID, + UpstreamSubject: expired.UpstreamSubject, + ClientID: expired.ClientID, + } + + // If the provider didn't rotate the refresh token, keep the original + if updated.RefreshToken == "" { + updated.RefreshToken = expired.RefreshToken + } + + // Store the refreshed tokens + if err := r.storage.StoreUpstreamTokens(ctx, sessionID, updated); err != nil { + // Log but still return the refreshed tokens — the current request can + // proceed even if storage fails. The next request will retry the refresh. + slog.Warn("failed to store refreshed upstream tokens", + "session_id", sessionID, + "error", err, + ) + return updated, nil + } + + slog.Debug("upstream tokens refreshed successfully", + "session_id", sessionID, + "provider_id", expired.ProviderID, + ) + + return updated, nil +} diff --git a/pkg/authserver/refresher_test.go b/pkg/authserver/refresher_test.go new file mode 100644 index 0000000000..23819a6f16 --- /dev/null +++ b/pkg/authserver/refresher_test.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package authserver + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/authserver/storage" + storagemocks "github.com/stacklok/toolhive/pkg/authserver/storage/mocks" + "github.com/stacklok/toolhive/pkg/authserver/upstream" + upstreammocks "github.com/stacklok/toolhive/pkg/authserver/upstream/mocks" +) + +func TestUpstreamTokenRefresher_RefreshAndStore(t *testing.T) { + t.Parallel() + + newExpiry := time.Now().Add(1 * time.Hour) + + baseExpired := &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "old-access", + RefreshToken: "old-refresh", + IDToken: "old-id-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + UserID: "user-123", + UpstreamSubject: "upstream-sub-456", + ClientID: "client-abc", + } + + tests := []struct { + name string + sessionID string + expired *storage.UpstreamTokens + setupProvider func(*testing.T, *upstreammocks.MockOAuth2Provider) + setupStorage func(*testing.T, *storagemocks.MockUpstreamTokenStorage) + wantErr bool + wantErrContain string + checkResult func(*testing.T, *storage.UpstreamTokens) + }{ + { + name: "successful refresh with token rotation", + sessionID: "session-1", + expired: baseExpired, + setupProvider: func(_ *testing.T, p *upstreammocks.MockOAuth2Provider) { + p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). + Return(&upstream.Tokens{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + IDToken: "new-id-token", + ExpiresAt: newExpiry, + }, nil) + }, + setupStorage: func(_ *testing.T, s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-1", gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, tokens *storage.UpstreamTokens) error { + // Verify binding fields are preserved from expired tokens + assert.Equal(t, "github", tokens.ProviderID) + assert.Equal(t, "user-123", tokens.UserID) + assert.Equal(t, "upstream-sub-456", tokens.UpstreamSubject) + assert.Equal(t, "client-abc", tokens.ClientID) + // Verify new token values + assert.Equal(t, "new-access", tokens.AccessToken) + assert.Equal(t, "new-refresh", tokens.RefreshToken) + assert.Equal(t, "new-id-token", tokens.IDToken) + assert.Equal(t, newExpiry, tokens.ExpiresAt) + return nil + }) + }, + checkResult: func(t *testing.T, result *storage.UpstreamTokens) { + t.Helper() + assert.Equal(t, "new-access", result.AccessToken) + assert.Equal(t, "new-refresh", result.RefreshToken) + assert.Equal(t, "new-id-token", result.IDToken) + assert.Equal(t, newExpiry, result.ExpiresAt) + // Binding fields preserved + assert.Equal(t, "github", result.ProviderID) + assert.Equal(t, "user-123", result.UserID) + assert.Equal(t, "upstream-sub-456", result.UpstreamSubject) + assert.Equal(t, "client-abc", result.ClientID) + }, + }, + { + name: "provider does not rotate refresh token - keeps old one", + sessionID: "session-2", + expired: baseExpired, + setupProvider: func(_ *testing.T, p *upstreammocks.MockOAuth2Provider) { + p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). + Return(&upstream.Tokens{ + AccessToken: "new-access", + RefreshToken: "", // Provider did not rotate + IDToken: "new-id-token", + ExpiresAt: newExpiry, + }, nil) + }, + setupStorage: func(_ *testing.T, s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-2", gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, tokens *storage.UpstreamTokens) error { + assert.Equal(t, "old-refresh", tokens.RefreshToken) + return nil + }) + }, + checkResult: func(t *testing.T, result *storage.UpstreamTokens) { + t.Helper() + assert.Equal(t, "new-access", result.AccessToken) + assert.Equal(t, "old-refresh", result.RefreshToken) + }, + }, + { + name: "nil expired tokens returns error", + sessionID: "session-3", + expired: nil, + setupProvider: func(_ *testing.T, _ *upstreammocks.MockOAuth2Provider) {}, + setupStorage: func(_ *testing.T, _ *storagemocks.MockUpstreamTokenStorage) {}, + wantErr: true, + wantErrContain: "expired tokens are required", + }, + { + name: "empty refresh token returns error", + sessionID: "session-4", + expired: &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "old-access", + RefreshToken: "", + UserID: "user-123", + UpstreamSubject: "upstream-sub-456", + ClientID: "client-abc", + }, + setupProvider: func(_ *testing.T, _ *upstreammocks.MockOAuth2Provider) {}, + setupStorage: func(_ *testing.T, _ *storagemocks.MockUpstreamTokenStorage) {}, + wantErr: true, + wantErrContain: "no refresh token available", + }, + { + name: "provider refresh fails returns error", + sessionID: "session-5", + expired: baseExpired, + setupProvider: func(_ *testing.T, p *upstreammocks.MockOAuth2Provider) { + p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). + Return(nil, errors.New("upstream IDP unavailable")) + }, + setupStorage: func(_ *testing.T, _ *storagemocks.MockUpstreamTokenStorage) {}, + wantErr: true, + wantErrContain: "upstream token refresh failed", + }, + { + name: "storage fails after refresh - returns refreshed tokens anyway", + sessionID: "session-6", + expired: baseExpired, + setupProvider: func(_ *testing.T, p *upstreammocks.MockOAuth2Provider) { + p.EXPECT().RefreshTokens(gomock.Any(), "old-refresh", "upstream-sub-456"). + Return(&upstream.Tokens{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + IDToken: "new-id-token", + ExpiresAt: newExpiry, + }, nil) + }, + setupStorage: func(_ *testing.T, s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().StoreUpstreamTokens(gomock.Any(), "session-6", gomock.Any()). + Return(errors.New("redis connection lost")) + }, + checkResult: func(t *testing.T, result *storage.UpstreamTokens) { + t.Helper() + // Tokens should still be returned despite storage failure + assert.Equal(t, "new-access", result.AccessToken) + assert.Equal(t, "new-refresh", result.RefreshToken) + assert.Equal(t, "new-id-token", result.IDToken) + assert.Equal(t, newExpiry, result.ExpiresAt) + // Binding fields preserved + assert.Equal(t, "github", result.ProviderID) + assert.Equal(t, "user-123", result.UserID) + assert.Equal(t, "upstream-sub-456", result.UpstreamSubject) + assert.Equal(t, "client-abc", result.ClientID) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockProvider := upstreammocks.NewMockOAuth2Provider(ctrl) + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + + tt.setupProvider(t, mockProvider) + tt.setupStorage(t, mockStorage) + + refresher := &upstreamTokenRefresher{ + provider: mockProvider, + storage: mockStorage, + } + + result, err := refresher.RefreshAndStore(context.Background(), tt.sessionID, tt.expired) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErrContain) + assert.Nil(t, result) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + if tt.checkResult != nil { + tt.checkResult(t, result) + } + }) + } +} diff --git a/pkg/authserver/runner/embeddedauthserver.go b/pkg/authserver/runner/embeddedauthserver.go index 91abdc82cd..0643481c4e 100644 --- a/pkg/authserver/runner/embeddedauthserver.go +++ b/pkg/authserver/runner/embeddedauthserver.go @@ -135,6 +135,12 @@ func (e *EmbeddedAuthServer) IDPTokenStorage() storage.UpstreamTokenStorage { return e.server.IDPTokenStorage() } +// UpstreamTokenRefresher returns a refresher that can refresh expired upstream +// tokens using the upstream provider's refresh token grant. +func (e *EmbeddedAuthServer) UpstreamTokenRefresher() storage.UpstreamTokenRefresher { + return e.server.UpstreamTokenRefresher() +} + // createKeyProvider creates a KeyProvider from SigningKeyRunConfig. // Returns a GeneratingProvider if config is nil or empty (development mode). func createKeyProvider(cfg *authserver.SigningKeyRunConfig) (keys.KeyProvider, error) { diff --git a/pkg/authserver/server.go b/pkg/authserver/server.go index 1c8286ab6a..aef969fec8 100644 --- a/pkg/authserver/server.go +++ b/pkg/authserver/server.go @@ -31,6 +31,11 @@ type Server interface { // Returns nil if no upstream IDP is configured. IDPTokenStorage() storage.UpstreamTokenStorage + // UpstreamTokenRefresher returns a refresher that can refresh expired upstream + // tokens using the upstream provider's refresh token grant. + // Returns nil if no upstream IDP is configured. + UpstreamTokenRefresher() storage.UpstreamTokenRefresher + // Close releases resources held by the server. Close() error } diff --git a/pkg/authserver/server_impl.go b/pkg/authserver/server_impl.go index 92276e05e6..1438d9c71b 100644 --- a/pkg/authserver/server_impl.go +++ b/pkg/authserver/server_impl.go @@ -161,6 +161,18 @@ func (s *server) IDPTokenStorage() storage.UpstreamTokenStorage { return s.storage } +// UpstreamTokenRefresher returns a refresher that wraps the upstream provider +// and storage to transparently refresh expired upstream tokens. +func (s *server) UpstreamTokenRefresher() storage.UpstreamTokenRefresher { + if s.upstreamIDP == nil { + return nil + } + return &upstreamTokenRefresher{ + provider: s.upstreamIDP, + storage: s.storage, + } +} + // Close releases resources held by the server. func (s *server) Close() error { slog.Debug("closing OAuth authorization server") diff --git a/pkg/authserver/storage/memory.go b/pkg/authserver/storage/memory.go index 29402eb069..3423d0e81a 100644 --- a/pkg/authserver/storage/memory.go +++ b/pkg/authserver/storage/memory.go @@ -693,11 +693,13 @@ func (s *MemoryStorage) StoreUpstreamTokens(_ context.Context, sessionID string, defer s.mu.Unlock() now := time.Now() + // Add DefaultRefreshTokenTTL beyond access token expiry so the refresh token + // survives in storage for transparent token refresh by the middleware. var expiresAt time.Time if tokens != nil && !tokens.ExpiresAt.IsZero() { - expiresAt = tokens.ExpiresAt + expiresAt = tokens.ExpiresAt.Add(DefaultRefreshTokenTTL) } else { - expiresAt = now.Add(DefaultAccessTokenTTL) + expiresAt = now.Add(DefaultAccessTokenTTL + DefaultRefreshTokenTTL) } // Make a defensive copy to prevent aliasing issues @@ -735,18 +737,12 @@ func (s *MemoryStorage) GetUpstreamTokens(_ context.Context, sessionID string) ( return nil, fmt.Errorf("%w: %w", ErrNotFound, fosite.ErrNotFound.WithHint("Upstream tokens not found")) } - // Check if expired - if time.Now().After(entry.expiresAt) { - slog.Debug("upstream tokens expired", "session_id", sessionID) - return nil, ErrExpired - } - // Return a defensive copy to prevent aliasing issues tokens := entry.value if tokens == nil { return nil, nil } - return &UpstreamTokens{ + result := &UpstreamTokens{ ProviderID: tokens.ProviderID, AccessToken: tokens.AccessToken, RefreshToken: tokens.RefreshToken, @@ -755,7 +751,17 @@ func (s *MemoryStorage) GetUpstreamTokens(_ context.Context, sessionID string) ( UserID: tokens.UserID, UpstreamSubject: tokens.UpstreamSubject, ClientID: tokens.ClientID, - }, nil + } + + // Check the token's own ExpiresAt (access token expiry), not the entry's expiresAt + // (storage TTL which includes DefaultRefreshTokenTTL buffer for refresh token survival). + // Return tokens along with ErrExpired so callers can use the refresh token. + if !result.ExpiresAt.IsZero() && time.Now().After(result.ExpiresAt) { + slog.Debug("upstream tokens expired", "session_id", sessionID) + return result, ErrExpired + } + + return result, nil } // DeleteUpstreamTokens removes the upstream IDP tokens for a session. diff --git a/pkg/authserver/storage/memory_test.go b/pkg/authserver/storage/memory_test.go index 5546931a14..85cc839f33 100644 --- a/pkg/authserver/storage/memory_test.go +++ b/pkg/authserver/storage/memory_test.go @@ -473,17 +473,22 @@ func TestMemoryStorage_UpstreamTokens(t *testing.T) { }) }) - t.Run("get expired tokens returns ErrExpired", func(t *testing.T) { + t.Run("get expired tokens returns ErrExpired with token data", func(t *testing.T) { withStorage(t, func(ctx context.Context, s *MemoryStorage) { require.NoError(t, s.StoreUpstreamTokens(ctx, "expired", &UpstreamTokens{ - AccessToken: "expired-token", ExpiresAt: time.Now().Add(-time.Hour), + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), })) assert.Equal(t, 1, s.Stats().UpstreamTokens) retrieved, err := s.GetUpstreamTokens(ctx, "expired") require.Error(t, err) assert.ErrorIs(t, err, ErrExpired) - assert.Nil(t, retrieved) + // Tokens should be returned alongside ErrExpired for refresh purposes + require.NotNil(t, retrieved) + assert.Equal(t, "expired-token", retrieved.AccessToken) + assert.Equal(t, "refresh-token", retrieved.RefreshToken) }) }) } @@ -637,7 +642,8 @@ func TestMemoryStorage_CleanupExpired(t *testing.T) { { name: "upstream tokens", setup: func(ctx context.Context, s *MemoryStorage) { - _ = s.StoreUpstreamTokens(ctx, "expired", &UpstreamTokens{AccessToken: "exp", ExpiresAt: time.Now().Add(-time.Hour)}) + // Entry must be older than DefaultRefreshTokenTTL past access token expiry to be cleaned up + _ = s.StoreUpstreamTokens(ctx, "expired", &UpstreamTokens{AccessToken: "exp", ExpiresAt: time.Now().Add(-DefaultRefreshTokenTTL - time.Hour)}) _ = s.StoreUpstreamTokens(ctx, "valid", &UpstreamTokens{AccessToken: "val", ExpiresAt: time.Now().Add(time.Hour)}) }, getStats: func(st Stats) int { return st.UpstreamTokens }, diff --git a/pkg/authserver/storage/mocks/mock_storage.go b/pkg/authserver/storage/mocks/mock_storage.go index 42d41949f8..278a1be587 100644 --- a/pkg/authserver/storage/mocks/mock_storage.go +++ b/pkg/authserver/storage/mocks/mock_storage.go @@ -234,6 +234,45 @@ func (mr *MockUpstreamTokenStorageMockRecorder) StoreUpstreamTokens(ctx, session return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StoreUpstreamTokens", reflect.TypeOf((*MockUpstreamTokenStorage)(nil).StoreUpstreamTokens), ctx, sessionID, tokens) } +// MockUpstreamTokenRefresher is a mock of UpstreamTokenRefresher interface. +type MockUpstreamTokenRefresher struct { + ctrl *gomock.Controller + recorder *MockUpstreamTokenRefresherMockRecorder + isgomock struct{} +} + +// MockUpstreamTokenRefresherMockRecorder is the mock recorder for MockUpstreamTokenRefresher. +type MockUpstreamTokenRefresherMockRecorder struct { + mock *MockUpstreamTokenRefresher +} + +// NewMockUpstreamTokenRefresher creates a new mock instance. +func NewMockUpstreamTokenRefresher(ctrl *gomock.Controller) *MockUpstreamTokenRefresher { + mock := &MockUpstreamTokenRefresher{ctrl: ctrl} + mock.recorder = &MockUpstreamTokenRefresherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUpstreamTokenRefresher) EXPECT() *MockUpstreamTokenRefresherMockRecorder { + return m.recorder +} + +// RefreshAndStore mocks base method. +func (m *MockUpstreamTokenRefresher) RefreshAndStore(ctx context.Context, sessionID string, expired *storage.UpstreamTokens) (*storage.UpstreamTokens, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RefreshAndStore", ctx, sessionID, expired) + ret0, _ := ret[0].(*storage.UpstreamTokens) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RefreshAndStore indicates an expected call of RefreshAndStore. +func (mr *MockUpstreamTokenRefresherMockRecorder) RefreshAndStore(ctx, sessionID, expired any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshAndStore", reflect.TypeOf((*MockUpstreamTokenRefresher)(nil).RefreshAndStore), ctx, sessionID, expired) +} + // MockUserStorage is a mock of UserStorage interface. type MockUserStorage struct { ctrl *gomock.Controller diff --git a/pkg/authserver/storage/redis.go b/pkg/authserver/storage/redis.go index 01b35fd106..08e87ef835 100644 --- a/pkg/authserver/storage/redis.go +++ b/pkg/authserver/storage/redis.go @@ -765,11 +765,13 @@ func marshalUpstreamTokensWithTTL(tokens *UpstreamTokens) ([]byte, time.Duration return nil, 0, fmt.Errorf("failed to marshal upstream tokens: %w", err) } - ttl := DefaultAccessTokenTTL + // Add DefaultRefreshTokenTTL beyond access token expiry so the refresh token + // survives in storage for transparent token refresh by the middleware. + ttl := DefaultAccessTokenTTL + DefaultRefreshTokenTTL if !tokens.ExpiresAt.IsZero() { - ttl = time.Until(tokens.ExpiresAt) + ttl = time.Until(tokens.ExpiresAt) + DefaultRefreshTokenTTL if ttl < 0 { - ttl = DefaultAccessTokenTTL + ttl = DefaultRefreshTokenTTL } } @@ -841,14 +843,13 @@ func (s *RedisStorage) GetUpstreamTokens(ctx context.Context, sessionID string) // stores 0 for zero time. Skip the expiry check in this case since Redis TTL // handles the actual expiration. var expiresAt time.Time + var expired bool if stored.ExpiresAt != 0 { expiresAt = time.Unix(stored.ExpiresAt, 0) - if time.Now().After(expiresAt) { - return nil, ErrExpired - } + expired = time.Now().After(expiresAt) } - return &UpstreamTokens{ + tokens := &UpstreamTokens{ ProviderID: stored.ProviderID, AccessToken: stored.AccessToken, RefreshToken: stored.RefreshToken, @@ -857,7 +858,14 @@ func (s *RedisStorage) GetUpstreamTokens(ctx context.Context, sessionID string) UserID: stored.UserID, UpstreamSubject: stored.UpstreamSubject, ClientID: stored.ClientID, - }, nil + } + + // Return tokens along with ErrExpired so callers can use the refresh token + if expired { + return tokens, ErrExpired + } + + return tokens, nil } // DeleteUpstreamTokens removes the upstream IDP tokens for a session. diff --git a/pkg/authserver/storage/redis_integration_test.go b/pkg/authserver/storage/redis_integration_test.go index 86c8f2474e..aa955acef3 100644 --- a/pkg/authserver/storage/redis_integration_test.go +++ b/pkg/authserver/storage/redis_integration_test.go @@ -659,13 +659,19 @@ func TestIntegration_UpstreamTokens(t *testing.T) { }) }) - t.Run("expired tokens return ErrExpired", func(t *testing.T) { + t.Run("expired tokens return ErrExpired with token data", func(t *testing.T) { withIntegrationStorage(t, func(ctx context.Context, s *RedisStorage) { require.NoError(t, s.StoreUpstreamTokens(ctx, "sess-exp", &UpstreamTokens{ - AccessToken: "expired", ExpiresAt: time.Now().Add(-time.Hour), + AccessToken: "expired-token", + RefreshToken: "expired-refresh", + ExpiresAt: time.Now().Add(-time.Hour), })) - _, err := s.GetUpstreamTokens(ctx, "sess-exp") + tokens, err := s.GetUpstreamTokens(ctx, "sess-exp") assert.ErrorIs(t, err, ErrExpired) + // Expired tokens should still return the data (needed for refresh) + require.NotNil(t, tokens, "expired tokens should return data for refresh") + assert.Equal(t, "expired-token", tokens.AccessToken) + assert.Equal(t, "expired-refresh", tokens.RefreshToken) }) }) diff --git a/pkg/authserver/storage/redis_test.go b/pkg/authserver/storage/redis_test.go index db0910c9fe..92f2639b5b 100644 --- a/pkg/authserver/storage/redis_test.go +++ b/pkg/authserver/storage/redis_test.go @@ -622,19 +622,24 @@ func TestRedisStorage_UpstreamTokens(t *testing.T) { }) }) - t.Run("get expired tokens returns ErrExpired", func(t *testing.T) { + t.Run("get expired tokens returns ErrExpired with token data", func(t *testing.T) { withRedisStorage(t, func(ctx context.Context, s *RedisStorage, _ *miniredis.Miniredis) { // Store with an ExpiresAt that's already in the past. - // The TTL will be set to DefaultAccessTokenTTL, but the stored - // ExpiresAt will be checked and return ErrExpired. + // The TTL includes DefaultRefreshTokenTTL so the key survives + // past access token expiry, allowing refresh token retrieval. require.NoError(t, s.StoreUpstreamTokens(ctx, "expired", &UpstreamTokens{ - AccessToken: "expired-token", ExpiresAt: time.Now().Add(-time.Hour), + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), })) retrieved, err := s.GetUpstreamTokens(ctx, "expired") require.Error(t, err) assert.ErrorIs(t, err, ErrExpired) - assert.Nil(t, retrieved) + // Tokens should be returned alongside ErrExpired for refresh purposes + require.NotNil(t, retrieved) + assert.Equal(t, "expired-token", retrieved.AccessToken) + assert.Equal(t, "refresh-token", retrieved.RefreshToken) }) }) diff --git a/pkg/authserver/storage/types.go b/pkg/authserver/storage/types.go index 308b4d8841..f6c7c89316 100644 --- a/pkg/authserver/storage/types.go +++ b/pkg/authserver/storage/types.go @@ -201,7 +201,9 @@ type UpstreamTokenStorage interface { // GetUpstreamTokens retrieves the upstream IDP tokens for a session. // Returns ErrNotFound if the session does not exist. - // Returns ErrExpired if the tokens have expired. + // Returns ErrExpired if the tokens have expired. When ErrExpired is returned, + // the token data (including refresh token) is also returned to allow callers + // to attempt a token refresh. // Returns ErrInvalidBinding if binding validation fails. GetUpstreamTokens(ctx context.Context, sessionID string) (*UpstreamTokens, error) @@ -210,6 +212,16 @@ type UpstreamTokenStorage interface { DeleteUpstreamTokens(ctx context.Context, sessionID string) error } +// UpstreamTokenRefresher can refresh expired upstream tokens using their stored refresh token. +// This is implemented by the auth server and used by the upstreamswap middleware to +// transparently refresh tokens without forcing re-authentication. +type UpstreamTokenRefresher interface { + // RefreshAndStore refreshes the upstream tokens for the given session using + // the stored refresh token, stores the new tokens, and returns them. + // Returns an error if the refresh token is empty, revoked, or the refresh fails. + RefreshAndStore(ctx context.Context, sessionID string, expired *UpstreamTokens) (*UpstreamTokens, error) +} + // UserStorage provides user and provider identity management operations. // This interface supports multi-IDP scenarios where a single user can authenticate // via multiple upstream identity providers (e.g., Google and GitHub). diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 4b2ee6c859..08a87d1690 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -145,6 +145,18 @@ func (r *Runner) GetUpstreamTokenStorage() func() storage.UpstreamTokenStorage { } } +// GetUpstreamTokenRefresher returns a lazy accessor for the upstream token refresher. +// The returned function should be called at request time; it returns nil if +// the embedded auth server is not configured. +func (r *Runner) GetUpstreamTokenRefresher() func() storage.UpstreamTokenRefresher { + return func() storage.UpstreamTokenRefresher { + if r.embeddedAuthServer == nil { + return nil + } + return r.embeddedAuthServer.UpstreamTokenRefresher() + } +} + // GetName returns the name of the mcp-service from the runner config (implements types.RunnerConfig) func (c *RunConfig) GetName() string { return c.Name diff --git a/pkg/transport/types/mocks/mock_transport.go b/pkg/transport/types/mocks/mock_transport.go index 87ce9602fb..20ad19fb31 100644 --- a/pkg/transport/types/mocks/mock_transport.go +++ b/pkg/transport/types/mocks/mock_transport.go @@ -123,6 +123,20 @@ func (mr *MockMiddlewareRunnerMockRecorder) GetConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetConfig)) } +// GetUpstreamTokenRefresher mocks base method. +func (m *MockMiddlewareRunner) GetUpstreamTokenRefresher() func() storage.UpstreamTokenRefresher { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUpstreamTokenRefresher") + ret0, _ := ret[0].(func() storage.UpstreamTokenRefresher) + return ret0 +} + +// GetUpstreamTokenRefresher indicates an expected call of GetUpstreamTokenRefresher. +func (mr *MockMiddlewareRunnerMockRecorder) GetUpstreamTokenRefresher() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpstreamTokenRefresher", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetUpstreamTokenRefresher)) +} + // GetUpstreamTokenStorage mocks base method. func (m *MockMiddlewareRunner) GetUpstreamTokenStorage() func() storage.UpstreamTokenStorage { m.ctrl.T.Helper() diff --git a/pkg/transport/types/transport.go b/pkg/transport/types/transport.go index 67e8b93e5d..8706aa09f2 100644 --- a/pkg/transport/types/transport.go +++ b/pkg/transport/types/transport.go @@ -86,6 +86,11 @@ type MiddlewareRunner interface { // before the embedded auth server is initialized. Storage availability is // determined at request time when the returned function is called. GetUpstreamTokenStorage() func() storage.UpstreamTokenStorage + + // GetUpstreamTokenRefresher returns a lazy accessor for the upstream token refresher. + // The returned function should be called at request time; it returns nil if + // the embedded auth server is not configured or the refresher is unavailable. + GetUpstreamTokenRefresher() func() storage.UpstreamTokenRefresher } // RunnerConfig defines the config interface needed by middleware to access runner configuration