Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pkg/api/socket_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io/fs"
"log/slog"
"net"
"net/url"
"os"
"path/filepath"
)
Expand Down Expand Up @@ -58,7 +59,10 @@ func cleanupUnixSocket(address string) {
}

// socketURL returns the URL form of a Unix-socket address for the discovery
// file. Non-Windows platforms only ever produce unix:// URLs.
// file. Non-Windows platforms only ever produce unix:// URLs. Built via
// (&url.URL{}).String() so the discovery dialer can round-trip the value back
// through net/url without surprises (matters mostly on Windows but kept here
// for symmetry).
func socketURL(address string) string {
return "unix://" + address
return (&url.URL{Scheme: "unix", Path: address}).String()
}
14 changes: 14 additions & 0 deletions pkg/api/socket_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,27 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/toolhive/pkg/server/discovery"
)

func TestSocketURL_Unix(t *testing.T) {
t.Parallel()
assert.Equal(t, "unix:///tmp/test.sock", socketURL("/tmp/test.sock"))
}

// TestSocketURL_RoundTrip_Unix pins that socketURL output is always parseable
// by ParseUnixSocketPath, closing the producer/consumer loop. Without the
// net/url-based emit form this would silently break for any path that needs
// percent-encoding.
func TestSocketURL_RoundTrip_Unix(t *testing.T) {
t.Parallel()
addr := "/tmp/test.sock"
got, err := discovery.ParseUnixSocketPath(socketURL(addr))
require.NoError(t, err)
assert.Equal(t, addr, got)
}

