Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 24 additions & 55 deletions pkg/auth/tokenexchange/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,13 @@ import (
"github.com/stacklok/toolhive/pkg/oauthproto"
)

const (
// defaultHTTPTimeout is the timeout for HTTP requests
defaultHTTPTimeout = 30 * time.Second

// maxResponseBodySize is the maximum size for reading response bodies (1 MB)
maxResponseBodySize = 1 << 20

// redactedPlaceholder is used to redact sensitive values in string representations
redactedPlaceholder = "[REDACTED]"

// emptyPlaceholder is used to indicate empty/missing values in string representations
emptyPlaceholder = "<empty>"
)
// maxResponseBodySize bounds io.LimitReader in executeTokenExchangeRequest so
// a pathological server cannot exhaust memory. The shared pkg/oauthproto
// package has an identical unexported constant, but we cannot import it yet —
// the shared one is consumed by oauthproto.DoTokenRequest, which will replace
// executeTokenExchangeRequest in a follow-up commit.
// TODO: drop when executeTokenExchangeRequest is replaced by oauthproto.DoTokenRequest.
const maxResponseBodySize = 1 << 20

// NormalizeTokenType converts a short token type name to its full URN.
// Accepts both short forms ("access_token", "id_token", "jwt") and full URNs.
Expand Down Expand Up @@ -66,11 +60,6 @@ func NormalizeTokenType(tokenType string) (string, error) {
}
}

// defaultHTTPClient is the default HTTP client used for token exchange requests.
var defaultHTTPClient = &http.Client{
Timeout: defaultHTTPTimeout,
}

// actingParty represents the acting party in a token exchange delegation scenario.
// When present, it indicates that the actor token holder is acting on behalf of the subject token holder.
type actingParty struct {
Expand All @@ -96,64 +85,41 @@ type exchangeRequest struct {

// String implements fmt.Stringer for exchangeRequest, redacting sensitive tokens.
func (r exchangeRequest) String() string {
subjectToken := redactedPlaceholder
if r.SubjectToken == "" {
subjectToken = emptyPlaceholder
}

actorToken := "<none>"
if r.ActingParty != nil {
actorToken = redactedPlaceholder
if r.ActingParty.ActorToken == "" {
actorToken = emptyPlaceholder
}
actorToken = oauthproto.Redact(r.ActingParty.ActorToken)
}

return fmt.Sprintf("exchangeRequest{GrantType: %s, Audience: %s, Resource: %s, Scope: %v, SubjectToken: %s, ActorToken: %s}",
r.GrantType, r.Audience, r.Resource, r.Scope, subjectToken, actorToken)
r.GrantType, r.Audience, r.Resource, r.Scope, oauthproto.Redact(r.SubjectToken), actorToken)
}

// response is used to decode the remote server response during an OAuth 2.0 token exchange.
type response struct {
AccessToken string `json:"access_token"` //nolint:gosec // G117: field legitimately holds sensitive data
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"` //nolint:gosec // G117: field legitimately holds sensitive data
RefreshToken string `json:"refresh_token"`
}

// String implements fmt.Stringer for response, redacting sensitive tokens.
func (r response) String() string {
accessToken := redactedPlaceholder
if r.AccessToken == "" {
accessToken = emptyPlaceholder
}

refreshToken := redactedPlaceholder
if r.RefreshToken == "" {
refreshToken = emptyPlaceholder
}

return fmt.Sprintf("response{AccessToken: %s, TokenType: %s, ExpiresIn: %d, RefreshToken: %s}",
accessToken, r.TokenType, r.ExpiresIn, refreshToken)
oauthproto.Redact(r.AccessToken), r.TokenType, r.ExpiresIn, oauthproto.Redact(r.RefreshToken))
}

// clientAuthentication represents OAuth client credentials for token exchange.
type clientAuthentication struct {
ClientID string
ClientSecret string //nolint:gosec // G117
ClientSecret string
}

// String implements fmt.Stringer for clientAuthentication, redacting the client secret.
func (c clientAuthentication) String() string {
clientSecret := redactedPlaceholder
if c.ClientSecret == "" {
clientSecret = emptyPlaceholder
}

return fmt.Sprintf("clientAuthentication{ClientID: %s, ClientSecret: %s}",
c.ClientID, clientSecret)
c.ClientID, oauthproto.Redact(c.ClientSecret))
}

