Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions client/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,18 @@ var GenerateCodeChallenge = transport.GenerateCodeChallenge
// GenerateState generates a state parameter for OAuth
var GenerateState = transport.GenerateState

// AuthorizationRequiredError is returned when a 401 Unauthorized response is received
type AuthorizationRequiredError = transport.AuthorizationRequiredError

// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError

// IsAuthorizationRequiredError checks if an error is an AuthorizationRequiredError
func IsAuthorizationRequiredError(err error) bool {
var target *AuthorizationRequiredError
return errors.As(err, &target)
}
Comment on lines +60 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Helper functions are correct but API is incomplete for RFC9728 workflow.

The error helpers correctly extract the discovered metadata URL, but there's no corresponding API to apply the discovered URL:

  • GetResourceMetadataURL(err) extracts the URL.
  • No UpdateOAuthMetadataURL(client, url) or similar to reconfigure the handler.

For a complete RFC9728 implementation, add a helper to update the handler's ProtectedResourceMetadataURL and reset cached metadata, or document that users must create a new client with the discovered URL in the config.

Based on PR comments discussion.

Also applies to: 87-105

🤖 Prompt for AI Agents
In client/oauth.go around lines 60-70 (and also apply same change for the
related section at 87-105), the package exposes helpers to extract a discovered
Protected Resource Metadata URL from errors but lacks an API to apply that URL
to an existing client for RFC9728 flows; add a helper function (e.g.,
UpdateOAuthMetadataURL) that accepts the client instance and the discovered URL,
sets the client's ProtectedResourceMetadataURL field, and clears or resets any
cached metadata/memoized discovery state so the handler will reload metadata
from the new URL; ensure the helper validates the URL, updates the configuration
on the existing client safely (concurrent-safe if the client is used
concurrently), and document that callers should invoke this helper after
extracting the URL from GetResourceMetadataURL(err) to complete the RFC9728
workflow.


// IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError
func IsOAuthAuthorizationRequiredError(err error) bool {
var target *OAuthAuthorizationRequiredError
Expand All @@ -74,3 +83,23 @@ func GetOAuthHandler(err error) *transport.OAuthHandler {
}
return nil
}

// GetResourceMetadataURL extracts the protected resource metadata URL from an authorization error.
// This URL is extracted from the WWW-Authenticate header per RFC9728 Section 5.1.
// Works with both AuthorizationRequiredError and OAuthAuthorizationRequiredError.
// Returns empty string if no metadata URL was discovered.
func GetResourceMetadataURL(err error) string {
// Try OAuthAuthorizationRequiredError first (contains AuthorizationRequiredError)
var oauthErr *OAuthAuthorizationRequiredError
if errors.As(err, &oauthErr) {
return oauthErr.ResourceMetadataURL
}

// Try base AuthorizationRequiredError
var authErr *AuthorizationRequiredError
if errors.As(err, &authErr) {
return authErr.ResourceMetadataURL
}

return ""
}
85 changes: 84 additions & 1 deletion client/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,93 @@ func TestIsOAuthAuthorizationRequiredError(t *testing.T) {
if IsOAuthAuthorizationRequiredError(err2) {
t.Errorf("Expected IsOAuthAuthorizationRequiredError to return false")
}

// Verify GetOAuthHandler returns nil
handler = GetOAuthHandler(err2)
if handler != nil {
t.Errorf("Expected GetOAuthHandler to return nil")
}
}

func TestGetResourceMetadataURL(t *testing.T) {
// Test with error containing metadata URL
metadataURL := "https://auth.example.com/.well-known/oauth-protected-resource"
err := &transport.OAuthAuthorizationRequiredError{
Handler: transport.NewOAuthHandler(transport.OAuthConfig{}),
AuthorizationRequiredError: transport.AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
},
}

// Verify GetResourceMetadataURL returns the correct URL
result := GetResourceMetadataURL(err)
if result != metadataURL {
t.Errorf("Expected GetResourceMetadataURL to return %q, got %q", metadataURL, result)
}