func TestIsNamedPipeAddress(t *testing.T) {
t.Parallel()
tests := []struct {
Expand Down
11 changes: 8 additions & 3 deletions pkg/api/socket_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io/fs"
"log/slog"
"net"
"net/url"
"os"
"path/filepath"

Expand Down Expand Up @@ -82,10 +83,14 @@ func cleanupUnixSocket(address string) {

// socketURL returns the URL form of a Unix-socket or named-pipe address for
// the discovery file. Named pipes are emitted as npipe://<name> where <name>
// is everything after the \\.\pipe\ prefix.
// is everything after the \\.\pipe\ prefix. AF_UNIX paths are emitted via
// (&url.URL{}).String() so a Windows drive-letter path round-trips through
// net/url cleanly (the previous concatenation form produced unix://C:\... ,
// which url.Parse rejects with "invalid port").
func socketURL(address string) string {
if isNamedPipeAddress(address) {
return "npipe://" + address[len(namedPipePrefix):]
name := address[len(namedPipePrefix):]
return (&url.URL{Scheme: "npipe", Host: name}).String()
}
return "unix://" + address
return (&url.URL{Scheme: "unix", Path: address}).String()
}
18 changes: 17 additions & 1 deletion pkg/api/socket_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"github.com/Microsoft/go-winio"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/toolhive/pkg/server/discovery"
)

// pipeNameSeq disambiguates concurrent test pipes so parallel runs don't
Expand All @@ -33,7 +35,11 @@ func TestSocketURL_Windows(t *testing.T) {
want string
}{
{"named pipe", `\\.\pipe\thv-api`, "npipe://thv-api"},
{"af_unix windows path", `C:\path\thv.sock`, `unix://C:\path\thv.sock`},
// AF_UNIX Windows paths are now percent-encoded so the resulting URL
// round-trips through net/url.Parse cleanly. The previous form
// (unix://C:\path\thv.sock) was rejected by url.Parse with
// "invalid port :\\path\\thv.sock".
{"af_unix windows path", `C:\path\thv.sock`, `unix:///C:%5Cpath%5Cthv.sock`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -43,6 +49,16 @@ func TestSocketURL_Windows(t *testing.T) {
}
}

// TestSocketURL_RoundTrip_NamedPipe pins that socketURL output is always
// parseable by ParseNamedPipeURL, closing the producer/consumer loop.
func TestSocketURL_RoundTrip_NamedPipe(t *testing.T) {
t.Parallel()
addr := `\\.\pipe\thv-api`
got, err := discovery.ParseNamedPipeURL(socketURL(addr))
require.NoError(t, err)
assert.Equal(t, addr, got)
}

func TestSetupUnixSocket_NamedPipe(t *testing.T) {
t.Parallel()
pipePath := uniqueTestPipe()
Expand Down
107 changes: 90 additions & 17 deletions pkg/server/discovery/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"net/url"
"path/filepath"
"regexp"
"strings"
"time"
)
Expand All @@ -29,6 +30,28 @@ const (
// sides cannot drift.
const NamedPipePrefix = `\\.\pipe\`

// maxNamedPipeNameLen is the Windows limit on the part of a pipe path after
// \\.\pipe\: the underlying CreateNamedPipeW lpName cannot exceed 256 chars
// total, leaving 247 for the name once the prefix is accounted for.
const maxNamedPipeNameLen = 247

// namedPipeNamePattern is the positive charset accepted for pipe names. The
// Windows pipe namespace technically allows more, but restricting to this
// subset keeps round-tripping through net/url predictable and rules out
// shell-meaningful characters that should never appear in our addresses.
var namedPipeNamePattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)

// reservedNamedPipeNames is the set of legacy Windows device names that
// CreateNamedPipeW will not accept as a pipe name. The check is on the
// already-lowercased hostname.
var reservedNamedPipeNames = map[string]struct{}{
"con": {}, "nul": {}, "prn": {}, "aux": {},
"com1": {}, "com2": {}, "com3": {}, "com4": {}, "com5": {},
"com6": {}, "com7": {}, "com8": {}, "com9": {},
"lpt1": {}, "lpt2": {}, "lpt3": {}, "lpt4": {}, "lpt5": {},
"lpt6": {}, "lpt7": {}, "lpt8": {}, "lpt9": {},
}

// CheckHealth verifies that a server at the given URL is healthy and optionally
// matches the expected nonce. It supports http://, unix://, and npipe:// URL
// schemes (npipe:// only resolves on Windows).
Expand Down Expand Up @@ -87,9 +110,16 @@ func buildHealthClient(serverURL string) (*http.Client, string, error) {
// dialNamedPipe helper. For http:// URLs it validates the host is a loopback
// address and returns a default client. The returned client has no timeout
// set; callers should apply their own timeout via context or client.Timeout.
//
// Scheme matching is case-insensitive because url.Parse lowercases the scheme
// during parsing, so npipe://x and NPIPE://x route through the same arm.
func HTTPClientForURL(serverURL string) (*http.Client, string, error) {
switch {
case strings.HasPrefix(serverURL, "unix://"):
u, err := url.Parse(serverURL)
if err != nil {
return nil, "", fmt.Errorf("invalid server URL: %w", err)
}
switch u.Scheme {
case "unix":
socketPath, err := ParseUnixSocketPath(serverURL)
if err != nil {
return nil, "", err
Expand All @@ -103,7 +133,7 @@ func HTTPClientForURL(serverURL string) (*http.Client, string, error) {
}
return client, "http://localhost", nil

case strings.HasPrefix(serverURL, "npipe://"):
case "npipe":
pipePath, err := ParseNamedPipeURL(serverURL)
if err != nil {
return nil, "", err
Expand All @@ -117,14 +147,14 @@ func HTTPClientForURL(serverURL string) (*http.Client, string, error) {
}
return client, "http://localhost", nil

case strings.HasPrefix(serverURL, "http://"):
case "http":
if err := ValidateLoopbackURL(serverURL); err != nil {
return nil, "", err
}
return &http.Client{}, serverURL, nil

default:
return nil, "", fmt.Errorf("unsupported URL scheme: %s", serverURL)
return nil, "", fmt.Errorf("unsupported URL scheme: %s", u.Scheme)
}
}

Expand All @@ -147,13 +177,36 @@ func ValidateLoopbackURL(rawURL string) error {
}

// ParseUnixSocketPath extracts and validates the socket path from a unix:// URL.
// The expected form is unix:///<absolute-path>; URLs with a non-empty
// authority (unix://host/path) are rejected because the path component would
// be ambiguous. On Windows, unix:///C:%5Cpath%5Cthv.sock round-trips back to
// C:\path\thv.sock by stripping the synthetic leading slash that net/url
// inserts in front of the drive letter.
func ParseUnixSocketPath(rawURL string) (string, error) {
path := strings.TrimPrefix(rawURL, "unix://")
u, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("invalid unix socket URL: %w", err)
}
if u.Scheme != "unix" {
return "", fmt.Errorf("unix socket URL must start with unix://: %s", rawURL)
}
if u.Host != "" || u.RawQuery != "" || u.Fragment != "" || u.User != nil {
return "", fmt.Errorf("unix socket path must be absolute (use unix:///<path>): %s", rawURL)
}
path := u.Path
if path == "" {
return "", fmt.Errorf("empty unix socket path")
}

// Check for traversal before Clean resolves it away
// On Windows AF_UNIX paths, the listener emits unix:///C:%5Cpath%5C..., which
// url.Parse returns as Path="/C:\path\..." with a synthetic leading slash.
// Strip it before any further validation so filepath.IsAbs sees the actual
// drive-letter form.
if len(path) >= 3 && path[0] == '/' && isASCIILetter(path[1]) && path[2] == ':' {
path = path[1:]
}

// Check for traversal before Clean resolves it away.
if strings.Contains(path, "..") {
return "", fmt.Errorf("unix socket path must not contain '..': %s", path)
}
Expand All @@ -168,23 +221,43 @@ func ParseUnixSocketPath(rawURL string) (string, error) {
}

// ParseNamedPipeURL extracts and validates the pipe name from an npipe:// URL
// and returns the full Windows pipe path (e.g. \\.\pipe\thv-api). The name
// portion must be a single segment with no path separators or traversal
// components, since the toolhive listener only ever publishes local pipes
// under the \\.\pipe\ namespace.
// and returns the full Windows pipe path (e.g. \\.\pipe\thv-api). The URL
// shape is strict: npipe://<name> with no path, query, fragment, userinfo, or
// port. The name itself must match a conservative charset, fit within the
// 247-char limit imposed by CreateNamedPipeW after the prefix, and must not
// be one of the legacy reserved Windows device names (CON, NUL, COM1, etc.).
// The returned path always uses the lowercase form of the name because the
// pipe namespace is case-insensitive at the kernel layer.
func ParseNamedPipeURL(rawURL string) (string, error) {
if !strings.HasPrefix(rawURL, "npipe://") {
u, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("invalid named pipe URL: %w", err)
}
if u.Scheme != "npipe" {
return "", fmt.Errorf("named pipe URL must start with npipe://: %s", rawURL)
}
name := strings.TrimPrefix(rawURL, "npipe://")
if u.Path != "" || u.Opaque != "" || u.RawQuery != "" ||
u.Fragment != "" || u.User != nil || u.Port() != "" {
return "", fmt.Errorf("named pipe URL must be of the form npipe://<name>: %s", rawURL)
}
name := strings.ToLower(u.Hostname())
if name == "" {
return "", fmt.Errorf("empty named pipe name")
}
if strings.ContainsAny(name, `/\`) {
return "", fmt.Errorf("named pipe name must not contain path separators: %s", name)
if len(name) > maxNamedPipeNameLen {
return "", fmt.Errorf("named pipe name exceeds %d characters: %d", maxNamedPipeNameLen, len(name))
}
if !namedPipeNamePattern.MatchString(name) {
return "", fmt.Errorf("named pipe name has invalid characters (allowed: A-Z, a-z, 0-9, ., _, -): %s", name)
}
if strings.Contains(name, "..") {
return "", fmt.Errorf("named pipe name must not contain '..': %s", name)
if _, reserved := reservedNamedPipeNames[name]; reserved {
return "", fmt.Errorf("named pipe name is a reserved Windows device name: %s", name)
}
return NamedPipePrefix + name, nil
}

// isASCIILetter reports whether b is an ASCII letter [A-Za-z]. Used by
// ParseUnixSocketPath to detect Windows drive-letter prefixes after url.Parse.
func isASCIILetter(b byte) bool {
return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z')
}
34 changes: 31 additions & 3 deletions pkg/server/discovery/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ package discovery

import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -56,12 +58,27 @@ func TestParseNamedPipeURL(t *testing.T) {
}{
{name: "valid", raw: "npipe://thv-api", expect: `\\.\pipe\thv-api`},
{name: "valid with hyphen and digits", raw: "npipe://thv-api-1", expect: `\\.\pipe\thv-api-1`},
{name: "valid with dot dot in name", raw: "npipe://my..api", expect: `\\.\pipe\my..api`},
{name: "mixed-case scheme normalized", raw: "NPIPE://thv-api", expect: `\\.\pipe\thv-api`},
{name: "mixed-case name normalized", raw: "npipe://Thv-API", expect: `\\.\pipe\thv-api`},
{name: "missing scheme", raw: "thv-api", wantErr: true, errSubstr: "must start with npipe://"},
{name: "wrong scheme", raw: "unix://thv-api", wantErr: true, errSubstr: "must start with npipe://"},
{name: "empty name", raw: "npipe://", wantErr: true, errSubstr: "empty"},
{name: "forward slash", raw: "npipe://thv/api", wantErr: true, errSubstr: "path separators"},
{name: "backslash", raw: `npipe://thv\api`, wantErr: true, errSubstr: "path separators"},
{name: "dot dot", raw: "npipe://..thv", wantErr: true, errSubstr: "'..'"},
{name: "forward slash rejected", raw: "npipe://thv/api", wantErr: true, errSubstr: "form npipe://<name>"},
{name: "backslash rejected by url.Parse", raw: `npipe://thv\api`, wantErr: true, errSubstr: "invalid"},
{name: "with port rejected", raw: "npipe://thv-api:1234", wantErr: true, errSubstr: "form npipe://<name>"},
{name: "with userinfo rejected", raw: "npipe://user:pass@thv-api", wantErr: true, errSubstr: "form npipe://<name>"},
{name: "with query rejected", raw: "npipe://thv-api?x=1", wantErr: true, errSubstr: "form npipe://<name>"},
{name: "with fragment rejected", raw: "npipe://thv-api#x", wantErr: true, errSubstr: "form npipe://<name>"},
{name: "invalid charset rejected", raw: "npipe://thv$api", wantErr: true, errSubstr: "invalid characters"},
{name: "reserved name CON rejected", raw: "npipe://CON", wantErr: true, errSubstr: "reserved Windows device name"},
{name: "reserved name com1 rejected", raw: "npipe://com1", wantErr: true, errSubstr: "reserved Windows device name"},
{
name: "name exceeds length cap",
raw: "npipe://" + strings.Repeat("a", maxNamedPipeNameLen+1),
wantErr: true,
errSubstr: fmt.Sprintf("exceeds %d characters", maxNamedPipeNameLen),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -78,6 +95,17 @@ func TestParseNamedPipeURL(t *testing.T) {
}
}

// TestHTTPClientForURL_SchemeDispatchCaseInsensitive pins that the dispatcher
// in HTTPClientForURL routes UNIX:// the same as unix:// because url.Parse
// lowercases the scheme. Without the migration this case fell through to the
// default "unsupported URL scheme" arm.
func TestHTTPClientForURL_SchemeDispatchCaseInsensitive(t *testing.T) {
t.Parallel()
_, baseURL, err := HTTPClientForURL("UNIX:///tmp/thv.sock")
require.NoError(t, err)
assert.Equal(t, "http://localhost", baseURL)
}

func TestCheckHealth_NamedPipe_Unsupported_OnNonWindows(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("non-Windows guard test")
Expand Down
Loading