Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pkg/networking/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ const HttpsScheme = "https"
// HttpScheme is the HTTP scheme
const HttpScheme = "http"

// Dialer control function for validating addresses prior to connection
func protectedDialerControl(_, address string, _ syscall.RawConn) error {
// ProtectedDialerControl is a Dialer control function for validating addresses
// prior to connection. It returns an error if the resolved address points at a
// private, loopback, or link-local IP, providing an SSRF guard at dial time.
// Pass it to (&net.Dialer{Control: ...}).DialContext on outbound HTTP transports.
func ProtectedDialerControl(_, address string, _ syscall.RawConn) error {
err := AddressReferencesPrivateIp(address)
if err != nil {
return err
Expand Down Expand Up @@ -154,7 +157,7 @@ func (b *HttpClientBuilder) Build() (*http.Client, error) {

if !b.allowPrivate {
transport.DialContext = (&net.Dialer{
Control: protectedDialerControl,
Control: ProtectedDialerControl,
}).DialContext
}

Expand Down
39 changes: 36 additions & 3 deletions pkg/webhook/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,33 @@ import (
"os"
"path/filepath"
"strconv"
"sync/atomic"
"syscall"
"time"

"github.com/stacklok/toolhive/pkg/networking"
)

// dialerControlFunc is the type of the Control hook installed on the underlying
// *net.Dialer used by the webhook HTTP transport.
type dialerControlFunc func(network, address string, c syscall.RawConn) error

// dialerControl holds the active Control hook. By default it wraps
// networking.ProtectedDialerControl, which rejects dials to private, loopback,
// and link-local addresses (SSRF guard).
//
// It is wrapped in atomic.Pointer so cross-goroutine writes from tests
// (SetDialerControlForTesting / SetDialerControlForTestMain) and
// cross-goroutine reads from the production dial path are race-free, even if
// a future test introduces t.Parallel(). Production callers must not
// reassign this variable.
var dialerControl atomic.Pointer[dialerControlFunc]

func init() {
fn := dialerControlFunc(networking.ProtectedDialerControl)
dialerControl.Store(&fn)
}

// Client is an HTTP client for calling webhook endpoints.
type Client struct {
httpClient *http.Client
Expand Down Expand Up @@ -126,7 +148,10 @@ func (c *Client) doHTTPCall(ctx context.Context, body []byte) ([]byte, error) {
httpReq.Header.Set(TimestampHeader, strconv.FormatInt(timestamp, 10))
}

// #nosec G704 -- URL is validated in Config.Validate and we use ValidatingTransport for SSRF protection.
// #nosec G704 -- URL is validated in Config.Validate; the inner transport's
// dialer rejects private/loopback/link-local addresses (SSRF guard), and
// ValidatingTransport additionally enforces HTTPS unless InsecureAllowHTTP
// is set for the configured TLS profile.
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, classifyError(c.config.Name, err)
Expand Down Expand Up @@ -182,15 +207,23 @@ func (c *Client) hmacSecretForRequest(ctx context.Context) ([]byte, error) {
return secret, nil
}

// buildTransport creates an http.RoundTripper with the specified TLS configuration,
// always wrapped in ValidatingTransport for SSRF protection.
// buildTransport creates an http.RoundTripper with the specified TLS configuration.
// The inner *http.Transport installs a dialer Control hook (the package-level
// dialerControl) that rejects connections to private, loopback, and link-local
// addresses, providing an SSRF guard regardless of TLS or HTTP mode. The outer
// ValidatingTransport additionally enforces HTTPS unless InsecureAllowHTTP is set.
func buildTransport(tlsCfg *TLSConfig) (http.RoundTripper, error) {
transport := &http.Transport{
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DialContext: (&net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return (*dialerControl.Load())(network, address, c)
},
}).DialContext,
}

// allowHTTP is true when InsecureSkipVerify is set, which also covers in-cluster
Expand Down
72 changes: 71 additions & 1 deletion pkg/webhook/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,9 @@ func TestClientHMACSigningHeaders(t *testing.T) {
assert.NotEmpty(t, capturedHeaders.Get(TimestampHeader), "expected %s header", TimestampHeader)
}

//nolint:paralleltest // mutates package-level dialerControl hook
func TestClientRereadsMountedHMACSecret(t *testing.T) {
t.Parallel()
SetDialerControlForTesting(t, AllowAnyDialerControl)

type signedRequest struct {
body []byte
Expand Down Expand Up @@ -643,6 +644,75 @@ func TestBuildTransport(t *testing.T) {
}
}

// TestClientSSRFGuardBlocksPrivateAddress verifies that the production webhook
// client refuses to dial private, loopback, and link-local addresses (cloud
// metadata IP). The dial-time SSRF guard is the load-bearing protection against
// a tenant pointing MCPWebhookConfig.url at internal services.
//
//nolint:paralleltest // exercises the production package-level dialerControl hook
func TestClientSSRFGuardBlocksPrivateAddress(t *testing.T) {
tests := []struct {
name string
host string
}{
{name: "loopback IPv4", host: "127.0.0.1"},
{name: "RFC1918 private", host: "10.0.0.1"},
{name: "cloud metadata link-local", host: "169.254.169.254"},
{name: "loopback IPv6", host: "[::1]"},
{name: "link-local IPv6", host: "[fe80::1]"},
{name: "ULA IPv6", host: "[fc00::1]"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use HTTPS scheme so ValidatingTransport does not short-circuit the
// request before the dial. Use a high-numbered port so we never
// actually reach a real listener; the dialer guard fires first.
cfg := Config{
Name: "ssrf-test",
URL: "https://" + tt.host + ":59999/webhook",
Timeout: 2 * time.Second,
FailurePolicy: FailurePolicyFail,
}
client, err := NewClient(cfg, TypeValidating, nil)
require.NoError(t, err)

req := &Request{
Version: APIVersion,
UID: "ssrf-uid",
Timestamp: time.Now(),
}
_, callErr := client.Call(t.Context(), req)
require.Error(t, callErr)

var netErr *NetworkError
require.True(t, errors.As(callErr, &netErr),
"expected *NetworkError, got %T: %v", callErr, callErr)
assert.Contains(t, callErr.Error(), networking.ErrPrivateIpAddress,
"error should reflect dialer rejection of private/loopback/link-local address")
})
}
}

// TestBuildTransportInstallsDialerGuard is a narrow unit check that buildTransport
// installs a non-nil DialContext on the inner *http.Transport. The integration
// coverage in TestClientSSRFGuardBlocksPrivateAddress is the load-bearing test;
// this assertion just ensures the wiring does not silently regress to a bare
// transport with no Control hook.
func TestBuildTransportInstallsDialerGuard(t *testing.T) {
t.Parallel()

rt, err := buildTransport(nil)
require.NoError(t, err)
require.NotNil(t, rt)

vt, ok := rt.(*networking.ValidatingTransport)
require.True(t, ok, "expected *networking.ValidatingTransport, got %T", rt)
inner, ok := vt.Transport.(*http.Transport)
require.True(t, ok, "expected inner *http.Transport, got %T", vt.Transport)
assert.NotNil(t, inner.DialContext, "buildTransport must install a DialContext that runs the SSRF dialer guard")
}

func TestClassifyError(t *testing.T) {
t.Parallel()

Expand Down
50 changes: 50 additions & 0 deletions pkg/webhook/dialer_testing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package webhook

import (
"syscall"
"testing"
)

// The functions in this file are test-only helpers that intentionally live in a
// non-_test.go file so that sub-package tests (e.g. pkg/webhook/validating,
// pkg/webhook/mutating) can call into them via TestMain. There is no clean
// alternative for cross-package test-time injection of the package-level
// dialerControl hook, so these helpers are exported. The testing.TB argument
// (or the explicit "ForTestMain" suffix) is the signal that the call is
// test-scoped; production code MUST NOT call any of them.

// SetDialerControlForTesting overrides the package-level dialerControl hook
// for the duration of tb. It is the sanctioned way for tests to bypass the
// production SSRF dial-time guard so they can talk to httptest servers,
// which always bind 127.0.0.1. The previous value is restored via t.Cleanup.
//
// Production code MUST NOT call this function. The testing.TB argument is the
// signal that the call is test-scoped.
func SetDialerControlForTesting(tb testing.TB, control func(network, address string, c syscall.RawConn) error) {
tb.Helper()
prev := dialerControl.Load()
fn := dialerControlFunc(control)
dialerControl.Store(&fn)
tb.Cleanup(func() { dialerControl.Store(prev) })
}

// SetDialerControlForTestMain installs control as the dialer guard for the
// rest of the test binary's lifetime. Use this in TestMain in sub-packages
// whose entire test suite legitimately dials httptest servers bound to
// 127.0.0.1. There is no restore — the binary exits anyway.
//
// Production code MUST NOT call this function.
func SetDialerControlForTestMain(control func(network, address string, c syscall.RawConn) error) {
fn := dialerControlFunc(control)
dialerControl.Store(&fn)
}

// AllowAnyDialerControl is a permissive Control function for tests that
// need to dial httptest servers on 127.0.0.1. It performs no validation
// and always returns nil.
//
// Production code MUST NOT use this function.
func AllowAnyDialerControl(_, _ string, _ syscall.RawConn) error { return nil }
20 changes: 20 additions & 0 deletions pkg/webhook/mutating/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package mutating

import (
"os"
"testing"

"github.com/stacklok/toolhive/pkg/webhook"
)

// TestMain installs a permissive dialer control hook for the entire test
// binary so that webhook clients can dial httptest servers bound to 127.0.0.1.
// The production hook (networking.ProtectedDialerControl) would otherwise reject
// loopback addresses as part of the SSRF guard.
func TestMain(m *testing.M) {
webhook.SetDialerControlForTestMain(webhook.AllowAnyDialerControl)
os.Exit(m.Run())
}
20 changes: 20 additions & 0 deletions pkg/webhook/validating/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package validating

import (
"os"
"testing"

"github.com/stacklok/toolhive/pkg/webhook"
)

// TestMain installs a permissive dialer control hook for the entire test
// binary so that webhook clients can dial httptest servers bound to 127.0.0.1.
// The production hook (networking.ProtectedDialerControl) would otherwise reject
// loopback addresses as part of the SSRF guard.
func TestMain(m *testing.M) {
webhook.SetDialerControlForTestMain(webhook.AllowAnyDialerControl)
os.Exit(m.Run())
}
Loading