// Test with error containing no metadata URL
err2 := &transport.OAuthAuthorizationRequiredError{
Handler: transport.NewOAuthHandler(transport.OAuthConfig{}),
AuthorizationRequiredError: transport.AuthorizationRequiredError{
ResourceMetadataURL: "",
},
}

result2 := GetResourceMetadataURL(err2)
if result2 != "" {
t.Errorf("Expected GetResourceMetadataURL to return empty string, got %q", result2)
}

// Test with non-OAuth error
err3 := fmt.Errorf("some other error")

result3 := GetResourceMetadataURL(err3)
if result3 != "" {
t.Errorf("Expected GetResourceMetadataURL to return empty string for non-OAuth error, got %q", result3)
}
}

func TestIsAuthorizationRequiredError(t *testing.T) {
// Test with base AuthorizationRequiredError (401 without OAuth handler)
metadataURL := "https://auth.example.com/.well-known/oauth-protected-resource"
err := &transport.AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
}

// Verify IsAuthorizationRequiredError returns true
if !IsAuthorizationRequiredError(err) {
t.Errorf("Expected IsAuthorizationRequiredError to return true for AuthorizationRequiredError")
}

// Verify GetResourceMetadataURL returns the correct URL
result := GetResourceMetadataURL(err)
if result != metadataURL {
t.Errorf("Expected GetResourceMetadataURL to return %q, got %q", metadataURL, result)
}

// Test with OAuthAuthorizationRequiredError (different type)
oauthErr := &transport.OAuthAuthorizationRequiredError{
Handler: transport.NewOAuthHandler(transport.OAuthConfig{}),
AuthorizationRequiredError: transport.AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
},
}

// Verify IsOAuthAuthorizationRequiredError returns true
if !IsOAuthAuthorizationRequiredError(oauthErr) {
t.Errorf("Expected IsOAuthAuthorizationRequiredError to return true for OAuthAuthorizationRequiredError")
}

// Verify GetResourceMetadataURL works with OAuth error too
result2 := GetResourceMetadataURL(oauthErr)
if result2 != metadataURL {
t.Errorf("Expected GetResourceMetadataURL to return %q, got %q", metadataURL, result2)
}

// Test with non-authorization error
err3 := fmt.Errorf("some other error")

// Verify IsAuthorizationRequiredError returns false
if IsAuthorizationRequiredError(err3) {
t.Errorf("Expected IsAuthorizationRequiredError to return false for non-authorization error")
}
}
20 changes: 17 additions & 3 deletions client/transport/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ type OAuthConfig struct {
// AuthServerMetadataURL is the URL to the OAuth server metadata
// If empty, the client will attempt to discover it from the base URL
AuthServerMetadataURL string
// ProtectedResourceMetadataURL is the URL to the OAuth protected resource metadata
// per RFC9728. If set, this URL will be used to discover the authorization server.
// This is typically extracted from the WWW-Authenticate header's resource_metadata parameter.
ProtectedResourceMetadataURL string
Comment on lines +35 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Configuration field cannot be updated after discovery.

ProtectedResourceMetadataURL is part of the immutable OAuthConfig. When a metadata URL is discovered from a WWW-Authenticate header (per RFC9728), there is no mechanism to update this field in the handler's config. The discovered URL is extracted into the error but cannot trigger a re-fetch of server metadata because:

  1. OAuthHandler.config is immutable.
  2. getServerMetadata() is guarded by sync.Once (line 351), preventing re-execution.

Consider adding UpdateProtectedResourceMetadataURL(url string) and ResetMetadata() methods to allow runtime discovery, or document that this field must be set at configuration time and discovery requires creating a new handler.

Based on PR comments discussion.

🤖 Prompt for AI Agents
In client/transport/oauth.go around lines 35 to 38, ProtectedResourceMetadataURL
is part of an immutable OAuthConfig so a metadata URL discovered at runtime
cannot be applied and getServerMetadata() is prevented from re-running by the
sync.Once guard; add runtime update support by (1) adding an exported method
UpdateProtectedResourceMetadataURL(url string) on the OAuthHandler that updates
a mutable field (or a small internal struct) holding the metadata URL, (2)
adding ResetMetadata() to clear any cached server metadata and reset the
sync.Once (or replace sync.Once with a mutex/atomic guard that supports
resetting), and (3) ensure getServerMetadata() reads the mutable metadata URL
and can re-fetch metadata after ResetMetadata() is called; alternatively,
document clearly that ProtectedResourceMetadataURL must be set at config time
and runtime discovery requires constructing a new handler.

// PKCEEnabled enables PKCE for the OAuth flow (recommended for public clients)
PKCEEnabled bool
// HTTPClient is an optional HTTP client to use for requests.
Expand Down Expand Up @@ -351,16 +355,26 @@ func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetada
return
}