// ExchangeConfig holds the configuration for token exchange.
Expand All @@ -165,7 +131,7 @@ type ExchangeConfig struct {
ClientID string

// ClientSecret is the OAuth 2.0 client secret
ClientSecret string //nolint:gosec // G117
ClientSecret string

// Audience is the target audience for the exchanged token (optional per RFC 8693)
Audience string
Expand Down Expand Up @@ -195,7 +161,7 @@ type ExchangeConfig struct {
SubjectTokenProvider func() (string, error)

// HTTPClient is the HTTP client to use for token exchange requests.
// If nil, defaultHTTPClient will be used.
// If nil, oauthproto.DefaultHTTPClient() will be used.
HTTPClient *http.Client
}

Expand Down Expand Up @@ -358,7 +324,7 @@ func exchangeToken(
}

if client == nil {
client = defaultHTTPClient
client = oauthproto.DefaultHTTPClient()
}

body, err := executeTokenExchangeRequest(client, req)
Expand Down Expand Up @@ -457,14 +423,17 @@ func createTokenExchangeRequest(

// executeTokenExchangeRequest sends the HTTP request and returns the response body.
func executeTokenExchangeRequest(client *http.Client, req *http.Request) ([]byte, error) {
resp, err := client.Do(req) // #nosec G704 -- URL is the configured token exchange endpoint
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("token exchange request failed: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
// Non-fatal: response body cleanup failure
slog.Debug("Failed to close response body", "error", err)
// Close without draining — matches oauthproto.DoTokenRequest. The
// LimitReader below caps how much we read; draining the remainder
// would be unbounded on oversized or never-terminating bodies and
// defeat the 1 MiB memory cap. Connection reuse is the tradeoff.
if closeErr := resp.Body.Close(); closeErr != nil {
slog.Debug("token exchange: close response body", "error", closeErr)
}
}()

Expand Down
44 changes: 22 additions & 22 deletions pkg/auth/tokenexchange/exchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func TestTokenSource_Token_Success(t *testing.T) {
err = json.NewEncoder(w).Encode(resp)
require.NoError(t, err)
}))
defer server.Close()
t.Cleanup(server.Close)

// Create config with test server
config := &ExchangeConfig{
Expand Down Expand Up @@ -258,7 +258,7 @@ func TestTokenSource_Token_WithRefreshToken(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

config := &ExchangeConfig{
TokenURL: server.URL,
Expand Down Expand Up @@ -290,7 +290,7 @@ func TestTokenSource_Token_NoExpiry(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

config := &ExchangeConfig{
TokenURL: server.URL,
Expand Down Expand Up @@ -343,7 +343,7 @@ func TestTokenSource_Token_ContextCancellation(t *testing.T) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
t.Cleanup(server.Close)

config := &ExchangeConfig{
TokenURL: server.URL,
Expand Down Expand Up @@ -418,7 +418,7 @@ func TestExchangeToken_HTTPErrorResponses(t *testing.T) {
w.WriteHeader(tt.statusCode)
_, _ = w.Write([]byte(tt.responseBody))
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Expand Down Expand Up @@ -483,7 +483,7 @@ func TestExchangeToken_MalformedJSON(t *testing.T) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(tt.responseBody))
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand All @@ -507,7 +507,7 @@ func TestExchangeToken_MissingRequiredFields(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
t.Fatal("should not reach server")
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
// Missing SubjectToken
Expand Down Expand Up @@ -542,7 +542,7 @@ func TestExchangeToken_DefaultValues(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -577,7 +577,7 @@ func TestExchangeToken_OptionalFields(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -618,7 +618,7 @@ func TestExchangeToken_ActorTokenWithoutType(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -686,7 +686,7 @@ func TestExchangeToken_ResponseSizeLimit(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -742,7 +742,7 @@ func TestExchangeToken_NoCredentialLeakage(t *testing.T) {
t.Parallel()

server := tt.setupServer()
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: tt.subjectToken,
Expand Down Expand Up @@ -784,7 +784,7 @@ func TestExchangeToken_FormEncoding(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: specialChars,
Expand Down Expand Up @@ -817,7 +817,7 @@ func TestExchangeToken_ContentLength(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -901,7 +901,7 @@ func TestSubjectTokenProvider_Variants(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

config := &ExchangeConfig{
TokenURL: server.URL,
Expand Down Expand Up @@ -950,7 +950,7 @@ func TestExchangeToken_EmptyClientCredentials(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -988,7 +988,7 @@ func TestExchangeToken_OnlyClientID(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -1020,7 +1020,7 @@ func TestExchangeToken_ResponseFields(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -1052,7 +1052,7 @@ func TestExchangeToken_MinimalResponse(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -1121,7 +1121,7 @@ func TestExchangeToken_ScopeArray(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down Expand Up @@ -1328,7 +1328,7 @@ func TestExchangeToken_URLValues(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Expand Down Expand Up @@ -1393,7 +1393,7 @@ func TestExchangeToken_BasicAuthURLEncoding(t *testing.T) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
t.Cleanup(server.Close)

request := &exchangeRequest{
SubjectToken: "test-token",
Expand Down
5 changes: 2 additions & 3 deletions pkg/auth/tokenexchange/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const (
const (
// EnvClientSecret is the environment variable name for the OAuth client secret
// This corresponds to the "client_secret" field in the token exchange configuration
//nolint:gosec // G101: This is an environment variable name, not a credential
//nolint:gosec // G101: this is an environment variable name, not a credential value
EnvClientSecret = "TOOLHIVE_TOKEN_EXCHANGE_CLIENT_SECRET"
)

Expand All @@ -63,7 +63,7 @@ type Config struct {
ClientID string `json:"client_id"`

// ClientSecret is the OAuth 2.0 client secret
ClientSecret string `json:"client_secret"` //nolint:gosec // G117: field legitimately holds sensitive data
ClientSecret string `json:"client_secret"`

// Audience is the target audience for the exchanged token
Audience string `json:"audience"`
Expand Down Expand Up @@ -333,7 +333,6 @@ func createTokenExchangeMiddleware(

// Log some claim information for debugging
if sub, exists := claims["sub"]; exists {
//nolint:gosec // G706: subject claim is from validated JWT
slog.Debug("Performing token exchange for subject", "subject", sub)
}

Expand Down
Loading