diff --git a/pkg/networking/http_client.go b/pkg/networking/http_client.go index 1116d2216e..297052b016 100644 --- a/pkg/networking/http_client.go +++ b/pkg/networking/http_client.go @@ -35,8 +35,11 @@ const HttpsScheme = "https" // HttpScheme is the HTTP scheme const HttpScheme = "http" -// Dialer control function for validating addresses prior to connection -func protectedDialerControl(_, address string, _ syscall.RawConn) error { +// ProtectedDialerControl is a Dialer control function for validating addresses +// prior to connection. It returns an error if the resolved address points at a +// private, loopback, or link-local IP, providing an SSRF guard at dial time. +// Pass it to (&net.Dialer{Control: ...}).DialContext on outbound HTTP transports. +func ProtectedDialerControl(_, address string, _ syscall.RawConn) error { err := AddressReferencesPrivateIp(address) if err != nil { return err @@ -154,7 +157,7 @@ func (b *HttpClientBuilder) Build() (*http.Client, error) { if !b.allowPrivate { transport.DialContext = (&net.Dialer{ - Control: protectedDialerControl, + Control: ProtectedDialerControl, }).DialContext } diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index cec045bd97..5c89acf252 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -17,11 +17,33 @@ import ( "os" "path/filepath" "strconv" + "sync/atomic" + "syscall" "time" "github.com/stacklok/toolhive/pkg/networking" ) +// dialerControlFunc is the type of the Control hook installed on the underlying +// *net.Dialer used by the webhook HTTP transport. +type dialerControlFunc func(network, address string, c syscall.RawConn) error + +// dialerControl holds the active Control hook. By default it wraps +// networking.ProtectedDialerControl, which rejects dials to private, loopback, +// and link-local addresses (SSRF guard). +// +// It is wrapped in atomic.Pointer so cross-goroutine writes from tests +// (SetDialerControlForTesting / SetDialerControlForTestMain) and +// cross-goroutine reads from the production dial path are race-free, even if +// a future test introduces t.Parallel(). Production callers must not +// reassign this variable. +var dialerControl atomic.Pointer[dialerControlFunc] + +func init() { + fn := dialerControlFunc(networking.ProtectedDialerControl) + dialerControl.Store(&fn) +} + // Client is an HTTP client for calling webhook endpoints. type Client struct { httpClient *http.Client @@ -126,7 +148,10 @@ func (c *Client) doHTTPCall(ctx context.Context, body []byte) ([]byte, error) { httpReq.Header.Set(TimestampHeader, strconv.FormatInt(timestamp, 10)) } - // #nosec G704 -- URL is validated in Config.Validate and we use ValidatingTransport for SSRF protection. + // #nosec G704 -- URL is validated in Config.Validate; the inner transport's + // dialer rejects private/loopback/link-local addresses (SSRF guard), and + // ValidatingTransport additionally enforces HTTPS unless InsecureAllowHTTP + // is set for the configured TLS profile. resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, classifyError(c.config.Name, err) @@ -182,8 +207,11 @@ func (c *Client) hmacSecretForRequest(ctx context.Context) ([]byte, error) { return secret, nil } -// buildTransport creates an http.RoundTripper with the specified TLS configuration, -// always wrapped in ValidatingTransport for SSRF protection. +// buildTransport creates an http.RoundTripper with the specified TLS configuration. +// The inner *http.Transport installs a dialer Control hook (the package-level +// dialerControl) that rejects connections to private, loopback, and link-local +// addresses, providing an SSRF guard regardless of TLS or HTTP mode. The outer +// ValidatingTransport additionally enforces HTTPS unless InsecureAllowHTTP is set. func buildTransport(tlsCfg *TLSConfig) (http.RoundTripper, error) { transport := &http.Transport{ TLSHandshakeTimeout: 10 * time.Second, @@ -191,6 +219,11 @@ func buildTransport(tlsCfg *TLSConfig) (http.RoundTripper, error) { MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, + DialContext: (&net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + return (*dialerControl.Load())(network, address, c) + }, + }).DialContext, } // allowHTTP is true when InsecureSkipVerify is set, which also covers in-cluster diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 377fc6f0a8..39e3e5eb05 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -323,8 +323,9 @@ func TestClientHMACSigningHeaders(t *testing.T) { assert.NotEmpty(t, capturedHeaders.Get(TimestampHeader), "expected %s header", TimestampHeader) } +//nolint:paralleltest // mutates package-level dialerControl hook func TestClientRereadsMountedHMACSecret(t *testing.T) { - t.Parallel() + SetDialerControlForTesting(t, AllowAnyDialerControl) type signedRequest struct { body []byte @@ -643,6 +644,75 @@ func TestBuildTransport(t *testing.T) { } } +// TestClientSSRFGuardBlocksPrivateAddress verifies that the production webhook +// client refuses to dial private, loopback, and link-local addresses (cloud +// metadata IP). The dial-time SSRF guard is the load-bearing protection against +// a tenant pointing MCPWebhookConfig.url at internal services. +// +//nolint:paralleltest // exercises the production package-level dialerControl hook +func TestClientSSRFGuardBlocksPrivateAddress(t *testing.T) { + tests := []struct { + name string + host string + }{ + {name: "loopback IPv4", host: "127.0.0.1"}, + {name: "RFC1918 private", host: "10.0.0.1"}, + {name: "cloud metadata link-local", host: "169.254.169.254"}, + {name: "loopback IPv6", host: "[::1]"}, + {name: "link-local IPv6", host: "[fe80::1]"}, + {name: "ULA IPv6", host: "[fc00::1]"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Use HTTPS scheme so ValidatingTransport does not short-circuit the + // request before the dial. Use a high-numbered port so we never + // actually reach a real listener; the dialer guard fires first. + cfg := Config{ + Name: "ssrf-test", + URL: "https://" + tt.host + ":59999/webhook", + Timeout: 2 * time.Second, + FailurePolicy: FailurePolicyFail, + } + client, err := NewClient(cfg, TypeValidating, nil) + require.NoError(t, err) + + req := &Request{ + Version: APIVersion, + UID: "ssrf-uid", + Timestamp: time.Now(), + } + _, callErr := client.Call(t.Context(), req) + require.Error(t, callErr) + + var netErr *NetworkError + require.True(t, errors.As(callErr, &netErr), + "expected *NetworkError, got %T: %v", callErr, callErr) + assert.Contains(t, callErr.Error(), networking.ErrPrivateIpAddress, + "error should reflect dialer rejection of private/loopback/link-local address") + }) + } +} + +// TestBuildTransportInstallsDialerGuard is a narrow unit check that buildTransport +// installs a non-nil DialContext on the inner *http.Transport. The integration +// coverage in TestClientSSRFGuardBlocksPrivateAddress is the load-bearing test; +// this assertion just ensures the wiring does not silently regress to a bare +// transport with no Control hook. +func TestBuildTransportInstallsDialerGuard(t *testing.T) { + t.Parallel() + + rt, err := buildTransport(nil) + require.NoError(t, err) + require.NotNil(t, rt) + + vt, ok := rt.(*networking.ValidatingTransport) + require.True(t, ok, "expected *networking.ValidatingTransport, got %T", rt) + inner, ok := vt.Transport.(*http.Transport) + require.True(t, ok, "expected inner *http.Transport, got %T", vt.Transport) + assert.NotNil(t, inner.DialContext, "buildTransport must install a DialContext that runs the SSRF dialer guard") +} + func TestClassifyError(t *testing.T) { t.Parallel() diff --git a/pkg/webhook/dialer_testing.go b/pkg/webhook/dialer_testing.go new file mode 100644 index 0000000000..269433a57f --- /dev/null +++ b/pkg/webhook/dialer_testing.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import ( + "syscall" + "testing" +) + +// The functions in this file are test-only helpers that intentionally live in a +// non-_test.go file so that sub-package tests (e.g. pkg/webhook/validating, +// pkg/webhook/mutating) can call into them via TestMain. There is no clean +// alternative for cross-package test-time injection of the package-level +// dialerControl hook, so these helpers are exported. The testing.TB argument +// (or the explicit "ForTestMain" suffix) is the signal that the call is +// test-scoped; production code MUST NOT call any of them. + +// SetDialerControlForTesting overrides the package-level dialerControl hook +// for the duration of tb. It is the sanctioned way for tests to bypass the +// production SSRF dial-time guard so they can talk to httptest servers, +// which always bind 127.0.0.1. The previous value is restored via t.Cleanup. +// +// Production code MUST NOT call this function. The testing.TB argument is the +// signal that the call is test-scoped. +func SetDialerControlForTesting(tb testing.TB, control func(network, address string, c syscall.RawConn) error) { + tb.Helper() + prev := dialerControl.Load() + fn := dialerControlFunc(control) + dialerControl.Store(&fn) + tb.Cleanup(func() { dialerControl.Store(prev) }) +} + +// SetDialerControlForTestMain installs control as the dialer guard for the +// rest of the test binary's lifetime. Use this in TestMain in sub-packages +// whose entire test suite legitimately dials httptest servers bound to +// 127.0.0.1. There is no restore — the binary exits anyway. +// +// Production code MUST NOT call this function. +func SetDialerControlForTestMain(control func(network, address string, c syscall.RawConn) error) { + fn := dialerControlFunc(control) + dialerControl.Store(&fn) +} + +// AllowAnyDialerControl is a permissive Control function for tests that +// need to dial httptest servers on 127.0.0.1. It performs no validation +// and always returns nil. +// +// Production code MUST NOT use this function. +func AllowAnyDialerControl(_, _ string, _ syscall.RawConn) error { return nil } diff --git a/pkg/webhook/mutating/main_test.go b/pkg/webhook/mutating/main_test.go new file mode 100644 index 0000000000..6a91610c5c --- /dev/null +++ b/pkg/webhook/mutating/main_test.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package mutating + +import ( + "os" + "testing" + + "github.com/stacklok/toolhive/pkg/webhook" +) + +// TestMain installs a permissive dialer control hook for the entire test +// binary so that webhook clients can dial httptest servers bound to 127.0.0.1. +// The production hook (networking.ProtectedDialerControl) would otherwise reject +// loopback addresses as part of the SSRF guard. +func TestMain(m *testing.M) { + webhook.SetDialerControlForTestMain(webhook.AllowAnyDialerControl) + os.Exit(m.Run()) +} diff --git a/pkg/webhook/validating/main_test.go b/pkg/webhook/validating/main_test.go new file mode 100644 index 0000000000..d52e1fc11b --- /dev/null +++ b/pkg/webhook/validating/main_test.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package validating + +import ( + "os" + "testing" + + "github.com/stacklok/toolhive/pkg/webhook" +) + +// TestMain installs a permissive dialer control hook for the entire test +// binary so that webhook clients can dial httptest servers bound to 127.0.0.1. +// The production hook (networking.ProtectedDialerControl) would otherwise reject +// loopback addresses as part of the SSRF guard. +func TestMain(m *testing.M) { + webhook.SetDialerControlForTestMain(webhook.AllowAnyDialerControl) + os.Exit(m.Run()) +}