// Try to discover the authorization server via OAuth Protected Resource
// as per RFC 9728 (https://datatracker.ietf.org/doc/html/rfc9728)
// Always extract base URL for fallback scenarios
baseURL, err := h.extractBaseURL()
if err != nil {
h.metadataFetchErr = fmt.Errorf("failed to extract base URL: %w", err)
return
}

// Determine the protected resource metadata URL with priority:
// 1. Explicit config (ProtectedResourceMetadataURL from RFC9728 WWW-Authenticate header)
// 2. Constructed from base URL
var protectedResourceURL string
if h.config.ProtectedResourceMetadataURL != "" {
// Use explicitly configured protected resource metadata URL
protectedResourceURL = h.config.ProtectedResourceMetadataURL
} else {
// Fall back to constructing the URL from base URL
protectedResourceURL = baseURL + "/.well-known/oauth-protected-resource"
}

// Try to fetch the OAuth Protected Resource metadata
protectedResourceURL := baseURL + "/.well-known/oauth-protected-resource"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, protectedResourceURL, nil)
if err != nil {
h.metadataFetchErr = fmt.Errorf("failed to create protected resource request: %w", err)
Expand Down
78 changes: 64 additions & 14 deletions client/transport/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ func (c *SSE) Start(ctx context.Context) error {
if err.Error() == "no valid token available, authorization required" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use errors.Is instead of string comparison.

Lines 148 and 394 check err.Error() == "no valid token available, authorization required" which is fragile:

  1. String matching breaks if error message changes.
  2. Doesn't respect error wrapping.
  3. Inconsistent with line 553 which correctly uses errors.Is(err, ErrOAuthAuthorizationRequired).

Replace with:

-if err.Error() == "no valid token available, authorization required" {
+if errors.Is(err, ErrOAuthAuthorizationRequired) {

As per coding guidelines: "Error handling: return sentinel errors, wrap with fmt.Errorf, and check with errors.Is/As."

Also applies to: 394-394

🤖 Prompt for AI Agents
In client/transport/sse.go around lines 148 and 394, replace fragile string
equality checks of err.Error() == "no valid token available, authorization
required" with errors.Is(err, ErrOAuthAuthorizationRequired) so wrapped errors
are handled correctly; ensure the package imports the standard errors package if
not already imported and remove the string literal comparison, using the
sentinel ErrOAuthAuthorizationRequired (already used elsewhere) for consistent,
robust error checks.

return &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
AuthorizationRequiredError: AuthorizationRequiredError{
ResourceMetadataURL: "", // No response available in this code path
},
}
}
return fmt.Errorf("failed to get authorization header: %w", err)
Expand All @@ -162,10 +165,24 @@ func (c *SSE) Start(ctx context.Context) error {

if resp.StatusCode != http.StatusOK {
resp.Body.Close()
// Handle OAuth unauthorized error
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
return &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
// Handle unauthorized error
if resp.StatusCode == http.StatusUnauthorized {
// Extract discovered metadata URL per RFC9728
metadataURL := extractResourceMetadataURL(resp.Header.Get("WWW-Authenticate"))

// If OAuth handler exists, return OAuth-specific error
if c.oauthHandler != nil {
return &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
AuthorizationRequiredError: AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
},
}
}

// No OAuth handler, return base authorization error
return &AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
}
}
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
Expand Down Expand Up @@ -377,6 +394,9 @@ func (c *SSE) SendRequest(
if err.Error() == "no valid token available, authorization required" {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
AuthorizationRequiredError: AuthorizationRequiredError{
ResourceMetadataURL: "", // No response available in this code path
},
}
}
return nil, fmt.Errorf("failed to get authorization header: %w", err)
Expand Down Expand Up @@ -419,17 +439,29 @@ func (c *SSE) SendRequest(
return nil, fmt.Errorf("failed to read response body: %w", err)
}

// Check if we got an error response
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
deleteResponseChan()
// Handle unauthorized error
if resp.StatusCode == http.StatusUnauthorized {
// Extract discovered metadata URL per RFC9728
metadataURL := extractResourceMetadataURL(resp.Header.Get("WWW-Authenticate"))

// If OAuth handler exists, return OAuth-specific error
if c.oauthHandler != nil {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
AuthorizationRequiredError: AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
},
}
}

