diff --git a/pkg/api/socket_unix.go b/pkg/api/socket_unix.go index 3bd6403be4..39cf23a3f4 100644 --- a/pkg/api/socket_unix.go +++ b/pkg/api/socket_unix.go @@ -11,6 +11,7 @@ import ( "io/fs" "log/slog" "net" + "net/url" "os" "path/filepath" ) @@ -58,7 +59,10 @@ func cleanupUnixSocket(address string) { } // socketURL returns the URL form of a Unix-socket address for the discovery -// file. Non-Windows platforms only ever produce unix:// URLs. +// file. Non-Windows platforms only ever produce unix:// URLs. Built via +// (&url.URL{}).String() so the discovery dialer can round-trip the value back +// through net/url without surprises (matters mostly on Windows but kept here +// for symmetry). func socketURL(address string) string { - return "unix://" + address + return (&url.URL{Scheme: "unix", Path: address}).String() } diff --git a/pkg/api/socket_unix_test.go b/pkg/api/socket_unix_test.go index 27ca7dc691..f2f8705743 100644 --- a/pkg/api/socket_unix_test.go +++ b/pkg/api/socket_unix_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/server/discovery" ) func TestSocketURL_Unix(t *testing.T) { @@ -17,6 +19,18 @@ func TestSocketURL_Unix(t *testing.T) { assert.Equal(t, "unix:///tmp/test.sock", socketURL("/tmp/test.sock")) } +// TestSocketURL_RoundTrip_Unix pins that socketURL output is always parseable +// by ParseUnixSocketPath, closing the producer/consumer loop. Without the +// net/url-based emit form this would silently break for any path that needs +// percent-encoding. +func TestSocketURL_RoundTrip_Unix(t *testing.T) { + t.Parallel() + addr := "/tmp/test.sock" + got, err := discovery.ParseUnixSocketPath(socketURL(addr)) + require.NoError(t, err) + assert.Equal(t, addr, got) +} + func TestIsNamedPipeAddress(t *testing.T) { t.Parallel() tests := []struct { diff --git a/pkg/api/socket_windows.go b/pkg/api/socket_windows.go index 494a82a550..62f2e63ad6 100644 --- a/pkg/api/socket_windows.go +++ b/pkg/api/socket_windows.go @@ -11,6 +11,7 @@ import ( "io/fs" "log/slog" "net" + "net/url" "os" "path/filepath" @@ -82,10 +83,14 @@ func cleanupUnixSocket(address string) { // socketURL returns the URL form of a Unix-socket or named-pipe address for // the discovery file. Named pipes are emitted as npipe:// where -// is everything after the \\.\pipe\ prefix. +// is everything after the \\.\pipe\ prefix. AF_UNIX paths are emitted via +// (&url.URL{}).String() so a Windows drive-letter path round-trips through +// net/url cleanly (the previous concatenation form produced unix://C:\... , +// which url.Parse rejects with "invalid port"). func socketURL(address string) string { if isNamedPipeAddress(address) { - return "npipe://" + address[len(namedPipePrefix):] + name := address[len(namedPipePrefix):] + return (&url.URL{Scheme: "npipe", Host: name}).String() } - return "unix://" + address + return (&url.URL{Scheme: "unix", Path: address}).String() } diff --git a/pkg/api/socket_windows_test.go b/pkg/api/socket_windows_test.go index f6caa5b4f4..0f650c053d 100644 --- a/pkg/api/socket_windows_test.go +++ b/pkg/api/socket_windows_test.go @@ -15,6 +15,8 @@ import ( "github.com/Microsoft/go-winio" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/server/discovery" ) // pipeNameSeq disambiguates concurrent test pipes so parallel runs don't @@ -33,7 +35,11 @@ func TestSocketURL_Windows(t *testing.T) { want string }{ {"named pipe", `\\.\pipe\thv-api`, "npipe://thv-api"}, - {"af_unix windows path", `C:\path\thv.sock`, `unix://C:\path\thv.sock`}, + // AF_UNIX Windows paths are now percent-encoded so the resulting URL + // round-trips through net/url.Parse cleanly. The previous form + // (unix://C:\path\thv.sock) was rejected by url.Parse with + // "invalid port :\\path\\thv.sock". + {"af_unix windows path", `C:\path\thv.sock`, `unix:///C:%5Cpath%5Cthv.sock`}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -43,6 +49,16 @@ func TestSocketURL_Windows(t *testing.T) { } } +// TestSocketURL_RoundTrip_NamedPipe pins that socketURL output is always +// parseable by ParseNamedPipeURL, closing the producer/consumer loop. +func TestSocketURL_RoundTrip_NamedPipe(t *testing.T) { + t.Parallel() + addr := `\\.\pipe\thv-api` + got, err := discovery.ParseNamedPipeURL(socketURL(addr)) + require.NoError(t, err) + assert.Equal(t, addr, got) +} + func TestSetupUnixSocket_NamedPipe(t *testing.T) { t.Parallel() pipePath := uniqueTestPipe() diff --git a/pkg/server/discovery/health.go b/pkg/server/discovery/health.go index 6f7d24c185..91c3cf1730 100644 --- a/pkg/server/discovery/health.go +++ b/pkg/server/discovery/health.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" "path/filepath" + "regexp" "strings" "time" ) @@ -29,6 +30,28 @@ const ( // sides cannot drift. const NamedPipePrefix = `\\.\pipe\` +// maxNamedPipeNameLen is the Windows limit on the part of a pipe path after +// \\.\pipe\: the underlying CreateNamedPipeW lpName cannot exceed 256 chars +// total, leaving 247 for the name once the prefix is accounted for. +const maxNamedPipeNameLen = 247 + +// namedPipeNamePattern is the positive charset accepted for pipe names. The +// Windows pipe namespace technically allows more, but restricting to this +// subset keeps round-tripping through net/url predictable and rules out +// shell-meaningful characters that should never appear in our addresses. +var namedPipeNamePattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`) + +// reservedNamedPipeNames is the set of legacy Windows device names that +// CreateNamedPipeW will not accept as a pipe name. The check is on the +// already-lowercased hostname. +var reservedNamedPipeNames = map[string]struct{}{ + "con": {}, "nul": {}, "prn": {}, "aux": {}, + "com1": {}, "com2": {}, "com3": {}, "com4": {}, "com5": {}, + "com6": {}, "com7": {}, "com8": {}, "com9": {}, + "lpt1": {}, "lpt2": {}, "lpt3": {}, "lpt4": {}, "lpt5": {}, + "lpt6": {}, "lpt7": {}, "lpt8": {}, "lpt9": {}, +} + // CheckHealth verifies that a server at the given URL is healthy and optionally // matches the expected nonce. It supports http://, unix://, and npipe:// URL // schemes (npipe:// only resolves on Windows). @@ -87,9 +110,16 @@ func buildHealthClient(serverURL string) (*http.Client, string, error) { // dialNamedPipe helper. For http:// URLs it validates the host is a loopback // address and returns a default client. The returned client has no timeout // set; callers should apply their own timeout via context or client.Timeout. +// +// Scheme matching is case-insensitive because url.Parse lowercases the scheme +// during parsing, so npipe://x and NPIPE://x route through the same arm. func HTTPClientForURL(serverURL string) (*http.Client, string, error) { - switch { - case strings.HasPrefix(serverURL, "unix://"): + u, err := url.Parse(serverURL) + if err != nil { + return nil, "", fmt.Errorf("invalid server URL: %w", err) + } + switch u.Scheme { + case "unix": socketPath, err := ParseUnixSocketPath(serverURL) if err != nil { return nil, "", err @@ -103,7 +133,7 @@ func HTTPClientForURL(serverURL string) (*http.Client, string, error) { } return client, "http://localhost", nil - case strings.HasPrefix(serverURL, "npipe://"): + case "npipe": pipePath, err := ParseNamedPipeURL(serverURL) if err != nil { return nil, "", err @@ -117,14 +147,14 @@ func HTTPClientForURL(serverURL string) (*http.Client, string, error) { } return client, "http://localhost", nil - case strings.HasPrefix(serverURL, "http://"): + case "http": if err := ValidateLoopbackURL(serverURL); err != nil { return nil, "", err } return &http.Client{}, serverURL, nil default: - return nil, "", fmt.Errorf("unsupported URL scheme: %s", serverURL) + return nil, "", fmt.Errorf("unsupported URL scheme: %s", u.Scheme) } } @@ -147,13 +177,36 @@ func ValidateLoopbackURL(rawURL string) error { } // ParseUnixSocketPath extracts and validates the socket path from a unix:// URL. +// The expected form is unix:///; URLs with a non-empty +// authority (unix://host/path) are rejected because the path component would +// be ambiguous. On Windows, unix:///C:%5Cpath%5Cthv.sock round-trips back to +// C:\path\thv.sock by stripping the synthetic leading slash that net/url +// inserts in front of the drive letter. func ParseUnixSocketPath(rawURL string) (string, error) { - path := strings.TrimPrefix(rawURL, "unix://") + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("invalid unix socket URL: %w", err) + } + if u.Scheme != "unix" { + return "", fmt.Errorf("unix socket URL must start with unix://: %s", rawURL) + } + if u.Host != "" || u.RawQuery != "" || u.Fragment != "" || u.User != nil { + return "", fmt.Errorf("unix socket path must be absolute (use unix:///): %s", rawURL) + } + path := u.Path if path == "" { return "", fmt.Errorf("empty unix socket path") } - // Check for traversal before Clean resolves it away + // On Windows AF_UNIX paths, the listener emits unix:///C:%5Cpath%5C..., which + // url.Parse returns as Path="/C:\path\..." with a synthetic leading slash. + // Strip it before any further validation so filepath.IsAbs sees the actual + // drive-letter form. + if len(path) >= 3 && path[0] == '/' && isASCIILetter(path[1]) && path[2] == ':' { + path = path[1:] + } + + // Check for traversal before Clean resolves it away. if strings.Contains(path, "..") { return "", fmt.Errorf("unix socket path must not contain '..': %s", path) } @@ -168,23 +221,43 @@ func ParseUnixSocketPath(rawURL string) (string, error) { } // ParseNamedPipeURL extracts and validates the pipe name from an npipe:// URL -// and returns the full Windows pipe path (e.g. \\.\pipe\thv-api). The name -// portion must be a single segment with no path separators or traversal -// components, since the toolhive listener only ever publishes local pipes -// under the \\.\pipe\ namespace. +// and returns the full Windows pipe path (e.g. \\.\pipe\thv-api). The URL +// shape is strict: npipe:// with no path, query, fragment, userinfo, or +// port. The name itself must match a conservative charset, fit within the +// 247-char limit imposed by CreateNamedPipeW after the prefix, and must not +// be one of the legacy reserved Windows device names (CON, NUL, COM1, etc.). +// The returned path always uses the lowercase form of the name because the +// pipe namespace is case-insensitive at the kernel layer. func ParseNamedPipeURL(rawURL string) (string, error) { - if !strings.HasPrefix(rawURL, "npipe://") { + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("invalid named pipe URL: %w", err) + } + if u.Scheme != "npipe" { return "", fmt.Errorf("named pipe URL must start with npipe://: %s", rawURL) } - name := strings.TrimPrefix(rawURL, "npipe://") + if u.Path != "" || u.Opaque != "" || u.RawQuery != "" || + u.Fragment != "" || u.User != nil || u.Port() != "" { + return "", fmt.Errorf("named pipe URL must be of the form npipe://: %s", rawURL) + } + name := strings.ToLower(u.Hostname()) if name == "" { return "", fmt.Errorf("empty named pipe name") } - if strings.ContainsAny(name, `/\`) { - return "", fmt.Errorf("named pipe name must not contain path separators: %s", name) + if len(name) > maxNamedPipeNameLen { + return "", fmt.Errorf("named pipe name exceeds %d characters: %d", maxNamedPipeNameLen, len(name)) + } + if !namedPipeNamePattern.MatchString(name) { + return "", fmt.Errorf("named pipe name has invalid characters (allowed: A-Z, a-z, 0-9, ., _, -): %s", name) } - if strings.Contains(name, "..") { - return "", fmt.Errorf("named pipe name must not contain '..': %s", name) + if _, reserved := reservedNamedPipeNames[name]; reserved { + return "", fmt.Errorf("named pipe name is a reserved Windows device name: %s", name) } return NamedPipePrefix + name, nil } + +// isASCIILetter reports whether b is an ASCII letter [A-Za-z]. Used by +// ParseUnixSocketPath to detect Windows drive-letter prefixes after url.Parse. +func isASCIILetter(b byte) bool { + return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') +} diff --git a/pkg/server/discovery/health_test.go b/pkg/server/discovery/health_test.go index d45c64cb3a..9dbc07b965 100644 --- a/pkg/server/discovery/health_test.go +++ b/pkg/server/discovery/health_test.go @@ -5,12 +5,14 @@ package discovery import ( "context" + "fmt" "net" "net/http" "net/http/httptest" "os" "path/filepath" "runtime" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -56,12 +58,27 @@ func TestParseNamedPipeURL(t *testing.T) { }{ {name: "valid", raw: "npipe://thv-api", expect: `\\.\pipe\thv-api`}, {name: "valid with hyphen and digits", raw: "npipe://thv-api-1", expect: `\\.\pipe\thv-api-1`}, + {name: "valid with dot dot in name", raw: "npipe://my..api", expect: `\\.\pipe\my..api`}, + {name: "mixed-case scheme normalized", raw: "NPIPE://thv-api", expect: `\\.\pipe\thv-api`}, + {name: "mixed-case name normalized", raw: "npipe://Thv-API", expect: `\\.\pipe\thv-api`}, {name: "missing scheme", raw: "thv-api", wantErr: true, errSubstr: "must start with npipe://"}, {name: "wrong scheme", raw: "unix://thv-api", wantErr: true, errSubstr: "must start with npipe://"}, {name: "empty name", raw: "npipe://", wantErr: true, errSubstr: "empty"}, - {name: "forward slash", raw: "npipe://thv/api", wantErr: true, errSubstr: "path separators"}, - {name: "backslash", raw: `npipe://thv\api`, wantErr: true, errSubstr: "path separators"}, - {name: "dot dot", raw: "npipe://..thv", wantErr: true, errSubstr: "'..'"}, + {name: "forward slash rejected", raw: "npipe://thv/api", wantErr: true, errSubstr: "form npipe://"}, + {name: "backslash rejected by url.Parse", raw: `npipe://thv\api`, wantErr: true, errSubstr: "invalid"}, + {name: "with port rejected", raw: "npipe://thv-api:1234", wantErr: true, errSubstr: "form npipe://"}, + {name: "with userinfo rejected", raw: "npipe://user:pass@thv-api", wantErr: true, errSubstr: "form npipe://"}, + {name: "with query rejected", raw: "npipe://thv-api?x=1", wantErr: true, errSubstr: "form npipe://"}, + {name: "with fragment rejected", raw: "npipe://thv-api#x", wantErr: true, errSubstr: "form npipe://"}, + {name: "invalid charset rejected", raw: "npipe://thv$api", wantErr: true, errSubstr: "invalid characters"}, + {name: "reserved name CON rejected", raw: "npipe://CON", wantErr: true, errSubstr: "reserved Windows device name"}, + {name: "reserved name com1 rejected", raw: "npipe://com1", wantErr: true, errSubstr: "reserved Windows device name"}, + { + name: "name exceeds length cap", + raw: "npipe://" + strings.Repeat("a", maxNamedPipeNameLen+1), + wantErr: true, + errSubstr: fmt.Sprintf("exceeds %d characters", maxNamedPipeNameLen), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -78,6 +95,17 @@ func TestParseNamedPipeURL(t *testing.T) { } } +// TestHTTPClientForURL_SchemeDispatchCaseInsensitive pins that the dispatcher +// in HTTPClientForURL routes UNIX:// the same as unix:// because url.Parse +// lowercases the scheme. Without the migration this case fell through to the +// default "unsupported URL scheme" arm. +func TestHTTPClientForURL_SchemeDispatchCaseInsensitive(t *testing.T) { + t.Parallel() + _, baseURL, err := HTTPClientForURL("UNIX:///tmp/thv.sock") + require.NoError(t, err) + assert.Equal(t, "http://localhost", baseURL) +} + func TestCheckHealth_NamedPipe_Unsupported_OnNonWindows(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("non-Windows guard test")