diff --git a/options_test.go b/options_test.go index 23fca5a..2992b6a 100644 --- a/options_test.go +++ b/options_test.go @@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) { func TestPasswordAuthBadPass(t *testing.T) { t.Parallel() - l := newLocalListener() + l := newLocalTCPListener() srv := &Server{Handler: func(s Session) {}} srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { return false diff --git a/server.go b/server.go index 6e0eab4..e2290b8 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package ssh import ( "context" "errors" + "maps" "net" "sync" "time" @@ -47,6 +48,8 @@ type Server struct { ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil + LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding (direct-streamlocal@openssh.com), denies all if nil + ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding (streamlocal-forward@openssh.com), denies all if nil ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions @@ -98,21 +101,15 @@ func (srv *Server) ensureHandlers() { if srv.RequestHandlers == nil { srv.RequestHandlers = map[string]RequestHandler{} - for k, v := range DefaultRequestHandlers { - srv.RequestHandlers[k] = v - } + maps.Copy(srv.RequestHandlers, DefaultRequestHandlers) } if srv.ChannelHandlers == nil { srv.ChannelHandlers = map[string]ChannelHandler{} - for k, v := range DefaultChannelHandlers { - srv.ChannelHandlers[k] = v - } + maps.Copy(srv.ChannelHandlers, DefaultChannelHandlers) } if srv.SubsystemHandlers == nil { srv.SubsystemHandlers = map[string]SubsystemHandler{} - for k, v := range DefaultSubsystemHandlers { - srv.SubsystemHandlers[k] = v - } + maps.Copy(srv.SubsystemHandlers, DefaultSubsystemHandlers) } } diff --git a/server_test.go b/server_test.go index 11978e6..e60e703 100644 --- a/server_test.go +++ b/server_test.go @@ -30,7 +30,7 @@ func TestAddHostKey(t *testing.T) { } func TestServerShutdown(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() testBytes := []byte("Hello world\n") s := &Server{ Handler: func(s Session) { @@ -82,7 +82,7 @@ func TestServerShutdown(t *testing.T) { } func TestServerClose(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() s := &Server{ Handler: func(s Session) { time.Sleep(5 * time.Second) @@ -128,7 +128,7 @@ func TestServerClose(t *testing.T) { } func TestServerHandshakeTimeout(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() s := &Server{ HandshakeTimeout: time.Millisecond, diff --git a/session_test.go b/session_test.go index 4f6b5ca..3291fb5 100644 --- a/session_test.go +++ b/session_test.go @@ -21,14 +21,25 @@ func (srv *Server) serveOnce(l net.Listener) error { return e } srv.ChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, - "direct-tcpip": DirectTCPIPHandler, + "session": DefaultSessionHandler, + "direct-tcpip": DirectTCPIPHandler, + "direct-streamlocal@openssh.com": DirectStreamLocalHandler, } + + forwardedTCPHandler := &ForwardedTCPHandler{} + forwardedUnixHandler := &ForwardedUnixHandler{} + srv.RequestHandlers = map[string]RequestHandler{ + "tcpip-forward": forwardedTCPHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest, + } + srv.HandleConn(conn) return nil } -func newLocalListener() net.Listener { +func newLocalTCPListener() net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { @@ -65,7 +76,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g } func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - l := newLocalListener() + l := newLocalTCPListener() go srv.serveOnce(l) return newClientSession(t, l.Addr().String(), cfg) } diff --git a/ssh.go b/ssh.go index e2dd161..cc8a4a6 100644 --- a/ssh.go +++ b/ssh.go @@ -2,6 +2,7 @@ package ssh import ( "crypto/subtle" + "errors" "net" gossh "golang.org/x/crypto/ssh" @@ -29,6 +30,12 @@ const ( // DefaultHandler is the default Handler used by Serve. var DefaultHandler Handler +// ErrRejected may be returned by LocalUnixForwardingCallback or +// ReverseUnixForwardingCallback to reject a forwarding request. When +// returned, the server replies with "prohibited" rather than +// "connection failed." +var ErrRejected = errors.New("ssh: rejected") + // Option is a functional option handler for Server. type Option func(*Server) error @@ -66,6 +73,30 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti // ReversePortForwardingCallback is a hook for allowing reverse port forwarding type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool +// LocalUnixForwardingCallback is a hook for allowing unix forwarding +// (direct-streamlocal@openssh.com). The callback receives the client-requested +// socket path and returns a connection to the target socket, or an error. +// +// Returning ErrRejected (or an error wrapping it) rejects the request with +// "administratively prohibited" and the error message is sent to the client. +// Any other error rejects with "connection failed." +// +// Use NewLocalUnixForwardingCallback to create a callback with built-in path +// validation and security controls. +type LocalUnixForwardingCallback func(ctx Context, socketPath string) (net.Conn, error) + +// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding +// (streamlocal-forward@openssh.com). The callback receives the client-requested +// socket path and returns a listener bound to that path, or an error. +// +// Returning ErrRejected (or an error wrapping it) rejects the request with +// "administratively prohibited" and the error message is sent to the client. +// Any other error rejects the request silently. +// +// Use NewReverseUnixForwardingCallback to create a callback with built-in path +// validation, permission controls, and security defaults matching OpenSSH. +type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error) + // ServerConfigCallback is a hook for creating custom default server configs type ServerConfigCallback func(ctx Context) *gossh.ServerConfig diff --git a/streamlocal.go b/streamlocal.go new file mode 100644 index 0000000..cf9208b --- /dev/null +++ b/streamlocal.go @@ -0,0 +1,470 @@ +package ssh + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net" + "os" + "path/filepath" + "strings" + "sync" + "syscall" + + gossh "golang.org/x/crypto/ssh" +) + +// maxSunPathLen is the maximum length of a Unix domain socket path on the +// current platform, derived from the kernel's sockaddr_un.sun_path field. +// This is 108 on Linux and 104 on macOS/BSD. +var maxSunPathLen = len(syscall.RawSockaddrUnix{}.Path) + +const ( + forwardedUnixChannelType = "forwarded-streamlocal@openssh.com" +) + +// directStreamLocalChannelData data struct as specified in OpenSSH's protocol +// extensions document, Section 2.4. +// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD +type directStreamLocalChannelData struct { + SocketPath string + + Reserved1 string + Reserved2 uint32 +} + +// DirectStreamLocalHandler provides Unix forwarding from client -> server. It +// can be enabled by adding it to the server's ChannelHandlers under +// `direct-streamlocal@openssh.com`. +// +// Unix socket support on Windows is not widely available, so this handler may +// not work on all Windows installations and is not tested on Windows. +func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + var d directStreamLocalChannelData + err := gossh.Unmarshal(newChan.ExtraData(), &d) + if err != nil { + _ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error()) + return + } + + if srv.LocalUnixForwardingCallback == nil { + _ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled") + return + } + dconn, err := srv.LocalUnixForwardingCallback(ctx, d.SocketPath) + if err != nil { + if errors.Is(err, ErrRejected) { + _ = newChan.Reject(gossh.Prohibited, rejectedMessage(err)) + return + } + _ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %v", d.SocketPath, err)) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = dconn.Close() + return + } + go gossh.DiscardRequests(reqs) + + bicopy(ctx, ch, dconn) +} + +// remoteUnixForwardRequest describes the extra data sent in a +// streamlocal-forward@openssh.com containing the socket path to bind to. +type remoteUnixForwardRequest struct { + SocketPath string +} + +// remoteUnixForwardChannelData describes the data sent as the payload in the new +// channel request when a Unix connection is accepted by the listener. +// +// See OpenSSH PROTOCOL, Section 2.4 "forwarded-streamlocal@openssh.com": +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +// +// See also the client-side struct in x/crypto/ssh (forwardedStreamLocalPayload): +// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go +type remoteUnixForwardChannelData struct { + SocketPath string + Reserved string +} + +// forwardKey identifies a forwarded Unix socket scoped to a specific +// SSH session, preventing cross-session collisions and ensuring that +// one session cannot cancel another session's forward. +type forwardKey struct { + sessionID string + addr string +} + +// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and +// adding the HandleSSHRequest callback to the server's RequestHandlers under +// `streamlocal-forward@openssh.com` and +// `cancel-streamlocal-forward@openssh.com` +// +// Unix socket support on Windows is not widely available, so this handler may +// not work on all Windows installations and is not tested on Windows. +type ForwardedUnixHandler struct { + sync.Mutex + forwards map[forwardKey]net.Listener +} + +func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[forwardKey]net.Listener) + } + h.Unlock() + conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn) + if !ok { + // TODO: log cast failure + return false, nil + } + + switch req.Type { + case "streamlocal-forward@openssh.com": + var reqPayload remoteUnixForwardRequest + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + // TODO: log parse failure + return false, nil + } + + if srv.ReverseUnixForwardingCallback == nil { + return false, []byte("unix forwarding is disabled") + } + + addr := reqPayload.SocketPath + key := forwardKey{ + sessionID: ctx.SessionID(), + addr: addr, + } + + // Use a nil sentinel to claim the key while the callback runs, + // preventing a concurrent request from racing past the check. + h.Lock() + _, ok := h.forwards[key] + if ok { + h.Unlock() + // In cases where ExitOnForwardFailure=yes is set, returning + // false here will cause the connection to be closed. To avoid + // this, and to match OpenSSH behavior, we silently ignore + // the second forward request. + // TODO: log duplicate forward + return true, nil + } + h.forwards[key] = nil // placeholder; claimed + h.Unlock() + + ln, err := srv.ReverseUnixForwardingCallback(ctx, addr) + if err != nil { + h.Lock() + delete(h.forwards, key) + h.Unlock() + if errors.Is(err, ErrRejected) { + return false, []byte(rejectedMessage(err)) + } + // TODO: log unix listen failure + return false, nil + } + + h.Lock() + h.forwards[key] = ln + h.Unlock() + + // Use the connection-scoped context for bicopy so active data + // transfers survive listener shutdown. The derived context is + // only used for the accept loop lifecycle. + connCtx := ctx + ctx, cancel := context.WithCancel(ctx) + go func() { + <-ctx.Done() + _ = ln.Close() + }() + go func() { + defer cancel() + + for { + c, err := ln.Accept() + if err != nil { + // closed below + break + } + payload := gossh.Marshal(&remoteUnixForwardChannelData{ + SocketPath: addr, + }) + + go func() { + ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload) + if err != nil { + _ = c.Close() + return + } + go gossh.DiscardRequests(reqs) + bicopy(connCtx, ch, c) + }() + } + + h.Lock() + ln2, ok := h.forwards[key] + if ok && ln2 == ln { + delete(h.forwards, key) + } + h.Unlock() + _ = ln.Close() + }() + + return true, nil + + case "cancel-streamlocal-forward@openssh.com": + var reqPayload remoteUnixForwardRequest + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + // TODO: log parse failure + return false, nil + } + key := forwardKey{ + sessionID: ctx.SessionID(), + addr: reqPayload.SocketPath, + } + h.Lock() + ln, ok := h.forwards[key] + if ok { + delete(h.forwards, key) + } + h.Unlock() + if ok { + _ = ln.Close() + } + return true, nil + + default: + return false, nil + } +} + +// unlink removes files and unlike os.Remove, directories are kept. +func unlink(path string) error { + // Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go + // for more details. + for { + err := syscall.Unlink(path) + if !errors.Is(err, syscall.EINTR) { + return err + } + } +} + +// rejectedMessage returns a user-facing rejection message. If err is a bare +// ErrRejected (no wrapping context), it returns the generic "unix forwarding +// is disabled" for backward compatibility. Wrapped errors (e.g. rejectionError) +// return their descriptive message. +func rejectedMessage(err error) string { + if err == ErrRejected { //nolint:errorlint // intentional identity check + return "unix forwarding is disabled" + } + return err.Error() +} + +// rejectionError wraps ErrRejected with a descriptive reason for the SSH +// client. It satisfies errors.Is(err, ErrRejected) so that handlers send +// the rejection as "administratively prohibited" with the descriptive message. +type rejectionError struct { + reason string +} + +func (e *rejectionError) Error() string { return e.reason } +func (e *rejectionError) Unwrap() error { return ErrRejected } + +// UnixForwardingOptions configures the behavior of +// NewLocalUnixForwardingCallback and NewReverseUnixForwardingCallback. +type UnixForwardingOptions struct { + // AllowAll, if true, permits any absolute socket path without directory + // restrictions. AllowedDirectories and DeniedPrefixes are ignored when + // set. Basic sanitization (absolute path, length, filepath.Clean) is + // still applied. + AllowAll bool + + // AllowedDirectories is the list of directory prefixes under which + // socket paths are permitted. Paths are cleaned with filepath.Clean + // before prefix matching. Ignored when AllowAll is true. + // When AllowAll is false and AllowedDirectories is empty, all + // requests are denied. + AllowedDirectories []string + + // DeniedPrefixes is an optional denylist applied after the allowlist. + // Useful for excluding sensitive sub-paths within allowed directories + // (e.g. /run/user/1000/systemd/ within /run/user/1000/). + // Ignored when AllowAll is true. + DeniedPrefixes []string + + // BindUnlink controls whether an existing socket file is removed + // before binding (reverse forwarding only). Only socket-type files + // are removed; regular files are left in place and the listen will + // fail with EADDRINUSE. Default: false. + // Matches OpenSSH's StreamLocalBindUnlink (default: no). + BindUnlink bool + + // BindMask is the umask applied when creating listening sockets + // (reverse forwarding only). The resulting socket permission is + // 0666 &^ BindMask. If nil, defaults to 0177 (socket permission + // 0600, owner read/write only). + // Matches OpenSSH's StreamLocalBindMask. + BindMask *os.FileMode + + // PathValidator is an optional additional validation function called + // after built-in checks pass. Return an error wrapping ErrRejected + // (or a *rejectionError) for "administratively prohibited" semantics, + // or any other error for "connection failed." + PathValidator func(ctx Context, socketPath string) error +} + +// validateSocketPath checks that socketPath is safe according to opts. +// It returns the cleaned path on success. Returned errors wrap ErrRejected +// so that handlers report them as "administratively prohibited" with a +// descriptive message. +func validateSocketPath(socketPath string, opts UnixForwardingOptions) (string, error) { + if !filepath.IsAbs(socketPath) { + return "", &rejectionError{reason: "socket path must be absolute"} + } + + cleaned := filepath.Clean(socketPath) + + if strings.ContainsRune(cleaned, 0) { + return "", &rejectionError{reason: "socket path contains NUL byte"} + } + + if len(cleaned) >= maxSunPathLen { + return "", &rejectionError{ + reason: fmt.Sprintf("socket path too long (%d >= %d)", len(cleaned), maxSunPathLen), + } + } + + if !opts.AllowAll { + if len(opts.AllowedDirectories) == 0 { + return "", &rejectionError{ + reason: fmt.Sprintf("socket path %q is not in an allowed directory", cleaned), + } + } + + allowed := false + for _, dir := range opts.AllowedDirectories { + prefix := filepath.Clean(dir) + if !strings.HasSuffix(prefix, string(filepath.Separator)) { + prefix += string(filepath.Separator) + } + if strings.HasPrefix(cleaned, prefix) { + allowed = true + break + } + } + if !allowed { + return "", &rejectionError{ + reason: fmt.Sprintf("socket path %q is not in an allowed directory", cleaned), + } + } + + for _, denied := range opts.DeniedPrefixes { + prefix := filepath.Clean(denied) + if cleaned == prefix || strings.HasPrefix(cleaned, prefix+string(filepath.Separator)) { + return "", &rejectionError{ + reason: fmt.Sprintf("socket path %q is denied", cleaned), + } + } + } + } + + return cleaned, nil +} + +// NewLocalUnixForwardingCallback returns a LocalUnixForwardingCallback that +// validates socket paths against the provided options before dialing. +// Path validation errors are reported to the SSH client as +// "administratively prohibited" rejections with descriptive messages. +func NewLocalUnixForwardingCallback(opts UnixForwardingOptions) LocalUnixForwardingCallback { + return func(ctx Context, socketPath string) (net.Conn, error) { + cleaned, err := validateSocketPath(socketPath, opts) + if err != nil { + return nil, err + } + if opts.PathValidator != nil { + if err := opts.PathValidator(ctx, cleaned); err != nil { + return nil, err + } + } + + var d net.Dialer + return d.DialContext(ctx, "unix", cleaned) + } +} + +// NewReverseUnixForwardingCallback returns a ReverseUnixForwardingCallback +// that validates socket paths against the provided options before listening. +// +// Unlike a bare net.Listen, this callback: +// - Validates the socket path against allow/deny lists +// - Does not create parent directories +// - Applies a restrictive permission mask (default 0177 / mode 0600) +// - Only unlinks existing socket files when BindUnlink is true (not +// regular files or directories) +func NewReverseUnixForwardingCallback(opts UnixForwardingOptions) ReverseUnixForwardingCallback { + return func(ctx Context, socketPath string) (net.Listener, error) { + cleaned, err := validateSocketPath(socketPath, opts) + if err != nil { + return nil, err + } + if opts.PathValidator != nil { + if err := opts.PathValidator(ctx, cleaned); err != nil { + return nil, err + } + } + + if opts.BindUnlink { + // Only unlink if the existing file is a socket or does + // not exist. Regular files and directories are left in + // place so that net.Listen fails with EADDRINUSE rather + // than silently deleting user data. + if info, serr := os.Lstat(cleaned); serr == nil { + if info.Mode().Type() == os.ModeSocket { + if uerr := unlink(cleaned); uerr != nil && !errors.Is(uerr, fs.ErrNotExist) { + return nil, fmt.Errorf("failed to unlink existing socket %q: %w", cleaned, uerr) + } + } + } + } + + lc := &net.ListenConfig{} + ln, err := lc.Listen(ctx, "unix", cleaned) + if err != nil { + return nil, fmt.Errorf("failed to listen on unix socket %q: %w", cleaned, err) + } + + // Apply socket permission mask. Default 0177 (mode 0600), + // matching OpenSSH's StreamLocalBindMask. + mask := os.FileMode(0177) + if opts.BindMask != nil { + mask = *opts.BindMask + } + mode := os.FileMode(0666) &^ mask + if err := os.Chmod(cleaned, mode); err != nil { + _ = ln.Close() + return nil, fmt.Errorf("failed to set permissions on socket %q: %w", cleaned, err) + } + + return ln, nil + } +} + +// UserSocketDirectories returns common socket directory prefixes for a user, +// suitable for use as UnixForwardingOptions.AllowedDirectories. The returned +// list includes the user's home directory, /tmp, and the XDG runtime +// directory (/run/user/). +func UserSocketDirectories(homeDir string, uid string) []string { + return []string{ + homeDir, + "/tmp", + filepath.Join("/run/user", uid), + } +} diff --git a/streamlocal_test.go b/streamlocal_test.go new file mode 100644 index 0000000..b04ba51 --- /dev/null +++ b/streamlocal_test.go @@ -0,0 +1,788 @@ +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strings" + "sync/atomic" + "testing" + + gossh "golang.org/x/crypto/ssh" +) + +// tempDirUnixSocket returns a temporary directory that can safely hold unix +// sockets. +// +// On all platforms other than darwin this just returns t.TempDir(). On darwin +// we manually make a temporary directory in /tmp because t.TempDir() returns a +// very long directory name, and the path length limit for Unix sockets on +// darwin is 104 characters. +func tempDirUnixSocket(t *testing.T) string { + t.Helper() + if runtime.GOOS == "darwin" { + testName := strings.ReplaceAll(t.Name(), "/", "_") + dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("gliderlabs-ssh-test-%s-", testName)) + if err != nil { + t.Fatalf("create temp dir for test: %v", err) + } + + t.Cleanup(func() { + err := os.RemoveAll(dir) + if err != nil { + t.Errorf("remove temp dir %s: %v", dir, err) + } + }) + return dir + } + + return t.TempDir() +} + +func newLocalUnixListener(t *testing.T) net.Listener { + path := filepath.Join(tempDirUnixSocket(t), "socket.sock") + l, err := net.Listen("unix", path) + if err != nil { + t.Fatalf("failed to listen on a unix socket %q: %v", path, err) + } + return l +} + +func sampleUnixSocketServer(t *testing.T) net.Listener { + l := newLocalUnixListener(t) + + go func() { + conn, err := l.Accept() + if err != nil { + return + } + _, _ = conn.Write(sampleServerResponse) + _ = conn.Close() + }() + + return l +} + +func newTestSessionWithUnixForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { + l := sampleUnixSocketServer(t) + + allowAllCb := NewLocalUnixForwardingCallback(UnixForwardingOptions{AllowAll: true}) + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + LocalUnixForwardingCallback: func(ctx Context, socketPath string) (net.Conn, error) { + if socketPath != l.Addr().String() { + panic("unexpected socket path: " + socketPath) + } + if !forwardingEnabled { + return nil, ErrRejected + } + return allowAllCb(ctx, socketPath) + }, + }, nil) + + return l, client, func() { + cleanup() + _ = l.Close() + } +} + +func TestLocalUnixForwardingWorks(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithUnixForwarding(t, true) + defer cleanup() + + conn, err := client.Dial("unix", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } +} + +func TestLocalUnixForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithUnixForwarding(t, false) + defer cleanup() + + _, err := client.Dial("unix", l.Addr().String()) + if err == nil { + t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) + } + if !strings.Contains(err.Error(), "unix forwarding is disabled") { + t.Fatalf("Expected permission error but got %#v", err) + } +} + +func TestReverseUnixForwardingWorks(t *testing.T) { + t.Parallel() + + remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + + allowAllCb := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + BindUnlink: true, + }) + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) { + if socketPath != remoteSocketPath { + panic("unexpected socket path: " + socketPath) + } + return allowAllCb(ctx, socketPath) + }, + }, nil) + defer cleanup() + + l, err := client.ListenUnix(remoteSocketPath) + if err != nil { + t.Fatalf("failed to listen on a unix socket over SSH %q: %v", remoteSocketPath, err) + } + defer l.Close() //nolint:errcheck + go func() { + conn, err := l.Accept() + if err != nil { + return + } + _, _ = conn.Write(sampleServerResponse) + _ = conn.Close() + }() + + // Dial the listener that should've been created by the server. + conn, err := net.Dial("unix", remoteSocketPath) + if err != nil { + t.Fatalf("Error connecting to %v: %v", remoteSocketPath, err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } + + // Close the listener and make sure that the Unix socket is gone. + err = l.Close() + if err != nil { + t.Fatalf("failed to close remote listener: %v", err) + } + _, err = os.Stat(remoteSocketPath) + if err == nil { + t.Fatal("expected remote socket to be removed after close") + } +} + +func TestValidateSocketPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + opts UnixForwardingOptions + wantErr bool + wantClean string // expected cleaned path on success + errSubstr string // substring expected in error message + wantType error // expected error type (ErrRejected) + }{ + // Basic validation (applies to all modes). + { + name: "absolute path accepted with AllowAll", + path: "/tmp/test.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/tmp/test.sock", + }, + { + name: "relative path rejected", + path: "relative/path.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantErr: true, + errSubstr: "must be absolute", + wantType: ErrRejected, + }, + { + name: "dot-relative path rejected", + path: "./local.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantErr: true, + errSubstr: "must be absolute", + wantType: ErrRejected, + }, + { + name: "empty path rejected", + path: "", + opts: UnixForwardingOptions{AllowAll: true}, + wantErr: true, + errSubstr: "must be absolute", + wantType: ErrRejected, + }, + { + name: "path with dot-dot cleaned and accepted", + path: "/tmp/foo/../bar/test.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/tmp/bar/test.sock", + }, + { + name: "path with double slashes cleaned", + path: "/tmp//foo//test.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/tmp/foo/test.sock", + }, + { + name: "path with trailing slash cleaned", + path: "/tmp/test.sock/", + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/tmp/test.sock", + }, + { + name: "path at sun_path limit rejected", + path: "/" + strings.Repeat("a", maxSunPathLen-1), + opts: UnixForwardingOptions{AllowAll: true}, + wantErr: true, + errSubstr: "too long", + wantType: ErrRejected, + }, + { + name: "path just under sun_path limit accepted", + path: "/" + strings.Repeat("a", maxSunPathLen-3), + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/" + strings.Repeat("a", maxSunPathLen-3), + }, + + // AllowedDirectories tests. + { + name: "path in allowed directory accepted", + path: "/tmp/ssh/agent.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantClean: "/tmp/ssh/agent.sock", + }, + { + name: "path in second allowed directory accepted", + path: "/home/user/.ssh/agent.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp", "/home/user"}}, + wantClean: "/home/user/.ssh/agent.sock", + }, + { + name: "path outside allowed directories rejected", + path: "/var/run/docker.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp", "/home/user"}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "empty allowed directories rejects all", + path: "/tmp/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "nil allowed directories rejects all", + path: "/tmp/test.sock", + opts: UnixForwardingOptions{}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "allowed directory itself is not a valid socket path", + path: "/tmp", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "allowed directory with trailing slash works", + path: "/tmp/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp/"}}, + wantClean: "/tmp/test.sock", + }, + { + name: "dot-dot traversal out of allowed directory rejected", + path: "/tmp/../var/run/docker.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "dot-dot traversal staying in allowed directory accepted", + path: "/tmp/foo/../bar/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantClean: "/tmp/bar/test.sock", + }, + { + name: "allowed directory prefix attack rejected", + path: "/tmpevil/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + + // DeniedPrefixes tests. + { + name: "path in denied prefix rejected", + path: "/run/user/1000/systemd/private.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}, DeniedPrefixes: []string{"/run/user/1000/systemd"}}, + wantErr: true, + errSubstr: "is denied", + wantType: ErrRejected, + }, + { + name: "exact denied path rejected", + path: "/var/run/docker.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/var/run"}, DeniedPrefixes: []string{"/var/run/docker.sock"}}, + wantErr: true, + errSubstr: "is denied", + wantType: ErrRejected, + }, + { + name: "path not matching denied prefix accepted", + path: "/run/user/1000/podman/podman.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}, DeniedPrefixes: []string{"/run/user/1000/systemd"}}, + wantClean: "/run/user/1000/podman/podman.sock", + }, + { + name: "denied prefix does not match partial directory names", + path: "/run/user/1000/systemd-resolved/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}, DeniedPrefixes: []string{"/run/user/1000/systemd"}}, + wantClean: "/run/user/1000/systemd-resolved/test.sock", + }, + + // AllowAll overrides AllowedDirectories/DeniedPrefixes. + { + name: "AllowAll ignores AllowedDirectories", + path: "/var/run/docker.sock", + opts: UnixForwardingOptions{AllowAll: true, AllowedDirectories: []string{"/tmp"}}, + wantClean: "/var/run/docker.sock", + }, + { + name: "AllowAll ignores DeniedPrefixes", + path: "/run/user/1000/systemd/private.sock", + opts: UnixForwardingOptions{AllowAll: true, DeniedPrefixes: []string{"/run/user/1000/systemd"}}, + wantClean: "/run/user/1000/systemd/private.sock", + }, + + // Real-world socket paths. + { + name: "podman socket path", + path: "/run/user/1000/podman/podman.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}}, + wantClean: "/run/user/1000/podman/podman.sock", + }, + { + name: "gpg agent socket", + path: "/home/user/.gnupg/S.gpg-agent", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/home/user"}}, + wantClean: "/home/user/.gnupg/S.gpg-agent", + }, + { + name: "gpg agent socket systemd path", + path: "/run/user/1000/gnupg/S.gpg-agent", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}}, + wantClean: "/run/user/1000/gnupg/S.gpg-agent", + }, + { + name: "vscode remote socket", + path: "/tmp/code-d0fd2e91-ed82-46dd-8394-87ac5cde31c3.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantClean: "/tmp/code-d0fd2e91-ed82-46dd-8394-87ac5cde31c3.sock", + }, + { + name: "ssh agent socket", + path: "/tmp/ssh-XXXXXXXXXX/agent.12345", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantClean: "/tmp/ssh-XXXXXXXXXX/agent.12345", + }, + { + name: "docker socket denied even when /var/run allowed", + path: "/var/run/docker.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/var/run"}, DeniedPrefixes: []string{"/var/run/docker.sock"}}, + wantErr: true, + errSubstr: "is denied", + wantType: ErrRejected, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cleaned, err := validateSocketPath(tt.path, tt.opts) + if tt.wantErr { + if err == nil { + t.Fatalf("validateSocketPath(%q) = %q, nil; want error containing %q", tt.path, cleaned, tt.errSubstr) + } + if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) { + t.Fatalf("validateSocketPath(%q) error = %q; want substring %q", tt.path, err.Error(), tt.errSubstr) + } + if tt.wantType != nil && !errors.Is(err, tt.wantType) { + t.Fatalf("validateSocketPath(%q) error type = %T; want errors.Is(%v)", tt.path, err, tt.wantType) + } + return + } + if err != nil { + t.Fatalf("validateSocketPath(%q) = error %q; want %q", tt.path, err, tt.wantClean) + } + if cleaned != tt.wantClean { + t.Fatalf("validateSocketPath(%q) = %q; want %q", tt.path, cleaned, tt.wantClean) + } + }) + } +} + +func TestRejectedMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want string + }{ + { + name: "bare ErrRejected gives generic message", + err: ErrRejected, + want: "unix forwarding is disabled", + }, + { + name: "rejectionError gives descriptive message", + err: &rejectionError{reason: "socket path must be absolute"}, + want: "socket path must be absolute", + }, + { + name: "wrapped ErrRejected gives wrapper message", + err: fmt.Errorf("custom reason: %w", ErrRejected), + want: "custom reason: ssh: rejected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := rejectedMessage(tt.err) + if got != tt.want { + t.Fatalf("rejectedMessage(%v) = %q; want %q", tt.err, got, tt.want) + } + }) + } +} + +func TestUserSocketDirectories(t *testing.T) { + t.Parallel() + + dirs := UserSocketDirectories("/home/testuser", "1000") + want := []string{"/home/testuser", "/tmp", "/run/user/1000"} + + if len(dirs) != len(want) { + t.Fatalf("UserSocketDirectories returned %d dirs; want %d", len(dirs), len(want)) + } + for i, d := range dirs { + if d != want[i] { + t.Fatalf("UserSocketDirectories()[%d] = %q; want %q", i, d, want[i]) + } + } +} + +func TestNewLocalUnixForwardingCallbackValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts UnixForwardingOptions + path string + wantErr bool + errSubstr string + }{ + { + name: "AllowAll accepts any absolute path", + opts: UnixForwardingOptions{AllowAll: true}, + path: "/var/run/docker.sock", + }, + { + name: "AllowAll rejects relative path", + opts: UnixForwardingOptions{AllowAll: true}, + path: "relative.sock", + wantErr: true, + errSubstr: "must be absolute", + }, + { + name: "restricted rejects path outside allowed dirs", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + path: "/var/run/docker.sock", + wantErr: true, + errSubstr: "not in an allowed directory", + }, + { + name: "restricted accepts path in allowed dir", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + path: "/tmp/test.sock", + }, + { + name: "PathValidator is called and can reject", + opts: UnixForwardingOptions{ + AllowAll: true, + PathValidator: func(_ Context, _ string) error { + return &rejectionError{reason: "custom validator rejected"} + }, + }, + path: "/tmp/test.sock", + wantErr: true, + errSubstr: "custom validator rejected", + }, + { + name: "PathValidator receives cleaned path", + opts: UnixForwardingOptions{ + AllowAll: true, + PathValidator: func(_ Context, path string) error { + if path != "/tmp/test.sock" { + return fmt.Errorf("expected /tmp/test.sock, got %s", path) + } + return nil + }, + }, + path: "/tmp/foo/../test.sock", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := newContext(nil) + defer cancel() + + cb := NewLocalUnixForwardingCallback(tt.opts) + // The callback tries to dial the socket, which will fail for + // non-existent paths. We only care about the validation errors + // (which wrap ErrRejected), not dial errors. + _, err := cb(ctx, tt.path) + if tt.wantErr { + if err == nil { + t.Fatal("expected error but got nil") + } + if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) { + t.Fatalf("error = %q; want substring %q", err.Error(), tt.errSubstr) + } + if !errors.Is(err, ErrRejected) { + t.Fatalf("expected ErrRejected, got %T: %v", err, err) + } + return + } + // For valid paths, the error should either be nil (socket + // exists) or a dial error (socket doesn't exist), but NOT + // a validation/rejection error. + if err != nil && errors.Is(err, ErrRejected) { + t.Fatalf("unexpected rejection error: %v", err) + } + }) + } +} + +func TestNewReverseUnixForwardingCallbackValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts UnixForwardingOptions + path string + wantErr bool + errSubstr string + }{ + { + name: "rejects relative path", + opts: UnixForwardingOptions{AllowAll: true}, + path: "relative.sock", + wantErr: true, + errSubstr: "must be absolute", + }, + { + name: "rejects path outside allowed dirs", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + path: "/var/run/test.sock", + wantErr: true, + errSubstr: "not in an allowed directory", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := newContext(nil) + defer cancel() + + cb := NewReverseUnixForwardingCallback(tt.opts) + _, err := cb(ctx, tt.path) + if tt.wantErr { + if err == nil { + t.Fatal("expected error but got nil") + } + if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) { + t.Fatalf("error = %q; want substring %q", err.Error(), tt.errSubstr) + } + if !errors.Is(err, ErrRejected) { + t.Fatalf("expected ErrRejected, got %T: %v", err, err) + } + return + } + if err != nil && errors.Is(err, ErrRejected) { + t.Fatalf("unexpected rejection error: %v", err) + } + }) + } +} + +func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) { + t.Parallel() + + dir := tempDirUnixSocket(t) + sockPath := filepath.Join(dir, "test.sock") + + // Create an existing socket. Keep the listener open so the socket + // file persists (Go's UnixListener.Close removes the file). + oldLn, err := net.Listen("unix", sockPath) + if err != nil { + t.Fatalf("failed to create socket: %v", err) + } + defer oldLn.Close() //nolint:errcheck + + // Without BindUnlink, listening should fail because socket exists. + cbNoUnlink := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + }) + _, err = cbNoUnlink(nil, sockPath) + if err == nil { + t.Fatal("expected listen to fail on existing socket without BindUnlink") + } + + // With BindUnlink, the old socket is removed and we can listen. + cbUnlink := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + BindUnlink: true, + }) + newLn, err := cbUnlink(nil, sockPath) + if err != nil { + t.Fatalf("expected listen to succeed with BindUnlink, got: %v", err) + } + _ = newLn.Close() +} + +func TestNewReverseUnixForwardingCallbackBindUnlinkSkipsNonSocket(t *testing.T) { + t.Parallel() + + dir := tempDirUnixSocket(t) + filePath := filepath.Join(dir, "regular.file") + + // Create a regular file at the path. + if err := os.WriteFile(filePath, []byte("data"), 0600); err != nil { + t.Fatalf("failed to create regular file: %v", err) + } + + // BindUnlink should NOT remove regular files. Listen should fail. + cb := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + BindUnlink: true, + }) + _, err := cb(nil, filePath) + if err == nil { + t.Fatal("expected listen to fail on regular file even with BindUnlink") + } + + // Regular file should still exist. + if _, err := os.Stat(filePath); err != nil { + t.Fatalf("regular file should not have been deleted: %v", err) + } +} + +func TestNewReverseUnixForwardingCallbackSocketPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mask *os.FileMode + wantPerm os.FileMode + }{ + {name: "default mask 0177 gives mode 0600", mask: nil, wantPerm: 0600}, + {name: "custom mask 0117 gives mode 0660", mask: fileMode(0117), wantPerm: 0660}, + {name: "zero mask gives mode 0666", mask: fileMode(0), wantPerm: 0666}, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Use a short /tmp path to stay under sun_path limits + // even when the test framework creates long temp paths. + dir, err := os.MkdirTemp("/tmp", "ssh-perm-") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + sockPath := filepath.Join(dir, fmt.Sprintf("p%d.s", i)) + + cb := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + BindMask: tt.mask, + }) + ln, err := cb(nil, sockPath) + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer ln.Close() //nolint:errcheck + + info, err := os.Stat(sockPath) + if err != nil { + t.Fatalf("failed to stat socket: %v", err) + } + perm := info.Mode().Perm() + if perm != tt.wantPerm { + t.Fatalf("socket permissions = %04o; want %04o", perm, tt.wantPerm) + } + }) + } +} + +func fileMode(m os.FileMode) *os.FileMode { return &m } + +func TestReverseUnixForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + + var called int64 + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) { + atomic.AddInt64(&called, 1) + if socketPath != remoteSocketPath { + panic("unexpected socket path: " + socketPath) + } + return nil, ErrRejected + }, + }, nil) + defer cleanup() + + _, err := client.ListenUnix(remoteSocketPath) + if err == nil { + t.Fatalf("Expected error listening on %q but it succeeded", remoteSocketPath) + } + + if atomic.LoadInt64(&called) != 1 { + t.Fatalf("Expected callback to be called once but it was called %d times", called) + } +} diff --git a/tcpip.go b/tcpip.go index 335fda6..843704a 100644 --- a/tcpip.go +++ b/tcpip.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "io" "log" "net" @@ -53,16 +54,7 @@ func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewCh } go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(ch, dconn) - }() - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(dconn, ch) - }() + bicopy(ctx, ch, dconn) } type remoteForwardRequest struct { @@ -117,8 +109,14 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go // TODO: log listen failure return false, []byte{} } + + // If the bind port was port 0, we need to use the actual port in the + // listener map. _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) destPort, _ := strconv.Atoi(destPortStr) + if reqPayload.BindPort == 0 { + addr = net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(destPort)) + } h.Lock() h.forwards[addr] = ln h.Unlock() @@ -155,16 +153,7 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go return } go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer c.Close() - io.Copy(ch, c) - }() - go func() { - defer ch.Close() - defer c.Close() - io.Copy(c, ch) - }() + bicopy(ctx, ch, c) }() } h.Lock() @@ -191,3 +180,43 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go return false, nil } } + +// bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + defer func() { + _ = c1.Close() + _ = c2.Close() + }() + + var wg sync.WaitGroup + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer func() { + wg.Done() + // If one side of the copy fails, ensure the other one exits as + // well. + cancel() + }() + _, _ = io.Copy(dst, src) + } + + wg.Add(2) + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } +} diff --git a/tcpip_test.go b/tcpip_test.go index 4ddf40e..f6a9fa9 100644 --- a/tcpip_test.go +++ b/tcpip_test.go @@ -2,19 +2,22 @@ package ssh import ( "bytes" + "context" "io" "net" "strconv" "strings" + "sync/atomic" "testing" + "time" gossh "golang.org/x/crypto/ssh" ) var sampleServerResponse = []byte("Hello world") -func sampleSocketServer() net.Listener { - l := newLocalListener() +func sampleTCPSocketServer() net.Listener { + l := newLocalTCPListener() go func() { conn, err := l.Accept() @@ -29,7 +32,7 @@ func sampleSocketServer() net.Listener { } func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { - l := sampleSocketServer() + l := sampleTCPSocketServer() _, client, cleanup := newTestSession(t, &Server{ Handler: func(s Session) {}, @@ -81,3 +84,92 @@ func TestLocalPortForwardingRespectsCallback(t *testing.T) { t.Fatalf("Expected permission error but got %#v", err) } } + +func TestReverseTCPForwardingWorks(t *testing.T) { + t.Parallel() + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReversePortForwardingCallback: func(ctx Context, bindHost string, bindPort uint32) bool { + if bindHost != "127.0.0.1" { + panic("unexpected bindHost: " + bindHost) + } + if bindPort != 0 { + panic("unexpected bindPort: " + strconv.Itoa(int(bindPort))) + } + return true + }, + }, nil) + defer cleanup() + + l, err := client.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on a random TCP port over SSH: %v", err) + } + defer l.Close() //nolint:errcheck + go func() { + conn, err := l.Accept() + if err != nil { + return + } + _, _ = conn.Write(sampleServerResponse) + _ = conn.Close() + }() + + // Dial the listener that should've been created by the server. + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } + + // Close the listener and make sure that the port is no longer in use. + err = l.Close() + if err != nil { + t.Fatalf("failed to close remote listener: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + var d net.Dialer + _, err = d.DialContext(ctx, "tcp", l.Addr().String()) + if err == nil { + t.Fatalf("expected error connecting to %v but it succeeded", l.Addr().String()) + } +} + +func TestReverseTCPForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + var called int64 + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReversePortForwardingCallback: func(ctx Context, bindHost string, bindPort uint32) bool { + atomic.AddInt64(&called, 1) + if bindHost != "127.0.0.1" { + panic("unexpected bindHost: " + bindHost) + } + if bindPort != 0 { + panic("unexpected bindPort: " + strconv.Itoa(int(bindPort))) + } + return false + }, + }, nil) + defer cleanup() + + _, err := client.Listen("tcp", "127.0.0.1:0") + if err == nil { + t.Fatalf("Expected error listening on random port but it succeeded") + } + + if atomic.LoadInt64(&called) != 1 { + t.Fatalf("Expected callback to be called once but it was called %d times", called) + } +}