// Handle OAuth unauthorized error
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
// No OAuth handler, return base authorization error
return nil, &AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
}
}

// Read error body
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
}

Expand Down Expand Up @@ -521,6 +553,9 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
if errors.Is(err, ErrOAuthAuthorizationRequired) {
return &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
AuthorizationRequiredError: AuthorizationRequiredError{
ResourceMetadataURL: "", // No response available in this code path
},
}
}
return fmt.Errorf("failed to get authorization header: %w", err)
Expand All @@ -541,13 +576,28 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
// Handle OAuth unauthorized error
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
return &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
// Handle unauthorized error
if resp.StatusCode == http.StatusUnauthorized {
// Extract discovered metadata URL per RFC9728
metadataURL := extractResourceMetadataURL(resp.Header.Get("WWW-Authenticate"))

// If OAuth handler exists, return OAuth-specific error
if c.oauthHandler != nil {
return &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
AuthorizationRequiredError: AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
},
}
}

// No OAuth handler, return base authorization error
return &AuthorizationRequiredError{
ResourceMetadataURL: metadataURL,
}
}

// Handle other error responses
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf(
"notification failed with status %d: %s",
Expand Down
63 changes: 63 additions & 0 deletions client/transport/sse_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,66 @@ func TestSSE_IsOAuthEnabled(t *testing.T) {
t.Errorf("Expected IsOAuthEnabled() to return true")
}
}

func TestSSE_OAuthMetadataDiscovery(t *testing.T) {
// Test that we correctly extract resource_metadata URL from WWW-Authenticate header per RFC9728
const expectedMetadataURL = "https://auth.example.com/.well-known/oauth-protected-resource"

// Create a test server that returns 401 with WWW-Authenticate header
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return 401 with WWW-Authenticate header containing resource_metadata
w.Header().Set("WWW-Authenticate", `Bearer resource_metadata="`+expectedMetadataURL+`"`)
w.WriteHeader(http.StatusUnauthorized)
}))
defer server.Close()

// Create a token store with a valid token so the request reaches the server
// The server will still return 401 to simulate token rejection
tokenStore := NewMemoryTokenStore()
validToken := &Token{
AccessToken: "test-token",
TokenType: "Bearer",
RefreshToken: "refresh-token",
ExpiresIn: 3600,
ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour
}
if err := tokenStore.SaveToken(context.Background(), validToken); err != nil {
t.Fatalf("Failed to save token: %v", err)
}

// Create OAuth config
oauthConfig := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
Scopes: []string{"mcp.read", "mcp.write"},
TokenStore: tokenStore,
PKCEEnabled: true,
}

// Create SSE with OAuth
transport, err := NewSSE(server.URL, WithOAuth(oauthConfig))
if err != nil {
t.Fatalf("Failed to create SSE: %v", err)
}

// Start SSE which will trigger 401
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = transport.Start(ctx)

// Verify the error is an OAuthAuthorizationRequiredError
if err == nil {
t.Fatalf("Expected error, got nil")
}

var oauthErr *OAuthAuthorizationRequiredError
if !errors.As(err, &oauthErr) {
t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err)
}

// Verify the discovered metadata URL was extracted from WWW-Authenticate header
if oauthErr.ResourceMetadataURL != expectedMetadataURL {
t.Errorf("Expected ResourceMetadataURL to be %q, got %q",
expectedMetadataURL, oauthErr.ResourceMetadataURL)
}
}
Loading
Loading