diff --git a/cmd/thv/app/config_registryauth.go b/cmd/thv/app/config_registryauth.go new file mode 100644 index 0000000000..7c4f233bfd --- /dev/null +++ b/cmd/thv/app/config_registryauth.go @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package app + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/stacklok/toolhive/pkg/config" + "github.com/stacklok/toolhive/pkg/registry" +) + +var ( + authIssuer string + authClientID string + authAudience string + authScopes []string +) + +var setRegistryAuthCmd = &cobra.Command{ + Use: "set-registry-auth", + Short: "Configure OAuth/OIDC authentication for the registry", + Long: `Configure OAuth/OIDC authentication for the remote MCP server registry. +PKCE (S256) is always enforced for security. + +The issuer URL is validated via OIDC discovery before saving. + +Examples: + thv config set-registry-auth --issuer https://auth.company.com --client-id toolhive-cli + thv config set-registry-auth \ + --issuer https://auth.company.com --client-id toolhive-cli \ + --audience api://my-registry --scopes openid,profile`, + RunE: setRegistryAuthCmdFunc, +} + +var unsetRegistryAuthCmd = &cobra.Command{ + Use: "unset-registry-auth", + Short: "Remove registry authentication configuration", + Long: "Remove the OAuth/OIDC authentication configuration for the registry.", + RunE: unsetRegistryAuthCmdFunc, +} + +func init() { + setRegistryAuthCmd.Flags().StringVar(&authIssuer, "issuer", "", "OIDC issuer URL (required)") + setRegistryAuthCmd.Flags().StringVar(&authClientID, "client-id", "", "OAuth client ID (required)") + setRegistryAuthCmd.Flags().StringVar(&authAudience, "audience", "", "OAuth audience parameter") + setRegistryAuthCmd.Flags().StringSliceVar( + &authScopes, "scopes", []string{"openid", "offline_access"}, "OAuth scopes", + ) + + _ = setRegistryAuthCmd.MarkFlagRequired("issuer") + _ = setRegistryAuthCmd.MarkFlagRequired("client-id") + + configCmd.AddCommand(setRegistryAuthCmd) + configCmd.AddCommand(unsetRegistryAuthCmd) +} + +func setRegistryAuthCmdFunc(_ *cobra.Command, _ []string) error { + authManager := registry.NewAuthManager(config.NewDefaultProvider()) + + if err := authManager.SetOAuthAuth(authIssuer, authClientID, authAudience, authScopes); err != nil { + return fmt.Errorf("failed to configure registry auth: %w", err) + } + + return nil +} + +func unsetRegistryAuthCmdFunc(_ *cobra.Command, _ []string) error { + authManager := registry.NewAuthManager(config.NewDefaultProvider()) + + if err := authManager.UnsetAuth(); err != nil { + return fmt.Errorf("failed to remove registry auth: %w", err) + } + + return nil +} diff --git a/docs/cli/thv_config.md b/docs/cli/thv_config.md index 00f3b1634d..e0da7df076 100644 --- a/docs/cli/thv_config.md +++ b/docs/cli/thv_config.md @@ -41,9 +41,11 @@ The config command provides subcommands to manage application configuration sett * [thv config set-build-env](thv_config_set-build-env.md) - Set a build environment variable for protocol builds * [thv config set-ca-cert](thv_config_set-ca-cert.md) - Set the default CA certificate for container builds * [thv config set-registry](thv_config_set-registry.md) - Set the MCP server registry +* [thv config set-registry-auth](thv_config_set-registry-auth.md) - Configure OAuth/OIDC authentication for the registry * [thv config unset-build-auth-file](thv_config_unset-build-auth-file.md) - Remove build auth file(s) * [thv config unset-build-env](thv_config_unset-build-env.md) - Remove build environment variable(s) * [thv config unset-ca-cert](thv_config_unset-ca-cert.md) - Remove the configured CA certificate * [thv config unset-registry](thv_config_unset-registry.md) - Remove the configured registry +* [thv config unset-registry-auth](thv_config_unset-registry-auth.md) - Remove registry authentication configuration * [thv config usage-metrics](thv_config_usage-metrics.md) - Enable or disable anonymous usage metrics diff --git a/docs/cli/thv_config_set-registry-auth.md b/docs/cli/thv_config_set-registry-auth.md new file mode 100644 index 0000000000..27235d85e7 --- /dev/null +++ b/docs/cli/thv_config_set-registry-auth.md @@ -0,0 +1,52 @@ +--- +title: thv config set-registry-auth +hide_title: true +description: Reference for ToolHive CLI command `thv config set-registry-auth` +last_update: + author: autogenerated +slug: thv_config_set-registry-auth +mdx: + format: md +--- + +## thv config set-registry-auth + +Configure OAuth/OIDC authentication for the registry + +### Synopsis + +Configure OAuth/OIDC authentication for the remote MCP server registry. +PKCE (S256) is always enforced for security. + +The issuer URL is validated via OIDC discovery before saving. + +Examples: + thv config set-registry-auth --issuer https://auth.company.com --client-id toolhive-cli + thv config set-registry-auth \ + --issuer https://auth.company.com --client-id toolhive-cli \ + --audience api://my-registry --scopes openid,profile + +``` +thv config set-registry-auth [flags] +``` + +### Options + +``` + --audience string OAuth audience parameter + --client-id string OAuth client ID (required) + -h, --help help for set-registry-auth + --issuer string OIDC issuer URL (required) + --scopes strings OAuth scopes (default [openid,offline_access]) +``` + +### Options inherited from parent commands + +``` + --debug Enable debug mode +``` + +### SEE ALSO + +* [thv config](thv_config.md) - Manage application configuration + diff --git a/docs/cli/thv_config_unset-registry-auth.md b/docs/cli/thv_config_unset-registry-auth.md new file mode 100644 index 0000000000..9fa02aa3de --- /dev/null +++ b/docs/cli/thv_config_unset-registry-auth.md @@ -0,0 +1,39 @@ +--- +title: thv config unset-registry-auth +hide_title: true +description: Reference for ToolHive CLI command `thv config unset-registry-auth` +last_update: + author: autogenerated +slug: thv_config_unset-registry-auth +mdx: + format: md +--- + +## thv config unset-registry-auth + +Remove registry authentication configuration + +### Synopsis + +Remove the OAuth/OIDC authentication configuration for the registry. + +``` +thv config unset-registry-auth [flags] +``` + +### Options + +``` + -h, --help help for unset-registry-auth +``` + +### Options inherited from parent commands + +``` + --debug Enable debug mode +``` + +### SEE ALSO + +* [thv config](thv_config.md) - Manage application configuration + diff --git a/pkg/config/config.go b/pkg/config/config.go index 304bb026df..8f742736e5 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -45,6 +45,33 @@ type Config struct { BuildEnvFromShell []string `yaml:"build_env_from_shell,omitempty"` BuildAuthFiles map[string]string `yaml:"build_auth_files,omitempty"` RuntimeConfigs map[string]*templates.RuntimeConfig `yaml:"runtime_configs,omitempty"` + RegistryAuth RegistryAuth `yaml:"registry_auth,omitempty"` +} + +// RegistryAuthTypeOAuth is the auth type for OAuth/OIDC authentication. +const RegistryAuthTypeOAuth = "oauth" + +// RegistryAuth holds authentication configuration for remote registries. +type RegistryAuth struct { + // Type is the authentication type: RegistryAuthTypeOAuth or "" (none). + Type string `yaml:"type,omitempty"` + + // OAuth holds OAuth/OIDC authentication configuration. + OAuth *RegistryOAuthConfig `yaml:"oauth,omitempty"` +} + +// RegistryOAuthConfig holds OAuth/OIDC configuration for registry authentication. +// PKCE (S256) is always enforced per OAuth 2.1 requirements for public clients. +type RegistryOAuthConfig struct { + Issuer string `yaml:"issuer"` + ClientID string `yaml:"client_id"` + Scopes []string `yaml:"scopes,omitempty"` + Audience string `yaml:"audience,omitempty"` + CallbackPort int `yaml:"callback_port,omitempty"` + + // Cached token references for session restoration across CLI invocations. + CachedRefreshTokenRef string `yaml:"cached_refresh_token_ref,omitempty"` + CachedTokenExpiry time.Time `yaml:"cached_token_expiry,omitempty"` } // Secrets contains the settings for secrets management. diff --git a/pkg/registry/api/client.go b/pkg/registry/api/client.go index a63cd05625..4caf9d63d4 100644 --- a/pkg/registry/api/client.go +++ b/pkg/registry/api/client.go @@ -17,6 +17,7 @@ import ( "gopkg.in/yaml.v3" "github.com/stacklok/toolhive/pkg/networking" + "github.com/stacklok/toolhive/pkg/registry/auth" "github.com/stacklok/toolhive/pkg/versions" ) @@ -56,8 +57,10 @@ type mcpRegistryClient struct { userAgent string } -// NewClient creates a new MCP Registry API client -func NewClient(baseURL string, allowPrivateIp bool) (Client, error) { +// NewClient creates a new MCP Registry API client. +// If tokenSource is non-nil, the HTTP client transport will be wrapped to inject +// Bearer tokens into all requests. +func NewClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (Client, error) { // Build HTTP client with security controls // If private IPs are allowed, also allow HTTP (for localhost testing) builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp) @@ -69,6 +72,9 @@ func NewClient(baseURL string, allowPrivateIp bool) (Client, error) { return nil, fmt.Errorf("failed to build HTTP client: %w", err) } + // Wrap transport with auth if token source is provided + httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource) + // Ensure base URL doesn't have trailing slash if baseURL[len(baseURL)-1] == '/' { baseURL = baseURL[:len(baseURL)-1] diff --git a/pkg/registry/auth/auth.go b/pkg/registry/auth/auth.go new file mode 100644 index 0000000000..ff0d8c0c24 --- /dev/null +++ b/pkg/registry/auth/auth.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package auth provides authentication support for MCP server registries. +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + + "github.com/stacklok/toolhive/pkg/config" + "github.com/stacklok/toolhive/pkg/secrets" +) + +// ErrRegistryAuthRequired is returned when registry authentication is required +// but no cached tokens are available in a non-interactive context. +var ErrRegistryAuthRequired = errors.New("registry authentication required: run 'thv registry login' to authenticate") + +// TokenSource provides authentication tokens for registry HTTP requests. +type TokenSource interface { + // Token returns a valid access token string, or empty string if no auth. + // Implementations should handle token refresh transparently. + Token(ctx context.Context) (string, error) +} + +// NewTokenSource creates a TokenSource from registry OAuth configuration. +// Returns nil, nil if oauth config is nil (no auth required). +// The registryURL is used to derive a unique secret key for token storage. +// The secrets provider may be nil if secret storage is not available. +// The interactive flag controls whether browser-based OAuth flows are allowed. +func NewTokenSource( + cfg *config.RegistryOAuthConfig, + registryURL string, + secretsProvider secrets.Provider, + interactive bool, +) (TokenSource, error) { + if cfg == nil { + return nil, nil + } + + return &oauthTokenSource{ + oauthCfg: cfg, + registryURL: registryURL, + secretsProvider: secretsProvider, + interactive: interactive, + }, nil +} + +// DeriveSecretKey computes the secret key for storing a registry's refresh token. +// The key follows the formula: REGISTRY_OAUTH_<8 hex chars> +// where the hex is derived from sha256(registryURL + "\x00" + issuer)[:4]. +func DeriveSecretKey(registryURL, issuer string) string { + h := sha256.Sum256([]byte(registryURL + "\x00" + issuer)) + return "REGISTRY_OAUTH_" + hex.EncodeToString(h[:4]) +} diff --git a/pkg/registry/auth/auth_test.go b/pkg/registry/auth/auth_test.go new file mode 100644 index 0000000000..eee3833f5d --- /dev/null +++ b/pkg/registry/auth/auth_test.go @@ -0,0 +1,592 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/oauth2" + + "github.com/stacklok/toolhive/pkg/config" + "github.com/stacklok/toolhive/pkg/secrets" + secretsmocks "github.com/stacklok/toolhive/pkg/secrets/mocks" +) + +func TestDeriveSecretKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + registryURL string + issuer string + }{ + { + name: "typical registry and issuer", + registryURL: "https://registry.example.com", + issuer: "https://auth.example.com", + }, + { + name: "empty strings", + registryURL: "", + issuer: "", + }, + { + name: "empty issuer", + registryURL: "https://registry.example.com", + issuer: "", + }, + { + name: "empty registry URL", + registryURL: "", + issuer: "https://auth.example.com", + }, + { + name: "localhost registry", + registryURL: "http://localhost:5000", + issuer: "http://localhost:8080", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + key := DeriveSecretKey(tt.registryURL, tt.issuer) + + // Must start with the correct prefix + require.True(t, len(key) > len("REGISTRY_OAUTH_"), "key too short") + require.Equal(t, "REGISTRY_OAUTH_", key[:len("REGISTRY_OAUTH_")]) + + // The suffix must be exactly 8 hex characters (4 bytes of sha256) + suffix := key[len("REGISTRY_OAUTH_"):] + require.Len(t, suffix, 8, "hex suffix must be exactly 8 characters") + + // Verify each character is a valid hex character + for _, c := range suffix { + require.True(t, + (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'), + "suffix character %q is not a lowercase hex digit", c, + ) + } + + // Verify the derivation formula: sha256(registryURL + "\x00" + issuer)[:4] + h := sha256.Sum256([]byte(tt.registryURL + "\x00" + tt.issuer)) + expected := "REGISTRY_OAUTH_" + hex.EncodeToString(h[:4]) + require.Equal(t, expected, key) + }) + } +} + +func TestDeriveSecretKey_Deterministic(t *testing.T) { + t.Parallel() + + registryURL := "https://registry.example.com" + issuer := "https://auth.example.com" + + key1 := DeriveSecretKey(registryURL, issuer) + key2 := DeriveSecretKey(registryURL, issuer) + + require.Equal(t, key1, key2, "DeriveSecretKey must be deterministic") +} + +func TestDeriveSecretKey_UniquePerInputCombination(t *testing.T) { + t.Parallel() + + combinations := []struct { + registryURL string + issuer string + }{ + {"https://registry-a.example.com", "https://auth.example.com"}, + {"https://registry-b.example.com", "https://auth.example.com"}, + {"https://registry-a.example.com", "https://auth-other.example.com"}, + {"https://registry-b.example.com", "https://auth-other.example.com"}, + } + + keys := make(map[string]struct{}, len(combinations)) + for _, combo := range combinations { + key := DeriveSecretKey(combo.registryURL, combo.issuer) + _, alreadySeen := keys[key] + require.False(t, alreadySeen, + "DeriveSecretKey produced a duplicate key for registryURL=%q issuer=%q: %q", + combo.registryURL, combo.issuer, key, + ) + keys[key] = struct{}{} + } +} + +func TestDeriveSecretKey_NullByteIsolatesSegments(t *testing.T) { + t.Parallel() + + // Without the null-byte separator these two pairs would hash identically: + // ("ab", "c") and ("a", "bc") both concatenate to "abc". + // The separator prevents that collision. + key1 := DeriveSecretKey("ab", "c") + key2 := DeriveSecretKey("a", "bc") + + require.NotEqual(t, key1, key2, + "keys must differ when registry URL and issuer are split differently") +} + +func TestNewTokenSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.RegistryOAuthConfig + wantNil bool + wantErrNil bool + }{ + { + name: "nil config returns nil source and nil error", + cfg: nil, + wantNil: true, + wantErrNil: true, + }, + { + name: "non-nil config returns non-nil source", + cfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client-id", + }, + wantNil: false, + wantErrNil: true, + }, + { + name: "config with scopes and audience returns non-nil source", + cfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client-id", + Scopes: []string{"openid", "profile"}, + Audience: "api://my-api", + }, + wantNil: false, + wantErrNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + src, err := NewTokenSource(tt.cfg, "https://registry.example.com", nil, false) + + if tt.wantErrNil { + require.NoError(t, err) + } else { + require.Error(t, err) + } + + if tt.wantNil { + require.Nil(t, src) + } else { + require.NotNil(t, src) + } + }) + } +} + +func TestOAuthTokenSource_Token_NonInteractiveNoCache(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + buildProvider func(ctrl *gomock.Controller) *secretsmocks.MockProvider + }{ + { + name: "non-interactive with no secrets provider returns ErrRegistryAuthRequired", + buildProvider: nil, // nil secrets provider + }, + { + name: "non-interactive with secrets provider error returns ErrRegistryAuthRequired", + buildProvider: func(ctrl *gomock.Controller) *secretsmocks.MockProvider { + mock := secretsmocks.NewMockProvider(ctrl) + mock.EXPECT(). + GetSecret(gomock.Any(), gomock.Any()). + Return("", errors.New("connection refused")) + return mock + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + var provider secrets.Provider + if tt.buildProvider != nil { + provider = tt.buildProvider(ctrl) + } + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + registryURL: "https://registry.example.com", + secretsProvider: provider, + interactive: false, + } + + _, err := src.Token(context.Background()) + + require.Error(t, err) + require.True(t, errors.Is(err, ErrRegistryAuthRequired), + "expected ErrRegistryAuthRequired, got: %v", err) + }) + } +} + +func TestOAuthTokenSource_RefreshTokenKey(t *testing.T) { + t.Parallel() + + const registryURL = "https://registry.example.com" + const issuer = "https://auth.example.com" + + tests := []struct { + name string + cachedRefreshTokenRef string + wantKey string + }{ + { + name: "returns CachedRefreshTokenRef when set", + cachedRefreshTokenRef: "my-cached-ref-key", + wantKey: "my-cached-ref-key", + }, + { + name: "falls back to DeriveSecretKey when CachedRefreshTokenRef is empty", + cachedRefreshTokenRef: "", + wantKey: DeriveSecretKey(registryURL, issuer), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: issuer, + ClientID: "test-client", + CachedRefreshTokenRef: tt.cachedRefreshTokenRef, + }, + registryURL: registryURL, + } + + got := src.refreshTokenKey() + require.Equal(t, tt.wantKey, got) + }) + } +} + +// mockOAuth2TokenSource is a test double for oauth2.TokenSource (no-context variant). +type mockOAuth2TokenSource struct { + token *oauth2.Token + err error +} + +func (m *mockOAuth2TokenSource) Token() (*oauth2.Token, error) { + return m.token, m.err +} + +// newOIDCTestServer starts an httptest server that handles the two well-known +// OIDC discovery paths used by CreateOAuthConfigFromOIDC. It returns the server +// and shuts it down automatically when the test completes. +func newOIDCTestServer(t *testing.T) *httptest.Server { + t.Helper() + + var srv *httptest.Server + mux := http.NewServeMux() + + handler := func(w http.ResponseWriter, _ *http.Request) { + issuer := srv.URL + doc := map[string]string{ + "issuer": issuer, + "authorization_endpoint": issuer + "/authorize", + "token_endpoint": issuer + "/token", + "jwks_uri": issuer + "/jwks", + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(doc); err != nil { + http.Error(w, "encode error", http.StatusInternalServerError) + } + } + + // CreateOAuthConfigFromOIDC tries both OIDC and OAuth well-known paths. + mux.HandleFunc("/.well-known/openid-configuration", handler) + mux.HandleFunc("/.well-known/oauth-authorization-server", handler) + + srv = httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv +} + +// TestOAuthTokenSource_Token_InMemoryCacheHit verifies that when the in-memory +// token source holds a valid, non-expired token, Token() returns it immediately +// without consulting the secrets provider. +func TestOAuthTokenSource_Token_InMemoryCacheHit(t *testing.T) { + t.Parallel() + + validToken := &oauth2.Token{ + AccessToken: "cached-access-token", + Expiry: time.Now().Add(time.Hour), + TokenType: "Bearer", + } + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + registryURL: "https://registry.example.com", + secretsProvider: nil, // should never be called + interactive: false, + tokenSource: &mockOAuth2TokenSource{token: validToken}, + } + + got, err := src.Token(context.Background()) + require.NoError(t, err) + require.Equal(t, "cached-access-token", got) +} + +// TestOAuthTokenSource_Token_InMemoryCacheExpiredFallsThrough verifies that when +// the in-memory token source returns an expired token (past Expiry), Token() clears +// the cache and falls through to return ErrRegistryAuthRequired in non-interactive mode +// without a secrets provider. +func TestOAuthTokenSource_Token_InMemoryCacheExpiredFallsThrough(t *testing.T) { + t.Parallel() + + expiredToken := &oauth2.Token{ + AccessToken: "expired-token", + Expiry: time.Now().Add(-time.Hour), // already expired + TokenType: "Bearer", + } + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + registryURL: "https://registry.example.com", + secretsProvider: nil, + interactive: false, + tokenSource: &mockOAuth2TokenSource{token: expiredToken}, + } + + _, err := src.Token(context.Background()) + require.Error(t, err) + require.True(t, errors.Is(err, ErrRegistryAuthRequired), + "expected ErrRegistryAuthRequired, got: %v", err) + // In-memory cache should have been cleared. + require.Nil(t, src.tokenSource) +} + +// TestOAuthTokenSource_Token_InMemoryCacheErrorFallsThrough verifies that when +// the in-memory token source returns an error, Token() clears the cache and falls +// through to return ErrRegistryAuthRequired in non-interactive mode. +func TestOAuthTokenSource_Token_InMemoryCacheErrorFallsThrough(t *testing.T) { + t.Parallel() + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + registryURL: "https://registry.example.com", + secretsProvider: nil, + interactive: false, + tokenSource: &mockOAuth2TokenSource{err: errors.New("token refresh failed")}, + } + + _, err := src.Token(context.Background()) + require.Error(t, err) + require.True(t, errors.Is(err, ErrRegistryAuthRequired), + "expected ErrRegistryAuthRequired, got: %v", err) + // In-memory cache should have been cleared. + require.Nil(t, src.tokenSource) +} + +// TestOAuthTokenSource_TryRestoreFromCache_NilProvider verifies that +// tryRestoreFromCache returns an error immediately when no secrets provider is set. +func TestOAuthTokenSource_TryRestoreFromCache_NilProvider(t *testing.T) { + t.Parallel() + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + registryURL: "https://registry.example.com", + secretsProvider: nil, // genuine nil interface — triggers the nil guard in tryRestoreFromCache + } + + err := src.tryRestoreFromCache(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "no secrets provider available") +} + +// TestOAuthTokenSource_TryRestoreFromCache covers the error paths in tryRestoreFromCache +// that involve the secrets provider returning errors or empty values. +func TestOAuthTokenSource_TryRestoreFromCache(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + buildProvider func(ctrl *gomock.Controller) *secretsmocks.MockProvider + wantErrContains string + }{ + { + name: "GetSecret returns error", + buildProvider: func(ctrl *gomock.Controller) *secretsmocks.MockProvider { + mock := secretsmocks.NewMockProvider(ctrl) + mock.EXPECT(). + GetSecret(gomock.Any(), gomock.Any()). + Return("", errors.New("vault unavailable")) + return mock + }, + wantErrContains: "failed to get cached refresh token", + }, + { + name: "GetSecret returns empty string", + buildProvider: func(ctrl *gomock.Controller) *secretsmocks.MockProvider { + mock := secretsmocks.NewMockProvider(ctrl) + mock.EXPECT(). + GetSecret(gomock.Any(), gomock.Any()). + Return("", nil) + return mock + }, + wantErrContains: "no cached refresh token found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + provider := tt.buildProvider(ctrl) + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + registryURL: "https://registry.example.com", + secretsProvider: provider, + } + + err := src.tryRestoreFromCache(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErrContains) + }) + } +} + +// TestOAuthTokenSource_TryRestoreFromCache_WithOIDCServer verifies that +// tryRestoreFromCache succeeds when a valid refresh token is found in the secrets +// provider and an OIDC discovery document is available. +func TestOAuthTokenSource_TryRestoreFromCache_WithOIDCServer(t *testing.T) { + t.Parallel() + + srv := newOIDCTestServer(t) + + ctrl := gomock.NewController(t) + mockProvider := secretsmocks.NewMockProvider(ctrl) + mockProvider.EXPECT(). + GetSecret(gomock.Any(), gomock.Any()). + Return("my-refresh-token", nil) + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: srv.URL, + ClientID: "test-client", + }, + registryURL: "https://registry.example.com", + secretsProvider: mockProvider, + } + + err := src.tryRestoreFromCache(context.Background()) + require.NoError(t, err) + // We only verify the tokenSource was set; actually exchanging the refresh + // token requires a real /token endpoint and is covered by integration tests. + require.NotNil(t, src.tokenSource, + "tokenSource must be set after successful cache restoration") +} + +// TestOAuthTokenSource_CreateTokenPersister covers the createTokenPersister helper. +func TestOAuthTokenSource_CreateTokenPersister(t *testing.T) { + t.Parallel() + + const refreshTokenKey = "REGISTRY_OAUTH_testkey" + const refreshTokenValue = "rt-abc123" + + tests := []struct { + name string + setupMock func(mock *secretsmocks.MockProvider) + wantErr bool + wantErrSubstr string + }{ + { + name: "SetSecret succeeds", + setupMock: func(mock *secretsmocks.MockProvider) { + mock.EXPECT(). + SetSecret(gomock.Any(), refreshTokenKey, refreshTokenValue). + Return(nil) + }, + wantErr: false, + }, + { + name: "SetSecret returns error", + setupMock: func(mock *secretsmocks.MockProvider) { + mock.EXPECT(). + SetSecret(gomock.Any(), refreshTokenKey, refreshTokenValue). + Return(fmt.Errorf("storage full")) + }, + wantErr: true, + wantErrSubstr: "failed to persist refresh token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockProvider := secretsmocks.NewMockProvider(ctrl) + tt.setupMock(mockProvider) + + src := &oauthTokenSource{ + oauthCfg: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + registryURL: "https://registry.example.com", + secretsProvider: mockProvider, + } + + persister := src.createTokenPersister(refreshTokenKey) + require.NotNil(t, persister) + + // Call the persister function — expiry value does not affect SetSecret behaviour. + err := persister(refreshTokenValue, time.Now().Add(time.Hour)) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErrSubstr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/registry/auth/oauth_token_source.go b/pkg/registry/auth/oauth_token_source.go new file mode 100644 index 0000000000..de528022ce --- /dev/null +++ b/pkg/registry/auth/oauth_token_source.go @@ -0,0 +1,218 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + "golang.org/x/oauth2" + + "github.com/stacklok/toolhive/pkg/auth/oauth" + "github.com/stacklok/toolhive/pkg/auth/remote" + "github.com/stacklok/toolhive/pkg/config" + "github.com/stacklok/toolhive/pkg/secrets" +) + +// oauthTokenSource implements TokenSource using OAuth/OIDC browser-based flow. +type oauthTokenSource struct { + oauthCfg *config.RegistryOAuthConfig + registryURL string + secretsProvider secrets.Provider + interactive bool + mu sync.Mutex + tokenSource oauth2.TokenSource +} + +// Token returns a valid access token string, handling refresh and browser flow as needed. +func (o *oauthTokenSource) Token(ctx context.Context) (string, error) { + o.mu.Lock() + defer o.mu.Unlock() + + // Try cached token source first (auto-refreshes) + if o.tokenSource != nil { + token, err := o.tokenSource.Token() + if err == nil && token.Valid() { + return token.AccessToken, nil + } + // Token source failed or expired, try to restore or re-authenticate + o.tokenSource = nil + } + + // Try to restore from secrets manager + if err := o.tryRestoreFromCache(ctx); err == nil && o.tokenSource != nil { + token, err := o.tokenSource.Token() + if err == nil && token.Valid() { + return token.AccessToken, nil + } + o.tokenSource = nil + } + + // In non-interactive mode, return error instead of triggering browser flow + if !o.interactive { + return "", ErrRegistryAuthRequired + } + + // Trigger browser-based OAuth flow + if err := o.performOAuthFlow(ctx); err != nil { + return "", fmt.Errorf("oauth flow failed: %w", err) + } + + token, err := o.tokenSource.Token() + if err != nil { + return "", fmt.Errorf("failed to get token after oauth flow: %w", err) + } + + return token.AccessToken, nil +} + +// tryRestoreFromCache attempts to restore token source from cached refresh token. +func (o *oauthTokenSource) tryRestoreFromCache(ctx context.Context) error { + if o.secretsProvider == nil { + return fmt.Errorf("no secrets provider available") + } + + refreshTokenKey := o.refreshTokenKey() + + refreshToken, err := o.secretsProvider.GetSecret(ctx, refreshTokenKey) + if err != nil { + return fmt.Errorf("failed to get cached refresh token: %w", err) + } + if refreshToken == "" { + return fmt.Errorf("no cached refresh token found") + } + + oauth2Cfg, err := o.buildOAuth2Config(ctx) + if err != nil { + return fmt.Errorf("failed to create oauth2 config: %w", err) + } + + o.tokenSource = remote.CreateTokenSourceFromCached(oauth2Cfg, refreshToken, o.oauthCfg.CachedTokenExpiry) + return nil +} + +// performOAuthFlow executes the browser-based OAuth flow and persists the result. +func (o *oauthTokenSource) performOAuthFlow(ctx context.Context) error { + oauthCfg, err := o.buildOAuthFlowConfig(ctx) + if err != nil { + return fmt.Errorf("failed to create oauth config: %w", err) + } + + flow, err := oauth.NewFlow(oauthCfg) + if err != nil { + return fmt.Errorf("failed to create oauth flow: %w", err) + } + + tokenResult, err := flow.Start(ctx, false) + if err != nil { + return fmt.Errorf("oauth flow start failed: %w", err) + } + + baseTokenSource := flow.TokenSource() + + // Wrap with persisting token source if secrets provider available + if o.secretsProvider == nil { + slog.Debug("No secrets provider available, refresh token will not be persisted") + } else { + refreshTokenKey := o.refreshTokenKey() + baseTokenSource = remote.NewPersistingTokenSource( + baseTokenSource, + o.createTokenPersister(refreshTokenKey), + ) + + // Persist initial refresh token + if tokenResult.RefreshToken != "" { + if err := o.secretsProvider.SetSecret(ctx, refreshTokenKey, tokenResult.RefreshToken); err != nil { + slog.Warn("Failed to persist initial refresh token", "error", err) + } else { + slog.Debug("Persisted initial refresh token", "key", refreshTokenKey) + } + } else { + slog.Debug("OAuth provider did not return a refresh token, token will not be persisted") + } + + // Update config with token ref + o.updateConfigTokenRef(refreshTokenKey, tokenResult.Expiry) + } + + o.tokenSource = baseTokenSource + return nil +} + +// buildOAuthFlowConfig creates an oauth.Config for the browser-based flow via OIDC discovery. +// PKCE is always enabled (S256) per OAuth 2.1 requirements for public clients. +func (o *oauthTokenSource) buildOAuthFlowConfig(ctx context.Context) (*oauth.Config, error) { + callbackPort := o.oauthCfg.CallbackPort + if callbackPort == 0 { + callbackPort = remote.DefaultCallbackPort + } + + return oauth.CreateOAuthConfigFromOIDC( + ctx, + o.oauthCfg.Issuer, + o.oauthCfg.ClientID, + "", // Public client — no client secret (PKCE is used instead) + o.oauthCfg.Scopes, + true, // Always use PKCE (S256) + callbackPort, + o.oauthCfg.Audience, + ) +} + +// buildOAuth2Config creates an oauth2.Config for token refresh via OIDC discovery. +func (o *oauthTokenSource) buildOAuth2Config(ctx context.Context) (*oauth2.Config, error) { + oauthCfg, err := o.buildOAuthFlowConfig(ctx) + if err != nil { + return nil, err + } + + return &oauth2.Config{ + ClientID: oauthCfg.ClientID, + ClientSecret: oauthCfg.ClientSecret, + Scopes: oauthCfg.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: oauthCfg.AuthURL, + TokenURL: oauthCfg.TokenURL, + }, + }, nil +} + +// createTokenPersister returns a remote.TokenPersister function that stores +// refresh tokens in the secrets manager. +func (o *oauthTokenSource) createTokenPersister(refreshTokenKey string) remote.TokenPersister { + return func(refreshToken string, expiry time.Time) error { + // The TokenPersister signature does not accept a context, and this callback + // is invoked asynchronously during token refresh, so we use Background. + ctx := context.Background() + if err := o.secretsProvider.SetSecret(ctx, refreshTokenKey, refreshToken); err != nil { + return fmt.Errorf("failed to persist refresh token: %w", err) + } + o.updateConfigTokenRef(refreshTokenKey, expiry) + return nil + } +} + +// updateConfigTokenRef updates the config with the refresh token reference and expiry. +func (*oauthTokenSource) updateConfigTokenRef(refreshTokenKey string, expiry time.Time) { + if err := config.UpdateConfig(func(cfg *config.Config) { + if cfg.RegistryAuth.OAuth != nil { + cfg.RegistryAuth.OAuth.CachedRefreshTokenRef = refreshTokenKey + cfg.RegistryAuth.OAuth.CachedTokenExpiry = expiry + } + }); err != nil { + slog.Warn("Failed to update config with token reference", "error", err) + } +} + +// refreshTokenKey returns the key used to store the refresh token in the secrets manager. +// Uses the RFC-specified derivation formula if no cached reference exists. +func (o *oauthTokenSource) refreshTokenKey() string { + if o.oauthCfg.CachedRefreshTokenRef != "" { + return o.oauthCfg.CachedRefreshTokenRef + } + return DeriveSecretKey(o.registryURL, o.oauthCfg.Issuer) +} diff --git a/pkg/registry/auth/transport.go b/pkg/registry/auth/transport.go new file mode 100644 index 0000000000..44ac0d614a --- /dev/null +++ b/pkg/registry/auth/transport.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "fmt" + "net/http" +) + +// Transport wraps an http.RoundTripper to add OAuth authentication headers. +type Transport struct { + Base http.RoundTripper + Source TokenSource +} + +// RoundTrip executes a single HTTP transaction with authentication. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + if t.Source == nil { + return t.base().RoundTrip(req) + } + + // Get token from source + token, err := t.Source.Token(req.Context()) + if err != nil { + return nil, fmt.Errorf("failed to get auth token: %w", err) + } + + // If token is empty, pass through without auth + if token == "" { + return t.base().RoundTrip(req) + } + + // Clone request and add authorization header + clonedReq := req.Clone(req.Context()) + clonedReq.Header.Set("Authorization", "Bearer "+token) + + return t.base().RoundTrip(clonedReq) +} + +// base returns the base RoundTripper, defaulting to http.DefaultTransport. +func (t *Transport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +// WrapTransport wraps an http.RoundTripper with authentication support. +// If source is nil, returns the base transport unchanged. +func WrapTransport(base http.RoundTripper, source TokenSource) http.RoundTripper { + if source == nil { + return base + } + return &Transport{ + Base: base, + Source: source, + } +} diff --git a/pkg/registry/auth/transport_test.go b/pkg/registry/auth/transport_test.go new file mode 100644 index 0000000000..8906d6ae12 --- /dev/null +++ b/pkg/registry/auth/transport_test.go @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockTokenSource is a test double for the TokenSource interface. +type mockTokenSource struct { + token string + err error +} + +func (m *mockTokenSource) Token(_ context.Context) (string, error) { + return m.token, m.err +} + +func TestWrapTransport(t *testing.T) { + t.Parallel() + + base := http.DefaultTransport + + tests := []struct { + name string + source TokenSource + wantSameAsBase bool + }{ + { + name: "nil source returns base transport unchanged", + source: nil, + wantSameAsBase: true, + }, + { + name: "non-nil source returns wrapped transport", + source: &mockTokenSource{token: "tok"}, + wantSameAsBase: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := WrapTransport(base, tt.source) + + if tt.wantSameAsBase { + require.Equal(t, base, got, "expected base transport to be returned unchanged") + } else { + require.NotEqual(t, base, got, "expected a wrapped transport to be returned") + wrapped, ok := got.(*Transport) + require.True(t, ok, "wrapped transport should be *Transport") + require.Equal(t, base, wrapped.Base) + require.Equal(t, tt.source, wrapped.Source) + } + }) + } +} + +func TestTransport_RoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + source TokenSource + wantAuthHeader string + wantErr bool + wantErrContains string + }{ + { + name: "nil source passes through without auth header", + source: nil, + wantAuthHeader: "", + wantErr: false, + }, + { + name: "source returns token adds Bearer header", + source: &mockTokenSource{token: "my-access-token"}, + wantAuthHeader: "Bearer my-access-token", + wantErr: false, + }, + { + name: "source returns empty string passes through without auth header", + source: &mockTokenSource{token: ""}, + wantAuthHeader: "", + wantErr: false, + }, + { + name: "source returns error propagates error", + source: &mockTokenSource{err: errors.New("token fetch failed")}, + wantErr: true, + wantErrContains: "failed to get auth token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Record the Authorization header received by the server. + var receivedAuthHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + transport := &Transport{ + Base: srv.Client().Transport, + Source: tt.source, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + + if tt.wantErr { + require.Error(t, err) + if tt.wantErrContains != "" { + require.ErrorContains(t, err, tt.wantErrContains) + } + require.Nil(t, resp) + return + } + + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + assert.Equal(t, tt.wantAuthHeader, receivedAuthHeader) + }) + } +} + +func TestTransport_RoundTrip_DoesNotMutateOriginalRequest(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + transport := &Transport{ + Base: srv.Client().Transport, + Source: &mockTokenSource{token: "secret-token"}, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + require.NoError(t, err) + + // Capture the original header state before the round-trip. + originalAuth := req.Header.Get("Authorization") + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // The original request must not have been mutated. + assert.Equal(t, originalAuth, req.Header.Get("Authorization"), + "RoundTrip must not modify the original request's headers") +} + +func TestTransport_base_DefaultsToHTTPDefaultTransport(t *testing.T) { + t.Parallel() + + tr := &Transport{} + require.Equal(t, http.DefaultTransport, tr.base(), + "base() should return http.DefaultTransport when Base is nil") + + custom := &http.Transport{} + tr.Base = custom + require.Equal(t, custom, tr.base(), + "base() should return the configured Base transport when set") +} diff --git a/pkg/registry/auth_manager.go b/pkg/registry/auth_manager.go new file mode 100644 index 0000000000..5783b0659a --- /dev/null +++ b/pkg/registry/auth_manager.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "context" + "fmt" + "time" + + "github.com/stacklok/toolhive/pkg/auth/oauth" + "github.com/stacklok/toolhive/pkg/auth/remote" + "github.com/stacklok/toolhive/pkg/config" +) + +// AuthManager provides operations for managing registry authentication configuration. +type AuthManager interface { + // SetOAuthAuth configures OAuth/OIDC authentication for the registry. + // Validates the OIDC issuer before saving configuration. + SetOAuthAuth(issuer, clientID, audience string, scopes []string) error + + // UnsetAuth removes registry authentication configuration. + UnsetAuth() error + + // GetAuthInfo returns the current auth type and whether tokens are cached. + GetAuthInfo() (authType string, hasCachedTokens bool) +} + +// DefaultAuthManager is the default implementation of AuthManager. +type DefaultAuthManager struct { + provider config.Provider +} + +// NewAuthManager creates a new registry auth manager using the given config provider. +func NewAuthManager(provider config.Provider) AuthManager { + return &DefaultAuthManager{ + provider: provider, + } +} + +// SetOAuthAuth configures OAuth/OIDC authentication for the registry. +// PKCE (S256) is always enforced and not configurable. +func (c *DefaultAuthManager) SetOAuthAuth(issuer, clientID, audience string, scopes []string) error { + // Validate OIDC issuer by attempting discovery + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + _, err := oauth.DiscoverOIDCEndpoints(ctx, issuer) + if err != nil { + return fmt.Errorf("OIDC discovery failed for issuer %s: %w", issuer, err) + } + + return c.provider.UpdateConfig(func(cfg *config.Config) { + cfg.RegistryAuth = config.RegistryAuth{ + Type: config.RegistryAuthTypeOAuth, + OAuth: &config.RegistryOAuthConfig{ + Issuer: issuer, + ClientID: clientID, + Scopes: scopes, + Audience: audience, + CallbackPort: remote.DefaultCallbackPort, + }, + } + }) +} + +// UnsetAuth removes registry authentication configuration. +func (c *DefaultAuthManager) UnsetAuth() error { + return c.provider.UpdateConfig(func(cfg *config.Config) { + cfg.RegistryAuth = config.RegistryAuth{} + }) +} + +// GetAuthInfo returns the current auth type and whether tokens are cached. +func (c *DefaultAuthManager) GetAuthInfo() (string, bool) { + cfg := c.provider.GetConfig() + if cfg.RegistryAuth.Type == "" { + return "", false + } + + hasCachedTokens := cfg.RegistryAuth.OAuth != nil && + cfg.RegistryAuth.OAuth.CachedRefreshTokenRef != "" + + return cfg.RegistryAuth.Type, hasCachedTokens +} diff --git a/pkg/registry/auth_manager_test.go b/pkg/registry/auth_manager_test.go new file mode 100644 index 0000000000..c237c2597b --- /dev/null +++ b/pkg/registry/auth_manager_test.go @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/config" + configmocks "github.com/stacklok/toolhive/pkg/config/mocks" +) + +func TestDefaultAuthManager_UnsetAuth(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + updateErr error + wantErr bool + }{ + { + name: "clears registry auth config on success", + updateErr: nil, + wantErr: false, + }, + { + name: "propagates error from UpdateConfig", + updateErr: errUpdateFailed, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockProvider := configmocks.NewMockProvider(ctrl) + + // Capture the update function and verify it zeroes RegistryAuth. + mockProvider.EXPECT(). + UpdateConfig(gomock.Any()). + DoAndReturn(func(fn func(*config.Config)) error { + if tt.updateErr != nil { + return tt.updateErr + } + cfg := &config.Config{ + RegistryAuth: config.RegistryAuth{ + Type: config.RegistryAuthTypeOAuth, + OAuth: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + }, + } + fn(cfg) + // After the update function runs, RegistryAuth must be zero. + require.Equal(t, config.RegistryAuth{}, cfg.RegistryAuth) + return nil + }) + + mgr := NewAuthManager(mockProvider) + err := mgr.UnsetAuth() + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDefaultAuthManager_GetAuthInfo(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + registryAuth config.RegistryAuth + wantAuthType string + wantHasCachedToks bool + }{ + { + name: "returns empty when no auth configured", + registryAuth: config.RegistryAuth{}, + wantAuthType: "", + wantHasCachedToks: false, + }, + { + name: "returns oauth type without cached tokens when OAuth section has no ref", + registryAuth: config.RegistryAuth{ + Type: config.RegistryAuthTypeOAuth, + OAuth: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + }, + wantAuthType: config.RegistryAuthTypeOAuth, + wantHasCachedToks: false, + }, + { + name: "returns oauth type with cached tokens when CachedRefreshTokenRef is set", + registryAuth: config.RegistryAuth{ + Type: config.RegistryAuthTypeOAuth, + OAuth: &config.RegistryOAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + CachedRefreshTokenRef: "REGISTRY_OAUTH_aabbccdd", + }, + }, + wantAuthType: config.RegistryAuthTypeOAuth, + wantHasCachedToks: true, + }, + { + name: "returns oauth type without cached tokens when OAuth section is nil", + registryAuth: config.RegistryAuth{ + Type: config.RegistryAuthTypeOAuth, + OAuth: nil, + }, + wantAuthType: config.RegistryAuthTypeOAuth, + wantHasCachedToks: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockProvider := configmocks.NewMockProvider(ctrl) + + mockProvider.EXPECT(). + GetConfig(). + Return(&config.Config{RegistryAuth: tt.registryAuth}) + + mgr := NewAuthManager(mockProvider) + authType, hasCachedToks := mgr.GetAuthInfo() + + require.Equal(t, tt.wantAuthType, authType) + require.Equal(t, tt.wantHasCachedToks, hasCachedToks) + }) + } +} + +// errUpdateFailed is a sentinel error for testing UpdateConfig failure paths. +var errUpdateFailed = errSentinel("UpdateConfig failed") + +type errSentinel string + +func (e errSentinel) Error() string { return string(e) } diff --git a/pkg/registry/factory.go b/pkg/registry/factory.go index d8a0bbc8c2..82358e6c50 100644 --- a/pkg/registry/factory.go +++ b/pkg/registry/factory.go @@ -8,9 +8,12 @@ package registry import ( "fmt" + "log/slog" "sync" "github.com/stacklok/toolhive/pkg/config" + "github.com/stacklok/toolhive/pkg/registry/auth" + "github.com/stacklok/toolhive/pkg/secrets" ) var ( @@ -24,17 +27,40 @@ var ( defaultProviderMu sync.Mutex ) +// ProviderOption configures optional behavior for NewRegistryProvider. +type ProviderOption func(*providerOptions) + +type providerOptions struct { + interactive bool +} + +// WithInteractive sets whether browser-based OAuth flows are allowed. +// Defaults to true (CLI mode). Pass false for headless/serve mode. +func WithInteractive(interactive bool) ProviderOption { + return func(o *providerOptions) { o.interactive = interactive } +} + // NewRegistryProvider creates a new registry provider based on the configuration. // Returns an error if a custom registry is configured but cannot be reached. -func NewRegistryProvider(cfg *config.Config) (Provider, error) { +func NewRegistryProvider(cfg *config.Config, opts ...ProviderOption) (Provider, error) { + options := &providerOptions{interactive: true} + for _, opt := range opts { + opt(options) + } + // Priority order: // 1. API URL (if configured) - for live MCP Registry API queries // 2. Remote URL (if configured) - for static JSON over HTTP // 3. Local file path (if configured) - for local JSON file // 4. Default - embedded registry data + // Create token source if registry auth is configured. + // Auth only applies to API registry providers; remote URL and local file + // providers do not support authentication. + tokenSource := resolveTokenSource(cfg, options.interactive) + if cfg != nil && len(cfg.RegistryApiUrl) > 0 { - provider, err := NewCachedAPIRegistryProvider(cfg.RegistryApiUrl, cfg.AllowPrivateRegistryIp, true) + provider, err := NewCachedAPIRegistryProvider(cfg.RegistryApiUrl, cfg.AllowPrivateRegistryIp, true, tokenSource) if err != nil { return nil, fmt.Errorf("custom registry API at %s is not reachable: %w", cfg.RegistryApiUrl, err) } @@ -88,3 +114,36 @@ func ResetDefaultProvider() { defaultProvider = nil defaultProviderErr = nil } + +// resolveTokenSource creates a TokenSource from the config if registry auth is configured. +// Returns nil if no auth is configured or if token source creation fails (logs warning). +func resolveTokenSource(cfg *config.Config, interactive bool) auth.TokenSource { + if cfg == nil || cfg.RegistryAuth.Type != config.RegistryAuthTypeOAuth || cfg.RegistryAuth.OAuth == nil { + return nil + } + + // Try to create secrets provider for token persistence + var secretsProvider secrets.Provider + providerType, err := cfg.Secrets.GetProviderType() + if err != nil { + slog.Debug("Secrets provider not available for registry auth token persistence", + "error", err) + } else { + secretsProvider, err = secrets.CreateSecretProvider(providerType) + if err != nil { + slog.Warn("Failed to create secrets provider for registry auth, tokens will not be persisted", + "error", err) + } else { + slog.Debug("Secrets provider created for registry auth token persistence", + "provider_type", providerType) + } + } + + tokenSource, err := auth.NewTokenSource(cfg.RegistryAuth.OAuth, cfg.RegistryApiUrl, secretsProvider, interactive) + if err != nil { + slog.Warn("Failed to create registry auth token source", "error", err) + return nil + } + + return tokenSource +} diff --git a/pkg/registry/provider_api.go b/pkg/registry/provider_api.go index 468d7ddf73..f195119b52 100644 --- a/pkg/registry/provider_api.go +++ b/pkg/registry/provider_api.go @@ -13,6 +13,7 @@ import ( "github.com/stacklok/toolhive-core/registry/converters" types "github.com/stacklok/toolhive-core/registry/types" "github.com/stacklok/toolhive/pkg/registry/api" + "github.com/stacklok/toolhive/pkg/registry/auth" ) // APIRegistryProvider provides registry data from an MCP Registry API endpoint @@ -24,10 +25,11 @@ type APIRegistryProvider struct { client api.Client } -// NewAPIRegistryProvider creates a new API registry provider -func NewAPIRegistryProvider(apiURL string, allowPrivateIp bool) (*APIRegistryProvider, error) { +// NewAPIRegistryProvider creates a new API registry provider. +// If tokenSource is non-nil, all API requests will include authentication. +func NewAPIRegistryProvider(apiURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (*APIRegistryProvider, error) { // Create API client - client, err := api.NewClient(apiURL, allowPrivateIp) + client, err := api.NewClient(apiURL, allowPrivateIp, tokenSource) if err != nil { return nil, fmt.Errorf("failed to create API client: %w", err) } @@ -41,14 +43,18 @@ func NewAPIRegistryProvider(apiURL string, allowPrivateIp bool) (*APIRegistryPro // Initialize the base provider with the GetRegistry function p.BaseProvider = NewBaseProvider(p.GetRegistry) - // Validate the endpoint by actually trying to use it (not checking openapi.yaml) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // Skip validation probe when auth is configured. The OAuth browser flow + // requires user interaction which cannot complete within the validation timeout. + // The endpoint will be validated on first real use instead. + if tokenSource == nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - // Try to list servers with a small limit to verify API functionality - _, err = client.ListServers(ctx, &api.ListOptions{Limit: 1}) - if err != nil { - return nil, fmt.Errorf("API endpoint not functional: %w", err) + // Try to list servers with a small limit to verify API functionality + _, err = client.ListServers(ctx, &api.ListOptions{Limit: 1}) + if err != nil { + return nil, fmt.Errorf("API endpoint not functional: %w", err) + } } return p, nil diff --git a/pkg/registry/provider_cached.go b/pkg/registry/provider_cached.go index 51b8a28b03..cd2e8c592f 100644 --- a/pkg/registry/provider_cached.go +++ b/pkg/registry/provider_cached.go @@ -17,6 +17,7 @@ import ( v0 "github.com/modelcontextprotocol/registry/pkg/api/v0" types "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/auth" ) const ( @@ -48,8 +49,11 @@ type CachedAPIRegistryProvider struct { // NewCachedAPIRegistryProvider creates a new cached API registry provider. // If usePersistent is true, it will use a file cache in ~/.toolhive/cache/ // The validation happens in NewAPIRegistryProvider by actually trying to use the API. -func NewCachedAPIRegistryProvider(apiURL string, allowPrivateIp bool, usePersistent bool) (*CachedAPIRegistryProvider, error) { - base, err := NewAPIRegistryProvider(apiURL, allowPrivateIp) +// If tokenSource is non-nil, all API requests will include authentication. +func NewCachedAPIRegistryProvider( + apiURL string, allowPrivateIp bool, usePersistent bool, tokenSource auth.TokenSource, +) (*CachedAPIRegistryProvider, error) { + base, err := NewAPIRegistryProvider(apiURL, allowPrivateIp, tokenSource) if err != nil { return nil, err }