Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 120 additions & 20 deletions pkg/auth/upstreamswap/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
package upstreamswap

import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"time"

"golang.org/x/sync/singleflight"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/authserver/server/session"
"github.com/stacklok/toolhive/pkg/authserver/storage"
Expand Down Expand Up @@ -48,6 +51,10 @@ type MiddlewareParams struct {
// This allows lazy access to the storage, which may not be available at middleware creation time.
type StorageGetter func() storage.UpstreamTokenStorage

// RefresherGetter is a function that returns an upstream token refresher.
// This allows lazy access to the refresher, which may not be available at middleware creation time.
type RefresherGetter func() storage.UpstreamTokenRefresher

// Middleware wraps the upstream swap middleware functionality.
type Middleware struct {
middleware types.MiddlewareFunction
Expand Down Expand Up @@ -81,12 +88,13 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
return fmt.Errorf("invalid upstream swap configuration: %w", err)
}

// Get storage getter from runner.
// The storage getter is a lazy accessor that checks storage availability at request time,
// so it's always non-nil. Actual storage availability is verified when processing requests.
// Get storage getter and refresher getter from runner.
// These are lazy accessors that check availability at request time,
// so they're always non-nil. Actual availability is verified when processing requests.
storageGetter := runner.GetUpstreamTokenStorage()
refresherGetter := runner.GetUpstreamTokenRefresher()

middleware := createMiddlewareFunc(cfg, storageGetter)
middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter)

upstreamSwapMw := &Middleware{
middleware: middleware,
Expand Down Expand Up @@ -141,13 +149,18 @@ func createCustomInjector(headerName string) injectionFunc {
}

// createMiddlewareFunc creates the actual middleware function.
func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.MiddlewareFunction {
func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter, refresherGetter RefresherGetter) types.MiddlewareFunction {
// Determine injection strategy at startup time
strategy := cfg.HeaderStrategy
if strategy == "" {
strategy = HeaderStrategyReplace
}

// Deduplicate concurrent upstream token refresh attempts for the same session.
// Providers that rotate refresh tokens (single-use) would fail all but the
// first concurrent caller without this.
var sfGroup singleflight.Group

var injectToken injectionFunc
switch strategy {
case HeaderStrategyReplace:
Expand Down Expand Up @@ -188,13 +201,13 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle
return
}

// 4. Lookup upstream tokens
tokens, err := stor.GetUpstreamTokens(r.Context(), tsid)
// 4. Lookup upstream tokens, refreshing if expired
tokens, err := getOrRefreshUpstreamTokens(r.Context(), &sfGroup, stor, tsid, refresherGetter)
if err != nil {
slog.Warn("Failed to get upstream tokens",
"middleware", "upstreamswap", "error", err)
// Token is expired, was not found, or failed binding validation
// (e.g., subject/client mismatch). All three are client-attributable
// Token is expired (refresh failed), was not found, or failed binding
// validation (e.g., subject/client mismatch). All three are client-attributable
// errors that require the caller to re-authenticate with the upstream IdP.
if errors.Is(err, storage.ErrExpired) ||
errors.Is(err, storage.ErrNotFound) ||
Expand All @@ -207,17 +220,7 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle
return
}

// 5. Check if expired
// Defense in depth: some storage implementations may return tokens
// without checking expiry (the interface does not require it).
if tokens.IsExpired(time.Now()) {
slog.Warn("Upstream tokens expired",
"middleware", "upstreamswap")
writeUpstreamAuthRequired(w)
return
}

// 6. Inject access token
// 5. Inject access token
if tokens.AccessToken == "" {
slog.Warn("Access token is empty",
"middleware", "upstreamswap")
Expand All @@ -233,3 +236,100 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle
})
}
}

// getOrRefreshUpstreamTokens retrieves upstream tokens from storage, automatically
// refreshing them if expired and a refresh token is available.
func getOrRefreshUpstreamTokens(
ctx context.Context,
sfGroup *singleflight.Group,
stor storage.UpstreamTokenStorage,
sessionID string,
refresherGetter RefresherGetter,
) (*storage.UpstreamTokens, error) {
tokens, err := stor.GetUpstreamTokens(ctx, sessionID)
if err != nil {
// ErrExpired returns tokens (including refresh token) alongside the error.
// Attempt a refresh before giving up.
if errors.Is(err, storage.ErrExpired) && tokens != nil {
if refreshed := doSingleFlightRefresh(ctx, sfGroup, sessionID, tokens, refresherGetter); refreshed != nil {
return refreshed, nil
}
}
return nil, err
}

// Defense in depth: some storage implementations may return tokens
// without checking expiry (the interface does not require it).
if !tokens.ExpiresAt.IsZero() && tokens.IsExpired(time.Now()) {
if refreshed := doSingleFlightRefresh(ctx, sfGroup, sessionID, tokens, refresherGetter); refreshed != nil {
return refreshed, nil
}
return nil, storage.ErrExpired
}

return tokens, nil
}

// doSingleFlightRefresh wraps tryRefreshUpstreamTokens in a singleflight.Group
// to deduplicate concurrent refresh attempts for the same session.
func doSingleFlightRefresh(
ctx context.Context,
sfGroup *singleflight.Group,
sessionID string,
expired *storage.UpstreamTokens,
refresherGetter RefresherGetter,
) *storage.UpstreamTokens {
result, err, _ := sfGroup.Do(sessionID, func() (any, error) {
// Detach from the triggering request's context so that if the first
// caller disconnects, the refresh still completes for waiting callers.
// The 30s timeout bounds the operation independently from client lifecycle.
refreshCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
defer cancel()
refreshed := tryRefreshUpstreamTokens(refreshCtx, sessionID, expired, refresherGetter)
if refreshed == nil {
return nil, errors.New("refresh failed")
}
return refreshed, nil
})
if err != nil {
return nil
}
tokens, _ := result.(*storage.UpstreamTokens)
return tokens
}

// tryRefreshUpstreamTokens attempts to refresh expired upstream tokens using the
// configured refresher. Returns the refreshed tokens on success, or nil on failure.
func tryRefreshUpstreamTokens(
ctx context.Context,
sessionID string,
expired *storage.UpstreamTokens,
refresherGetter RefresherGetter,
) *storage.UpstreamTokens {
if expired.RefreshToken == "" {
slog.Debug("No refresh token available, cannot refresh upstream tokens",
"middleware", "upstreamswap")
return nil
}

if refresherGetter == nil {
return nil
}
refresher := refresherGetter()
if refresher == nil {
slog.Debug("Token refresher unavailable, cannot refresh upstream tokens",
"middleware", "upstreamswap")
return nil
}

refreshed, err := refresher.RefreshAndStore(ctx, sessionID, expired)
if err != nil {
slog.Warn("Upstream token refresh failed",
"middleware", "upstreamswap", "error", err)
return nil
}

slog.Debug("Successfully refreshed upstream tokens",
"middleware", "upstreamswap")
return refreshed
}
Loading
Loading