From cc05bc5764459a95734a8091cfaa46fc3c99715b Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 23 Oct 2024 20:22:19 +0900 Subject: [PATCH 1/8] feat: add Unix forwarding server implementations Adds optional (disabled by default) implementations of local->remote and remote->local Unix forwarding through OpenSSH's protocol extensions: - streamlocal-forward@openssh.com - cancel-streamlocal-forward@openssh.com - forwarded-streamlocal@openssh.com - direct-streamlocal@openssh.com Adds tests for Unix forwarding, reverse Unix forwarding and reverse TCP forwarding. Co-authored-by: Samuel Corsi-House --- options_test.go | 2 +- server.go | 2 + server_test.go | 4 +- session_test.go | 19 +++- ssh.go | 20 ++++ streamlocal.go | 252 ++++++++++++++++++++++++++++++++++++++++++++ streamlocal_test.go | 206 ++++++++++++++++++++++++++++++++++++ tcpip.go | 69 ++++++++---- tcpip_test.go | 98 ++++++++++++++++- 9 files changed, 642 insertions(+), 30 deletions(-) create mode 100644 streamlocal.go create mode 100644 streamlocal_test.go diff --git a/options_test.go b/options_test.go index 23fca5ab..2992b6a0 100644 --- a/options_test.go +++ b/options_test.go @@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) { func TestPasswordAuthBadPass(t *testing.T) { t.Parallel() - l := newLocalListener() + l := newLocalTCPListener() srv := &Server{Handler: func(s Session) {}} srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { return false diff --git a/server.go b/server.go index 6e0eab4b..8824cdfa 100644 --- a/server.go +++ b/server.go @@ -47,6 +47,8 @@ type Server struct { ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil + LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding (direct-streamlocal@openssh.com), denies all if nil + ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding (streamlocal-forward@openssh.com), denies all if nil ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions diff --git a/server_test.go b/server_test.go index 11978e62..0dfd73c6 100644 --- a/server_test.go +++ b/server_test.go @@ -30,7 +30,7 @@ func TestAddHostKey(t *testing.T) { } func TestServerShutdown(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() testBytes := []byte("Hello world\n") s := &Server{ Handler: func(s Session) { @@ -82,7 +82,7 @@ func TestServerShutdown(t *testing.T) { } func TestServerClose(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() s := &Server{ Handler: func(s Session) { time.Sleep(5 * time.Second) diff --git a/session_test.go b/session_test.go index 4f6b5cad..3291fb5f 100644 --- a/session_test.go +++ b/session_test.go @@ -21,14 +21,25 @@ func (srv *Server) serveOnce(l net.Listener) error { return e } srv.ChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, - "direct-tcpip": DirectTCPIPHandler, + "session": DefaultSessionHandler, + "direct-tcpip": DirectTCPIPHandler, + "direct-streamlocal@openssh.com": DirectStreamLocalHandler, } + + forwardedTCPHandler := &ForwardedTCPHandler{} + forwardedUnixHandler := &ForwardedUnixHandler{} + srv.RequestHandlers = map[string]RequestHandler{ + "tcpip-forward": forwardedTCPHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest, + } + srv.HandleConn(conn) return nil } -func newLocalListener() net.Listener { +func newLocalTCPListener() net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { @@ -65,7 +76,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g } func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - l := newLocalListener() + l := newLocalTCPListener() go srv.serveOnce(l) return newClientSession(t, l.Addr().String(), cfg) } diff --git a/ssh.go b/ssh.go index e2dd1610..38c29d6b 100644 --- a/ssh.go +++ b/ssh.go @@ -2,6 +2,7 @@ package ssh import ( "crypto/subtle" + "errors" "net" gossh "golang.org/x/crypto/ssh" @@ -29,6 +30,9 @@ const ( // DefaultHandler is the default Handler used by Serve. var DefaultHandler Handler +// ErrReject is returned by some callbacks to reject a request. +var ErrRejected = errors.New("ssh: rejected") + // Option is a functional option handler for Server. type Option func(*Server) error @@ -66,6 +70,22 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti // ReversePortForwardingCallback is a hook for allowing reverse port forwarding type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool +// LocalUnixForwardingCallback is a hook for allowing unix forwarding +// (direct-streamlocal@openssh.com). Returning ErrRejected will reject the +// request. The returned net.Conn will be closed by the server when no longer +// needed. +// +// Use SimpleUnixLocalForwardingCallback for a basic implementation. +type LocalUnixForwardingCallback func(ctx Context, socketPath string) (net.Conn, error) + +// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding +// (streamlocal-forward@openssh.com). Returning ErrRejected will reject the +// request. The returned net.Listener will be closed by the server when no +// longer needed. +// +// Use SimpleUnixReverseForwardingCallback for a basic implementation. +type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error) + // ServerConfigCallback is a hook for creating custom default server configs type ServerConfigCallback func(ctx Context) *gossh.ServerConfig diff --git a/streamlocal.go b/streamlocal.go new file mode 100644 index 00000000..2daa1a21 --- /dev/null +++ b/streamlocal.go @@ -0,0 +1,252 @@ +package ssh + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net" + "os" + "path/filepath" + "sync" + "syscall" + + gossh "golang.org/x/crypto/ssh" +) + +const ( + forwardedUnixChannelType = "forwarded-streamlocal@openssh.com" +) + +// directStreamLocalChannelData data struct as specified in OpenSSH's protocol +// extensions document, Section 2.4. +// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD +type directStreamLocalChannelData struct { + SocketPath string + + Reserved1 string + Reserved2 uint32 +} + +// DirectStreamLocalHandler provides Unix forwarding from client -> server. It +// can be enabled by adding it to the server's ChannelHandlers under +// `direct-streamlocal@openssh.com`. +// +// Unix socket support on Windows is not widely available, so this handler may +// not work on all Windows installations and is not tested on Windows. +func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + var d directStreamLocalChannelData + err := gossh.Unmarshal(newChan.ExtraData(), &d) + if err != nil { + _ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error()) + return + } + + if srv.LocalUnixForwardingCallback == nil { + _ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled") + return + } + dconn, err := srv.LocalUnixForwardingCallback(ctx, d.SocketPath) + if err != nil { + if errors.Is(err, ErrRejected) { + _ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled") + return + } + _ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error())) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = dconn.Close() + return + } + go gossh.DiscardRequests(reqs) + + bicopy(ctx, ch, dconn) +} + +// remoteUnixForwardRequest describes the extra data sent in a +// streamlocal-forward@openssh.com containing the socket path to bind to. +type remoteUnixForwardRequest struct { + SocketPath string +} + +// remoteUnixForwardChannelData describes the data sent as the payload in the new +// channel request when a Unix connection is accepted by the listener. +type remoteUnixForwardChannelData struct { + SocketPath string + Reserved uint32 +} + +// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and +// adding the HandleSSHRequest callback to the server's RequestHandlers under +// `streamlocal-forward@openssh.com` and +// `cancel-streamlocal-forward@openssh.com` +// +// Unix socket support on Windows is not widely available, so this handler may +// not work on all Windows installations and is not tested on Windows. +type ForwardedUnixHandler struct { + sync.Mutex + forwards map[string]net.Listener +} + +func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[string]net.Listener) + } + h.Unlock() + conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn) + if !ok { + // TODO: log cast failure + return false, nil + } + + switch req.Type { + case "streamlocal-forward@openssh.com": + var reqPayload remoteUnixForwardRequest + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + // TODO: log parse failure + return false, nil + } + + if srv.ReverseUnixForwardingCallback == nil { + return false, []byte("unix forwarding is disabled") + } + + addr := reqPayload.SocketPath + h.Lock() + _, ok := h.forwards[addr] + h.Unlock() + if ok { + // TODO: log failure + return false, nil + } + + ln, err := srv.ReverseUnixForwardingCallback(ctx, addr) + if err != nil { + if errors.Is(err, ErrRejected) { + return false, []byte("unix forwarding is disabled") + } + // TODO: log unix listen failure + return false, nil + } + + // The listener needs to successfully start before it can be added to + // the map, so we don't have to worry about checking for an existing + // listener as you can't listen on the same socket twice. + // + // This is also what the TCP version of this code does. + h.Lock() + h.forwards[addr] = ln + h.Unlock() + + ctx, cancel := context.WithCancel(ctx) + go func() { + <-ctx.Done() + _ = ln.Close() + }() + go func() { + defer cancel() + + for { + c, err := ln.Accept() + if err != nil { + // closed below + break + } + payload := gossh.Marshal(&remoteUnixForwardChannelData{ + SocketPath: addr, + }) + + go func() { + ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload) + if err != nil { + _ = c.Close() + return + } + go gossh.DiscardRequests(reqs) + bicopy(ctx, ch, c) + }() + } + + h.Lock() + ln2, ok := h.forwards[addr] + if ok && ln2 == ln { + delete(h.forwards, addr) + } + h.Unlock() + _ = ln.Close() + }() + + return true, nil + + case "cancel-streamlocal-forward@openssh.com": + var reqPayload remoteUnixForwardRequest + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + // TODO: log parse failure + return false, nil + } + h.Lock() + ln, ok := h.forwards[reqPayload.SocketPath] + h.Unlock() + if ok { + _ = ln.Close() + } + return true, nil + + default: + return false, nil + } +} + +// unlink removes files and unlike os.Remove, directories are kept. +func unlink(path string) error { + // Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go + // for more details. + for { + err := syscall.Unlink(path) + if !errors.Is(err, syscall.EINTR) { + return err + } + } +} + +// SimpleUnixLocalForwardingCallback provides a basic implementation for +// LocalUnixForwardingCallback. It will simply dial the requested socket using +// a context-aware dialer. +func SimpleUnixLocalForwardingCallback(ctx Context, socketPath string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) +} + +// SimpleUnixReverseForwardingCallback provides a basic implementation for +// ReverseUnixForwardingCallback. The parent directory will be created (with +// os.MkdirAll), and existing files with the same name will be removed. +func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) { + // Create socket parent dir if not exists. + parentDir := filepath.Dir(socketPath) + err := os.MkdirAll(parentDir, 0700) + if err != nil { + return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err) + } + + // Remove existing socket if it exists. We do not use os.Remove() here + // so that directories are kept. Note that it's possible that we will + // overwrite a regular file here. Both of these behaviors match OpenSSH, + // however, which is why we unlink. + err = unlink(socketPath) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err) + } + + ln, err := net.Listen("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err) + } + + return ln, err +} diff --git a/streamlocal_test.go b/streamlocal_test.go new file mode 100644 index 00000000..41ae3c92 --- /dev/null +++ b/streamlocal_test.go @@ -0,0 +1,206 @@ +package ssh + +import ( + "bytes" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strings" + "sync/atomic" + "testing" + + gossh "golang.org/x/crypto/ssh" +) + +// tempDirUnixSocket returns a temporary directory that can safely hold unix +// sockets. +// +// On all platforms other than darwin this just returns t.TempDir(). On darwin +// we manually make a temporary directory in /tmp because t.TempDir() returns a +// very long directory name, and the path length limit for Unix sockets on +// darwin is 104 characters. +func tempDirUnixSocket(t *testing.T) string { + t.Helper() + if runtime.GOOS == "darwin" { + testName := strings.ReplaceAll(t.Name(), "/", "_") + dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("gliderlabs-ssh-test-%s-", testName)) + if err != nil { + t.Fatalf("create temp dir for test: %v", err) + } + + t.Cleanup(func() { + err := os.RemoveAll(dir) + if err != nil { + t.Errorf("remove temp dir %s: %v", dir, err) + } + }) + return dir + } + + return t.TempDir() +} + +func newLocalUnixListener(t *testing.T) net.Listener { + path := filepath.Join(tempDirUnixSocket(t), "socket.sock") + l, err := net.Listen("unix", path) + if err != nil { + t.Fatalf("failed to listen on a unix socket %q: %v", path, err) + } + return l +} + +func sampleUnixSocketServer(t *testing.T) net.Listener { + l := newLocalUnixListener(t) + + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + return l +} + +func newTestSessionWithUnixForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { + l := sampleUnixSocketServer(t) + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + LocalUnixForwardingCallback: func(ctx Context, socketPath string) (net.Conn, error) { + if socketPath != l.Addr().String() { + panic("unexpected socket path: " + socketPath) + } + if !forwardingEnabled { + return nil, ErrRejected + } + return SimpleUnixLocalForwardingCallback(ctx, socketPath) + }, + }, nil) + + return l, client, func() { + cleanup() + l.Close() + } +} + +func TestLocalUnixForwardingWorks(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithUnixForwarding(t, true) + defer cleanup() + + conn, err := client.Dial("unix", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } +} + +func TestLocalUnixForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithUnixForwarding(t, false) + defer cleanup() + + _, err := client.Dial("unix", l.Addr().String()) + if err == nil { + t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) + } + if !strings.Contains(err.Error(), "unix forwarding is disabled") { + t.Fatalf("Expected permission error but got %#v", err) + } +} + +func TestReverseUnixForwardingWorks(t *testing.T) { + t.Parallel() + + remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) { + if socketPath != remoteSocketPath { + panic("unexpected socket path: " + socketPath) + } + return SimpleUnixReverseForwardingCallback(ctx, socketPath) + }, + }, nil) + defer cleanup() + + l, err := client.ListenUnix(remoteSocketPath) + if err != nil { + t.Fatalf("failed to listen on a unix socket over SSH %q: %v", remoteSocketPath, err) + } + defer l.Close() + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + // Dial the listener that should've been created by the server. + conn, err := net.Dial("unix", remoteSocketPath) + if err != nil { + t.Fatalf("Error connecting to %v: %v", remoteSocketPath, err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } + + // Close the listener and make sure that the Unix socket is gone. + err = l.Close() + if err != nil { + t.Fatalf("failed to close remote listener: %v", err) + } + _, err = os.Stat(remoteSocketPath) + if err == nil && !os.IsNotExist(err) { + t.Fatalf("expected remote socket to be gone but it still exists: %v", err) + } +} + +func TestReverseUnixForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + + var called int64 + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) { + atomic.AddInt64(&called, 1) + if socketPath != remoteSocketPath { + panic("unexpected socket path: " + socketPath) + } + return nil, ErrRejected + }, + }, nil) + defer cleanup() + + _, err := client.ListenUnix(remoteSocketPath) + if err == nil { + t.Fatalf("Expected error listening on %q but it succeeded", remoteSocketPath) + } + + if atomic.LoadInt64(&called) != 1 { + t.Fatalf("Expected callback to be called once but it was called %d times", called) + } +} diff --git a/tcpip.go b/tcpip.go index 335fda65..843704ad 100644 --- a/tcpip.go +++ b/tcpip.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "io" "log" "net" @@ -53,16 +54,7 @@ func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewCh } go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(ch, dconn) - }() - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(dconn, ch) - }() + bicopy(ctx, ch, dconn) } type remoteForwardRequest struct { @@ -117,8 +109,14 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go // TODO: log listen failure return false, []byte{} } + + // If the bind port was port 0, we need to use the actual port in the + // listener map. _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) destPort, _ := strconv.Atoi(destPortStr) + if reqPayload.BindPort == 0 { + addr = net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(destPort)) + } h.Lock() h.forwards[addr] = ln h.Unlock() @@ -155,16 +153,7 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go return } go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer c.Close() - io.Copy(ch, c) - }() - go func() { - defer ch.Close() - defer c.Close() - io.Copy(c, ch) - }() + bicopy(ctx, ch, c) }() } h.Lock() @@ -191,3 +180,43 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go return false, nil } } + +// bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + defer func() { + _ = c1.Close() + _ = c2.Close() + }() + + var wg sync.WaitGroup + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer func() { + wg.Done() + // If one side of the copy fails, ensure the other one exits as + // well. + cancel() + }() + _, _ = io.Copy(dst, src) + } + + wg.Add(2) + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } +} diff --git a/tcpip_test.go b/tcpip_test.go index 4ddf40e5..525ca2d7 100644 --- a/tcpip_test.go +++ b/tcpip_test.go @@ -2,19 +2,22 @@ package ssh import ( "bytes" + "context" "io" "net" "strconv" "strings" + "sync/atomic" "testing" + "time" gossh "golang.org/x/crypto/ssh" ) var sampleServerResponse = []byte("Hello world") -func sampleSocketServer() net.Listener { - l := newLocalListener() +func sampleTCPSocketServer() net.Listener { + l := newLocalTCPListener() go func() { conn, err := l.Accept() @@ -29,7 +32,7 @@ func sampleSocketServer() net.Listener { } func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { - l := sampleSocketServer() + l := sampleTCPSocketServer() _, client, cleanup := newTestSession(t, &Server{ Handler: func(s Session) {}, @@ -81,3 +84,92 @@ func TestLocalPortForwardingRespectsCallback(t *testing.T) { t.Fatalf("Expected permission error but got %#v", err) } } + +func TestReverseTCPForwardingWorks(t *testing.T) { + t.Parallel() + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReversePortForwardingCallback: func(ctx Context, bindHost string, bindPort uint32) bool { + if bindHost != "127.0.0.1" { + panic("unexpected bindHost: " + bindHost) + } + if bindPort != 0 { + panic("unexpected bindPort: " + strconv.Itoa(int(bindPort))) + } + return true + }, + }, nil) + defer cleanup() + + l, err := client.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on a random TCP port over SSH: %v", err) + } + defer l.Close() + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + // Dial the listener that should've been created by the server. + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } + + // Close the listener and make sure that the port is no longer in use. + err = l.Close() + if err != nil { + t.Fatalf("failed to close remote listener: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + var d net.Dialer + _, err = d.DialContext(ctx, "tcp", l.Addr().String()) + if err == nil { + t.Fatalf("expected error connecting to %v but it succeeded", l.Addr().String()) + } +} + +func TestReverseTCPForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + var called int64 + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReversePortForwardingCallback: func(ctx Context, bindHost string, bindPort uint32) bool { + atomic.AddInt64(&called, 1) + if bindHost != "127.0.0.1" { + panic("unexpected bindHost: " + bindHost) + } + if bindPort != 0 { + panic("unexpected bindPort: " + strconv.Itoa(int(bindPort))) + } + return false + }, + }, nil) + defer cleanup() + + _, err := client.Listen("tcp", "127.0.0.1:0") + if err == nil { + t.Fatalf("Expected error listening on random port but it succeeded") + } + + if atomic.LoadInt64(&called) != 1 { + t.Fatalf("Expected callback to be called once but it was called %d times", called) + } +} From cd9af1de91d32f6be1e474323d4641e73b54068f Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 16 Mar 2026 14:48:37 +0000 Subject: [PATCH 2/8] streamlocal: replace Simple callbacks with secure forwarding options Replace SimpleUnixLocalForwardingCallback and SimpleUnixReverseForwardingCallback with configurable, secure-by-default alternatives: - NewLocalUnixForwardingCallback(opts) validates socket paths before dialing - NewReverseUnixForwardingCallback(opts) validates paths, applies restrictive socket permissions (default mode 0600 matching OpenSSH), and does not create parent directories Add UnixForwardingOptions with: - AllowAll: permits any absolute path (for ACL-gated deployments) - AllowedDirectories/DeniedPrefixes: directory-level allow/deny lists - BindUnlink: opt-in socket replacement (only unlinks actual sockets) - BindMask: configurable umask (default 0177, matching OpenSSH) - PathValidator: hook for custom per-request validation Add UserSocketDirectories(homeDir, uid) helper returning common socket directories (/tmp, ~, /run/user/) for easy integration. Surface descriptive error messages to SSH clients: path validation errors are sent as 'administratively prohibited' with the specific reason instead of a generic 'unix forwarding is disabled'. Updates gliderlabs/ssh#196 --- ssh.go | 24 +- streamlocal.go | 240 +++++++++++++++--- streamlocal_test.go | 586 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 811 insertions(+), 39 deletions(-) diff --git a/ssh.go b/ssh.go index 38c29d6b..2f996c43 100644 --- a/ssh.go +++ b/ssh.go @@ -71,19 +71,27 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool // LocalUnixForwardingCallback is a hook for allowing unix forwarding -// (direct-streamlocal@openssh.com). Returning ErrRejected will reject the -// request. The returned net.Conn will be closed by the server when no longer -// needed. +// (direct-streamlocal@openssh.com). The callback receives the client-requested +// socket path and returns a connection to the target socket, or an error. // -// Use SimpleUnixLocalForwardingCallback for a basic implementation. +// Returning ErrRejected (or an error wrapping it) rejects the request with +// "administratively prohibited" and the error message is sent to the client. +// Any other error rejects with "connection failed." +// +// Use NewLocalUnixForwardingCallback to create a callback with built-in path +// validation and security controls. type LocalUnixForwardingCallback func(ctx Context, socketPath string) (net.Conn, error) // ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding -// (streamlocal-forward@openssh.com). Returning ErrRejected will reject the -// request. The returned net.Listener will be closed by the server when no -// longer needed. +// (streamlocal-forward@openssh.com). The callback receives the client-requested +// socket path and returns a listener bound to that path, or an error. +// +// Returning ErrRejected (or an error wrapping it) rejects the request with +// "administratively prohibited" and the error message is sent to the client. +// Any other error rejects the request silently. // -// Use SimpleUnixReverseForwardingCallback for a basic implementation. +// Use NewReverseUnixForwardingCallback to create a callback with built-in path +// validation, permission controls, and security defaults matching OpenSSH. type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error) // ServerConfigCallback is a hook for creating custom default server configs diff --git a/streamlocal.go b/streamlocal.go index 2daa1a21..c20278a9 100644 --- a/streamlocal.go +++ b/streamlocal.go @@ -8,12 +8,18 @@ import ( "net" "os" "path/filepath" + "strings" "sync" "syscall" gossh "golang.org/x/crypto/ssh" ) +// maxSunPathLen is the maximum length of a Unix domain socket path on the +// current platform, derived from the kernel's sockaddr_un.sun_path field. +// This is 108 on Linux and 104 on macOS/BSD. +var maxSunPathLen = len(syscall.RawSockaddrUnix{}.Path) + const ( forwardedUnixChannelType = "forwarded-streamlocal@openssh.com" ) @@ -49,10 +55,10 @@ func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.Ne dconn, err := srv.LocalUnixForwardingCallback(ctx, d.SocketPath) if err != nil { if errors.Is(err, ErrRejected) { - _ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled") + _ = newChan.Reject(gossh.Prohibited, rejectedMessage(err)) return } - _ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error())) + _ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %v", d.SocketPath, err)) return } @@ -128,7 +134,7 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g ln, err := srv.ReverseUnixForwardingCallback(ctx, addr) if err != nil { if errors.Is(err, ErrRejected) { - return false, []byte("unix forwarding is disabled") + return false, []byte(rejectedMessage(err)) } // TODO: log unix listen failure return false, nil @@ -215,38 +221,214 @@ func unlink(path string) error { } } -// SimpleUnixLocalForwardingCallback provides a basic implementation for -// LocalUnixForwardingCallback. It will simply dial the requested socket using -// a context-aware dialer. -func SimpleUnixLocalForwardingCallback(ctx Context, socketPath string) (net.Conn, error) { - var d net.Dialer - return d.DialContext(ctx, "unix", socketPath) +// rejectedMessage returns a user-facing rejection message. If err is a bare +// ErrRejected (no wrapping context), it returns the generic "unix forwarding +// is disabled" for backward compatibility. Wrapped errors (e.g. rejectionError) +// return their descriptive message. +func rejectedMessage(err error) string { + if err == ErrRejected { //nolint:errorlint // intentional identity check + return "unix forwarding is disabled" + } + return err.Error() } -// SimpleUnixReverseForwardingCallback provides a basic implementation for -// ReverseUnixForwardingCallback. The parent directory will be created (with -// os.MkdirAll), and existing files with the same name will be removed. -func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) { - // Create socket parent dir if not exists. - parentDir := filepath.Dir(socketPath) - err := os.MkdirAll(parentDir, 0700) - if err != nil { - return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err) +// rejectionError wraps ErrRejected with a descriptive reason for the SSH +// client. It satisfies errors.Is(err, ErrRejected) so that handlers send +// the rejection as "administratively prohibited" with the descriptive message. +type rejectionError struct { + reason string +} + +func (e *rejectionError) Error() string { return e.reason } +func (e *rejectionError) Unwrap() error { return ErrRejected } + +// UnixForwardingOptions configures the behavior of +// NewLocalUnixForwardingCallback and NewReverseUnixForwardingCallback. +type UnixForwardingOptions struct { + // AllowAll, if true, permits any absolute socket path without directory + // restrictions. AllowedDirectories and DeniedPrefixes are ignored when + // set. Basic sanitization (absolute path, length, filepath.Clean) is + // still applied. + AllowAll bool + + // AllowedDirectories is the list of directory prefixes under which + // socket paths are permitted. Paths are cleaned with filepath.Clean + // before prefix matching. Ignored when AllowAll is true. + // When AllowAll is false and AllowedDirectories is empty, all + // requests are denied. + AllowedDirectories []string + + // DeniedPrefixes is an optional denylist applied after the allowlist. + // Useful for excluding sensitive sub-paths within allowed directories + // (e.g. /run/user/1000/systemd/ within /run/user/1000/). + // Ignored when AllowAll is true. + DeniedPrefixes []string + + // BindUnlink controls whether an existing socket file is removed + // before binding (reverse forwarding only). Only socket-type files + // are removed; regular files are left in place and the listen will + // fail with EADDRINUSE. Default: false. + // Matches OpenSSH's StreamLocalBindUnlink (default: no). + BindUnlink bool + + // BindMask is the umask applied when creating listening sockets + // (reverse forwarding only). The resulting socket permission is + // 0666 &^ BindMask. If nil, defaults to 0177 (socket permission + // 0600, owner read/write only). + // Matches OpenSSH's StreamLocalBindMask. + BindMask *os.FileMode + + // PathValidator is an optional additional validation function called + // after built-in checks pass. Return an error wrapping ErrRejected + // (or a *rejectionError) for "administratively prohibited" semantics, + // or any other error for "connection failed." + PathValidator func(ctx Context, socketPath string) error +} + +// validateSocketPath checks that socketPath is safe according to opts. +// It returns the cleaned path on success. Returned errors wrap ErrRejected +// so that handlers report them as "administratively prohibited" with a +// descriptive message. +func validateSocketPath(socketPath string, opts UnixForwardingOptions) (string, error) { + if !filepath.IsAbs(socketPath) { + return "", &rejectionError{reason: "socket path must be absolute"} } - // Remove existing socket if it exists. We do not use os.Remove() here - // so that directories are kept. Note that it's possible that we will - // overwrite a regular file here. Both of these behaviors match OpenSSH, - // however, which is why we unlink. - err = unlink(socketPath) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err) + cleaned := filepath.Clean(socketPath) + + if strings.ContainsRune(cleaned, 0) { + return "", &rejectionError{reason: "socket path contains NUL byte"} } - ln, err := net.Listen("unix", socketPath) - if err != nil { - return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err) + if len(cleaned) >= maxSunPathLen { + return "", &rejectionError{ + reason: fmt.Sprintf("socket path too long (%d >= %d)", len(cleaned), maxSunPathLen), + } + } + + if !opts.AllowAll { + if len(opts.AllowedDirectories) == 0 { + return "", &rejectionError{ + reason: fmt.Sprintf("socket path %q is not in an allowed directory", cleaned), + } + } + + allowed := false + for _, dir := range opts.AllowedDirectories { + prefix := filepath.Clean(dir) + if !strings.HasSuffix(prefix, string(filepath.Separator)) { + prefix += string(filepath.Separator) + } + if strings.HasPrefix(cleaned, prefix) { + allowed = true + break + } + } + if !allowed { + return "", &rejectionError{ + reason: fmt.Sprintf("socket path %q is not in an allowed directory", cleaned), + } + } + + for _, denied := range opts.DeniedPrefixes { + prefix := filepath.Clean(denied) + if cleaned == prefix || strings.HasPrefix(cleaned, prefix+string(filepath.Separator)) { + return "", &rejectionError{ + reason: fmt.Sprintf("socket path %q is denied", cleaned), + } + } + } + } + + return cleaned, nil +} + +// NewLocalUnixForwardingCallback returns a LocalUnixForwardingCallback that +// validates socket paths against the provided options before dialing. +// Path validation errors are reported to the SSH client as +// "administratively prohibited" rejections with descriptive messages. +func NewLocalUnixForwardingCallback(opts UnixForwardingOptions) LocalUnixForwardingCallback { + return func(ctx Context, socketPath string) (net.Conn, error) { + cleaned, err := validateSocketPath(socketPath, opts) + if err != nil { + return nil, err + } + if opts.PathValidator != nil { + if err := opts.PathValidator(ctx, cleaned); err != nil { + return nil, err + } + } + + var d net.Dialer + return d.DialContext(ctx, "unix", cleaned) } +} + +// NewReverseUnixForwardingCallback returns a ReverseUnixForwardingCallback +// that validates socket paths against the provided options before listening. +// +// Unlike a bare net.Listen, this callback: +// - Validates the socket path against allow/deny lists +// - Does not create parent directories +// - Applies a restrictive permission mask (default 0177 / mode 0600) +// - Only unlinks existing socket files when BindUnlink is true (not +// regular files or directories) +func NewReverseUnixForwardingCallback(opts UnixForwardingOptions) ReverseUnixForwardingCallback { + return func(ctx Context, socketPath string) (net.Listener, error) { + cleaned, err := validateSocketPath(socketPath, opts) + if err != nil { + return nil, err + } + if opts.PathValidator != nil { + if err := opts.PathValidator(ctx, cleaned); err != nil { + return nil, err + } + } - return ln, err + if opts.BindUnlink { + // Only unlink if the existing file is a socket or does + // not exist. Regular files and directories are left in + // place so that net.Listen fails with EADDRINUSE rather + // than silently deleting user data. + if info, serr := os.Lstat(cleaned); serr == nil { + if info.Mode().Type() == os.ModeSocket { + if uerr := unlink(cleaned); uerr != nil && !errors.Is(uerr, fs.ErrNotExist) { + return nil, fmt.Errorf("failed to unlink existing socket %q: %w", cleaned, uerr) + } + } + } + } + + lc := &net.ListenConfig{} + ln, err := lc.Listen(ctx, "unix", cleaned) + if err != nil { + return nil, fmt.Errorf("failed to listen on unix socket %q: %w", cleaned, err) + } + + // Apply socket permission mask. Default 0177 (mode 0600), + // matching OpenSSH's StreamLocalBindMask. + mask := os.FileMode(0177) + if opts.BindMask != nil { + mask = *opts.BindMask + } + mode := os.FileMode(0666) &^ mask + if err := os.Chmod(cleaned, mode); err != nil { + _ = ln.Close() + return nil, fmt.Errorf("failed to set permissions on socket %q: %w", cleaned, err) + } + + return ln, nil + } +} + +// UserSocketDirectories returns common socket directory prefixes for a user, +// suitable for use as UnixForwardingOptions.AllowedDirectories. The returned +// list includes the user's home directory, /tmp, and the XDG runtime +// directory (/run/user/). +func UserSocketDirectories(homeDir string, uid string) []string { + return []string{ + homeDir, + "/tmp", + filepath.Join("/run/user", uid), + } } diff --git a/streamlocal_test.go b/streamlocal_test.go index 41ae3c92..5de2c97a 100644 --- a/streamlocal_test.go +++ b/streamlocal_test.go @@ -2,6 +2,7 @@ package ssh import ( "bytes" + "errors" "fmt" "io" "net" @@ -70,6 +71,7 @@ func sampleUnixSocketServer(t *testing.T) net.Listener { func newTestSessionWithUnixForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { l := sampleUnixSocketServer(t) + allowAllCb := NewLocalUnixForwardingCallback(UnixForwardingOptions{AllowAll: true}) _, client, cleanup := newTestSession(t, &Server{ Handler: func(s Session) {}, LocalUnixForwardingCallback: func(ctx Context, socketPath string) (net.Conn, error) { @@ -79,7 +81,7 @@ func newTestSessionWithUnixForwarding(t *testing.T, forwardingEnabled bool) (net if !forwardingEnabled { return nil, ErrRejected } - return SimpleUnixLocalForwardingCallback(ctx, socketPath) + return allowAllCb(ctx, socketPath) }, }, nil) @@ -128,13 +130,17 @@ func TestReverseUnixForwardingWorks(t *testing.T) { remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + allowAllCb := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + BindUnlink: true, + }) _, client, cleanup := newTestSession(t, &Server{ Handler: func(s Session) {}, ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) { if socketPath != remoteSocketPath { panic("unexpected socket path: " + socketPath) } - return SimpleUnixReverseForwardingCallback(ctx, socketPath) + return allowAllCb(ctx, socketPath) }, }, nil) defer cleanup() @@ -177,6 +183,582 @@ func TestReverseUnixForwardingWorks(t *testing.T) { } } +func TestValidateSocketPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + opts UnixForwardingOptions + wantErr bool + wantClean string // expected cleaned path on success + errSubstr string // substring expected in error message + wantType error // expected error type (ErrRejected) + }{ + // Basic validation (applies to all modes). + { + name: "absolute path accepted with AllowAll", + path: "/tmp/test.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/tmp/test.sock", + }, + { + name: "relative path rejected", + path: "relative/path.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantErr: true, + errSubstr: "must be absolute", + wantType: ErrRejected, + }, + { + name: "dot-relative path rejected", + path: "./local.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantErr: true, + errSubstr: "must be absolute", + wantType: ErrRejected, + }, + { + name: "empty path rejected", + path: "", + opts: UnixForwardingOptions{AllowAll: true}, + wantErr: true, + errSubstr: "must be absolute", + wantType: ErrRejected, + }, + { + name: "path with dot-dot cleaned and accepted", + path: "/tmp/foo/../bar/test.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/tmp/bar/test.sock", + }, + { + name: "path with double slashes cleaned", + path: "/tmp//foo//test.sock", + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/tmp/foo/test.sock", + }, + { + name: "path with trailing slash cleaned", + path: "/tmp/test.sock/", + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/tmp/test.sock", + }, + { + name: "path at sun_path limit rejected", + path: "/" + strings.Repeat("a", maxSunPathLen-1), + opts: UnixForwardingOptions{AllowAll: true}, + wantErr: true, + errSubstr: "too long", + wantType: ErrRejected, + }, + { + name: "path just under sun_path limit accepted", + path: "/" + strings.Repeat("a", maxSunPathLen-3), + opts: UnixForwardingOptions{AllowAll: true}, + wantClean: "/" + strings.Repeat("a", maxSunPathLen-3), + }, + + // AllowedDirectories tests. + { + name: "path in allowed directory accepted", + path: "/tmp/ssh/agent.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantClean: "/tmp/ssh/agent.sock", + }, + { + name: "path in second allowed directory accepted", + path: "/home/user/.ssh/agent.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp", "/home/user"}}, + wantClean: "/home/user/.ssh/agent.sock", + }, + { + name: "path outside allowed directories rejected", + path: "/var/run/docker.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp", "/home/user"}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "empty allowed directories rejects all", + path: "/tmp/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "nil allowed directories rejects all", + path: "/tmp/test.sock", + opts: UnixForwardingOptions{}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "allowed directory itself is not a valid socket path", + path: "/tmp", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "allowed directory with trailing slash works", + path: "/tmp/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp/"}}, + wantClean: "/tmp/test.sock", + }, + { + name: "dot-dot traversal out of allowed directory rejected", + path: "/tmp/../var/run/docker.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + { + name: "dot-dot traversal staying in allowed directory accepted", + path: "/tmp/foo/../bar/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantClean: "/tmp/bar/test.sock", + }, + { + name: "allowed directory prefix attack rejected", + path: "/tmpevil/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantErr: true, + errSubstr: "not in an allowed directory", + wantType: ErrRejected, + }, + + // DeniedPrefixes tests. + { + name: "path in denied prefix rejected", + path: "/run/user/1000/systemd/private.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}, DeniedPrefixes: []string{"/run/user/1000/systemd"}}, + wantErr: true, + errSubstr: "is denied", + wantType: ErrRejected, + }, + { + name: "exact denied path rejected", + path: "/var/run/docker.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/var/run"}, DeniedPrefixes: []string{"/var/run/docker.sock"}}, + wantErr: true, + errSubstr: "is denied", + wantType: ErrRejected, + }, + { + name: "path not matching denied prefix accepted", + path: "/run/user/1000/podman/podman.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}, DeniedPrefixes: []string{"/run/user/1000/systemd"}}, + wantClean: "/run/user/1000/podman/podman.sock", + }, + { + name: "denied prefix does not match partial directory names", + path: "/run/user/1000/systemd-resolved/test.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}, DeniedPrefixes: []string{"/run/user/1000/systemd"}}, + wantClean: "/run/user/1000/systemd-resolved/test.sock", + }, + + // AllowAll overrides AllowedDirectories/DeniedPrefixes. + { + name: "AllowAll ignores AllowedDirectories", + path: "/var/run/docker.sock", + opts: UnixForwardingOptions{AllowAll: true, AllowedDirectories: []string{"/tmp"}}, + wantClean: "/var/run/docker.sock", + }, + { + name: "AllowAll ignores DeniedPrefixes", + path: "/run/user/1000/systemd/private.sock", + opts: UnixForwardingOptions{AllowAll: true, DeniedPrefixes: []string{"/run/user/1000/systemd"}}, + wantClean: "/run/user/1000/systemd/private.sock", + }, + + // Real-world socket paths. + { + name: "podman socket path", + path: "/run/user/1000/podman/podman.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}}, + wantClean: "/run/user/1000/podman/podman.sock", + }, + { + name: "gpg agent socket", + path: "/home/user/.gnupg/S.gpg-agent", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/home/user"}}, + wantClean: "/home/user/.gnupg/S.gpg-agent", + }, + { + name: "gpg agent socket systemd path", + path: "/run/user/1000/gnupg/S.gpg-agent", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/run/user/1000"}}, + wantClean: "/run/user/1000/gnupg/S.gpg-agent", + }, + { + name: "vscode remote socket", + path: "/tmp/code-d0fd2e91-ed82-46dd-8394-87ac5cde31c3.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantClean: "/tmp/code-d0fd2e91-ed82-46dd-8394-87ac5cde31c3.sock", + }, + { + name: "ssh agent socket", + path: "/tmp/ssh-XXXXXXXXXX/agent.12345", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + wantClean: "/tmp/ssh-XXXXXXXXXX/agent.12345", + }, + { + name: "docker socket denied even when /var/run allowed", + path: "/var/run/docker.sock", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/var/run"}, DeniedPrefixes: []string{"/var/run/docker.sock"}}, + wantErr: true, + errSubstr: "is denied", + wantType: ErrRejected, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cleaned, err := validateSocketPath(tt.path, tt.opts) + if tt.wantErr { + if err == nil { + t.Fatalf("validateSocketPath(%q) = %q, nil; want error containing %q", tt.path, cleaned, tt.errSubstr) + } + if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) { + t.Fatalf("validateSocketPath(%q) error = %q; want substring %q", tt.path, err.Error(), tt.errSubstr) + } + if tt.wantType != nil && !errors.Is(err, tt.wantType) { + t.Fatalf("validateSocketPath(%q) error type = %T; want errors.Is(%v)", tt.path, err, tt.wantType) + } + return + } + if err != nil { + t.Fatalf("validateSocketPath(%q) = error %q; want %q", tt.path, err, tt.wantClean) + } + if cleaned != tt.wantClean { + t.Fatalf("validateSocketPath(%q) = %q; want %q", tt.path, cleaned, tt.wantClean) + } + }) + } +} + +func TestRejectedMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want string + }{ + { + name: "bare ErrRejected gives generic message", + err: ErrRejected, + want: "unix forwarding is disabled", + }, + { + name: "rejectionError gives descriptive message", + err: &rejectionError{reason: "socket path must be absolute"}, + want: "socket path must be absolute", + }, + { + name: "wrapped ErrRejected gives wrapper message", + err: fmt.Errorf("custom reason: %w", ErrRejected), + want: "custom reason: ssh: rejected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := rejectedMessage(tt.err) + if got != tt.want { + t.Fatalf("rejectedMessage(%v) = %q; want %q", tt.err, got, tt.want) + } + }) + } +} + +func TestUserSocketDirectories(t *testing.T) { + t.Parallel() + + dirs := UserSocketDirectories("/home/testuser", "1000") + want := []string{"/home/testuser", "/tmp", "/run/user/1000"} + + if len(dirs) != len(want) { + t.Fatalf("UserSocketDirectories returned %d dirs; want %d", len(dirs), len(want)) + } + for i, d := range dirs { + if d != want[i] { + t.Fatalf("UserSocketDirectories()[%d] = %q; want %q", i, d, want[i]) + } + } +} + +func TestNewLocalUnixForwardingCallbackValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts UnixForwardingOptions + path string + wantErr bool + errSubstr string + }{ + { + name: "AllowAll accepts any absolute path", + opts: UnixForwardingOptions{AllowAll: true}, + path: "/var/run/docker.sock", + }, + { + name: "AllowAll rejects relative path", + opts: UnixForwardingOptions{AllowAll: true}, + path: "relative.sock", + wantErr: true, + errSubstr: "must be absolute", + }, + { + name: "restricted rejects path outside allowed dirs", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + path: "/var/run/docker.sock", + wantErr: true, + errSubstr: "not in an allowed directory", + }, + { + name: "restricted accepts path in allowed dir", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + path: "/tmp/test.sock", + }, + { + name: "PathValidator is called and can reject", + opts: UnixForwardingOptions{ + AllowAll: true, + PathValidator: func(_ Context, _ string) error { + return &rejectionError{reason: "custom validator rejected"} + }, + }, + path: "/tmp/test.sock", + wantErr: true, + errSubstr: "custom validator rejected", + }, + { + name: "PathValidator receives cleaned path", + opts: UnixForwardingOptions{ + AllowAll: true, + PathValidator: func(_ Context, path string) error { + if path != "/tmp/test.sock" { + return fmt.Errorf("expected /tmp/test.sock, got %s", path) + } + return nil + }, + }, + path: "/tmp/foo/../test.sock", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := newContext(nil) + defer cancel() + + cb := NewLocalUnixForwardingCallback(tt.opts) + // The callback tries to dial the socket, which will fail for + // non-existent paths. We only care about the validation errors + // (which wrap ErrRejected), not dial errors. + _, err := cb(ctx, tt.path) + if tt.wantErr { + if err == nil { + t.Fatal("expected error but got nil") + } + if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) { + t.Fatalf("error = %q; want substring %q", err.Error(), tt.errSubstr) + } + if !errors.Is(err, ErrRejected) { + t.Fatalf("expected ErrRejected, got %T: %v", err, err) + } + return + } + // For valid paths, the error should either be nil (socket + // exists) or a dial error (socket doesn't exist), but NOT + // a validation/rejection error. + if err != nil && errors.Is(err, ErrRejected) { + t.Fatalf("unexpected rejection error: %v", err) + } + }) + } +} + +func TestNewReverseUnixForwardingCallbackValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts UnixForwardingOptions + path string + wantErr bool + errSubstr string + }{ + { + name: "rejects relative path", + opts: UnixForwardingOptions{AllowAll: true}, + path: "relative.sock", + wantErr: true, + errSubstr: "must be absolute", + }, + { + name: "rejects path outside allowed dirs", + opts: UnixForwardingOptions{AllowedDirectories: []string{"/tmp"}}, + path: "/var/run/test.sock", + wantErr: true, + errSubstr: "not in an allowed directory", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := newContext(nil) + defer cancel() + + cb := NewReverseUnixForwardingCallback(tt.opts) + _, err := cb(ctx, tt.path) + if tt.wantErr { + if err == nil { + t.Fatal("expected error but got nil") + } + if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) { + t.Fatalf("error = %q; want substring %q", err.Error(), tt.errSubstr) + } + if !errors.Is(err, ErrRejected) { + t.Fatalf("expected ErrRejected, got %T: %v", err, err) + } + return + } + if err != nil && errors.Is(err, ErrRejected) { + t.Fatalf("unexpected rejection error: %v", err) + } + }) + } +} + +func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) { + t.Parallel() + + dir := tempDirUnixSocket(t) + sockPath := filepath.Join(dir, "test.sock") + + // Create an existing socket. Keep the listener open so the socket + // file persists (Go's UnixListener.Close removes the file). + oldLn, err := net.Listen("unix", sockPath) + if err != nil { + t.Fatalf("failed to create socket: %v", err) + } + defer oldLn.Close() //nolint:errcheck + + // Without BindUnlink, listening should fail because socket exists. + cbNoUnlink := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + }) + _, err = cbNoUnlink(nil, sockPath) + if err == nil { + t.Fatal("expected listen to fail on existing socket without BindUnlink") + } + + // With BindUnlink, the old socket is removed and we can listen. + cbUnlink := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + BindUnlink: true, + }) + newLn, err := cbUnlink(nil, sockPath) + if err != nil { + t.Fatalf("expected listen to succeed with BindUnlink, got: %v", err) + } + _ = newLn.Close() +} + +func TestNewReverseUnixForwardingCallbackBindUnlinkSkipsNonSocket(t *testing.T) { + t.Parallel() + + dir := tempDirUnixSocket(t) + filePath := filepath.Join(dir, "regular.file") + + // Create a regular file at the path. + if err := os.WriteFile(filePath, []byte("data"), 0600); err != nil { + t.Fatalf("failed to create regular file: %v", err) + } + + // BindUnlink should NOT remove regular files. Listen should fail. + cb := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + BindUnlink: true, + }) + _, err := cb(nil, filePath) + if err == nil { + t.Fatal("expected listen to fail on regular file even with BindUnlink") + } + + // Regular file should still exist. + if _, err := os.Stat(filePath); err != nil { + t.Fatalf("regular file should not have been deleted: %v", err) + } +} + +func TestNewReverseUnixForwardingCallbackSocketPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mask *os.FileMode + wantPerm os.FileMode + }{ + {name: "default mask 0177 gives mode 0600", mask: nil, wantPerm: 0600}, + {name: "custom mask 0117 gives mode 0660", mask: fileMode(0117), wantPerm: 0660}, + {name: "zero mask gives mode 0666", mask: fileMode(0), wantPerm: 0666}, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Use a short /tmp path to stay under sun_path limits + // even when the test framework creates long temp paths. + dir, err := os.MkdirTemp("/tmp", "ssh-perm-") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + sockPath := filepath.Join(dir, fmt.Sprintf("p%d.s", i)) + + cb := NewReverseUnixForwardingCallback(UnixForwardingOptions{ + AllowAll: true, + BindMask: tt.mask, + }) + ln, err := cb(nil, sockPath) + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer ln.Close() //nolint:errcheck + + info, err := os.Stat(sockPath) + if err != nil { + t.Fatalf("failed to stat socket: %v", err) + } + perm := info.Mode().Perm() + if perm != tt.wantPerm { + t.Fatalf("socket permissions = %04o; want %04o", perm, tt.wantPerm) + } + }) + } +} + +func fileMode(m os.FileMode) *os.FileMode { return &m } + func TestReverseUnixForwardingRespectsCallback(t *testing.T) { t.Parallel() From 3306199eb88e0a4965818a3f466046cbb334cb14 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 13 Mar 2026 08:19:18 +0000 Subject: [PATCH 3/8] streamlocal: fix wire format for forwarded channel data The Reserved field in remoteUnixForwardChannelData is uint32 but the OpenSSH PROTOCOL spec (Section 2.4) and x/crypto/ssh forwardedStreamLocalPayload both define it as string. The current code works by coincidence because uint32(0) and string("") produce identical wire bytes, but this is a latent protocol bug. Also fix server_test.go rebase artifact: newLocalListener was renamed to newLocalTCPListener by the unix forwarding commit. Spec: https://github.com/openssh/openssh-portable/blob/master/PROTOCOL Ref: https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go Updates gliderlabs/ssh#196 --- server_test.go | 2 +- streamlocal.go | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/server_test.go b/server_test.go index 0dfd73c6..e60e7035 100644 --- a/server_test.go +++ b/server_test.go @@ -128,7 +128,7 @@ func TestServerClose(t *testing.T) { } func TestServerHandshakeTimeout(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() s := &Server{ HandshakeTimeout: time.Millisecond, diff --git a/streamlocal.go b/streamlocal.go index c20278a9..95bf06aa 100644 --- a/streamlocal.go +++ b/streamlocal.go @@ -80,9 +80,15 @@ type remoteUnixForwardRequest struct { // remoteUnixForwardChannelData describes the data sent as the payload in the new // channel request when a Unix connection is accepted by the listener. +// +// See OpenSSH PROTOCOL, Section 2.4 "forwarded-streamlocal@openssh.com": +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +// +// See also the client-side struct in x/crypto/ssh (forwardedStreamLocalPayload): +// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go type remoteUnixForwardChannelData struct { SocketPath string - Reserved uint32 + Reserved string } // ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and From e74718641af5c36ae9dabc78874522c94912a885 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 13 Mar 2026 08:20:03 +0000 Subject: [PATCH 4/8] streamlocal: address upstream review comments Address review feedback from gliderlabs/ssh#196: - Return true for duplicate forward requests instead of false. Returning false causes clients with ExitOnForwardFailure=yes to disconnect, which diverges from OpenSSH behavior. (mafredri, ge9) Ref: https://github.com/gliderlabs/ssh/pull/196#discussion_r1812621663 Ref: https://github.com/coder/coder/blob/b828412edd913bef6665cf8a0b2ca7ac93334012/agent/agentssh/forward.go#L76-L91 - Use context-aware net.ListenConfig in SimpleUnixReverseForwardingCallback instead of bare net.Listen, matching the dialer pattern used in SimpleUnixLocalForwardingCallback. (mafredri) Ref: https://github.com/gliderlabs/ssh/pull/196#discussion_r1812641672 Updates gliderlabs/ssh#196 --- streamlocal.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/streamlocal.go b/streamlocal.go index 95bf06aa..9b1bfcc7 100644 --- a/streamlocal.go +++ b/streamlocal.go @@ -133,8 +133,12 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g _, ok := h.forwards[addr] h.Unlock() if ok { - // TODO: log failure - return false, nil + // In cases where ExitOnForwardFailure=yes is set, returning + // false here will cause the connection to be closed. To avoid + // this, and to match OpenSSH behavior, we silently ignore + // the second forward request. + // TODO: log duplicate forward + return true, nil } ln, err := srv.ReverseUnixForwardingCallback(ctx, addr) From 3588f61f9b3ead679e8c09b0c7ca32e946d844e9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 13 Mar 2026 08:20:59 +0000 Subject: [PATCH 5/8] streamlocal: fix concurrency issues in ForwardedUnixHandler Several concurrency issues exist in the shared ForwardedUnixHandler: - Key the forwards map by (sessionID, addr) instead of addr alone. A single handler instance is shared across SSH connections, so addr-only keys allow cross-session collisions and let any session cancel another session's forwards. This matches the approach used in coder/coder's production fork of this code. - Pass the parent context (connection-scoped) to bicopy instead of the derived context. The derived context is cancelled when the accept loop exits, which prematurely tears down active forwarded connections that are still transferring data. - Delete the map entry atomically in the cancel handler instead of relying on the accept-loop goroutine to clean up asynchronously. This prevents a timing window where the stale entry would reject legitimate re-forward requests. Updates gliderlabs/ssh#196 --- streamlocal.go | 54 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/streamlocal.go b/streamlocal.go index 9b1bfcc7..cf9208b7 100644 --- a/streamlocal.go +++ b/streamlocal.go @@ -91,6 +91,14 @@ type remoteUnixForwardChannelData struct { Reserved string } +// forwardKey identifies a forwarded Unix socket scoped to a specific +// SSH session, preventing cross-session collisions and ensuring that +// one session cannot cancel another session's forward. +type forwardKey struct { + sessionID string + addr string +} + // ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and // adding the HandleSSHRequest callback to the server's RequestHandlers under // `streamlocal-forward@openssh.com` and @@ -100,13 +108,13 @@ type remoteUnixForwardChannelData struct { // not work on all Windows installations and is not tested on Windows. type ForwardedUnixHandler struct { sync.Mutex - forwards map[string]net.Listener + forwards map[forwardKey]net.Listener } func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { h.Lock() if h.forwards == nil { - h.forwards = make(map[string]net.Listener) + h.forwards = make(map[forwardKey]net.Listener) } h.Unlock() conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn) @@ -129,10 +137,17 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g } addr := reqPayload.SocketPath + key := forwardKey{ + sessionID: ctx.SessionID(), + addr: addr, + } + + // Use a nil sentinel to claim the key while the callback runs, + // preventing a concurrent request from racing past the check. h.Lock() - _, ok := h.forwards[addr] - h.Unlock() + _, ok := h.forwards[key] if ok { + h.Unlock() // In cases where ExitOnForwardFailure=yes is set, returning // false here will cause the connection to be closed. To avoid // this, and to match OpenSSH behavior, we silently ignore @@ -140,9 +155,14 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g // TODO: log duplicate forward return true, nil } + h.forwards[key] = nil // placeholder; claimed + h.Unlock() ln, err := srv.ReverseUnixForwardingCallback(ctx, addr) if err != nil { + h.Lock() + delete(h.forwards, key) + h.Unlock() if errors.Is(err, ErrRejected) { return false, []byte(rejectedMessage(err)) } @@ -150,15 +170,14 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g return false, nil } - // The listener needs to successfully start before it can be added to - // the map, so we don't have to worry about checking for an existing - // listener as you can't listen on the same socket twice. - // - // This is also what the TCP version of this code does. h.Lock() - h.forwards[addr] = ln + h.forwards[key] = ln h.Unlock() + // Use the connection-scoped context for bicopy so active data + // transfers survive listener shutdown. The derived context is + // only used for the accept loop lifecycle. + connCtx := ctx ctx, cancel := context.WithCancel(ctx) go func() { <-ctx.Done() @@ -184,14 +203,14 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g return } go gossh.DiscardRequests(reqs) - bicopy(ctx, ch, c) + bicopy(connCtx, ch, c) }() } h.Lock() - ln2, ok := h.forwards[addr] + ln2, ok := h.forwards[key] if ok && ln2 == ln { - delete(h.forwards, addr) + delete(h.forwards, key) } h.Unlock() _ = ln.Close() @@ -206,8 +225,15 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g // TODO: log parse failure return false, nil } + key := forwardKey{ + sessionID: ctx.SessionID(), + addr: reqPayload.SocketPath, + } h.Lock() - ln, ok := h.forwards[reqPayload.SocketPath] + ln, ok := h.forwards[key] + if ok { + delete(h.forwards, key) + } h.Unlock() if ok { _ = ln.Close() From 05598c6af8d331ee4ad2948ed9121de535be8c77 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 13 Mar 2026 08:21:47 +0000 Subject: [PATCH 6/8] streamlocal: fix minor code quality issues - Fix doc comment for ErrRejected: the comment said "ErrReject" but the variable is "ErrRejected". Also enumerate which callbacks honor it. - Use %v on error instead of %+v on err.Error() string in reject message. The %+v verb on a string is identical to %s; the stack trace behavior of %+v only works on the error value itself. - Fix incorrect os.Stat check in TestReverseUnixForwardingWorks. The condition "err == nil && !os.IsNotExist(err)" has a redundant second clause since os.IsNotExist(nil) is always false. Simplify to "err == nil". Updates gliderlabs/ssh#196 --- ssh.go | 5 ++++- streamlocal_test.go | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ssh.go b/ssh.go index 2f996c43..cc8a4a6d 100644 --- a/ssh.go +++ b/ssh.go @@ -30,7 +30,10 @@ const ( // DefaultHandler is the default Handler used by Serve. var DefaultHandler Handler -// ErrReject is returned by some callbacks to reject a request. +// ErrRejected may be returned by LocalUnixForwardingCallback or +// ReverseUnixForwardingCallback to reject a forwarding request. When +// returned, the server replies with "prohibited" rather than +// "connection failed." var ErrRejected = errors.New("ssh: rejected") // Option is a functional option handler for Server. diff --git a/streamlocal_test.go b/streamlocal_test.go index 5de2c97a..a43a309d 100644 --- a/streamlocal_test.go +++ b/streamlocal_test.go @@ -178,8 +178,8 @@ func TestReverseUnixForwardingWorks(t *testing.T) { t.Fatalf("failed to close remote listener: %v", err) } _, err = os.Stat(remoteSocketPath) - if err == nil && !os.IsNotExist(err) { - t.Fatalf("expected remote socket to be gone but it still exists: %v", err) + if err == nil { + t.Fatal("expected remote socket to be removed after close") } } From 92e9c68e7e4ef471ac44c91d27ca3f3b56ade6bc Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 13 Mar 2026 08:26:19 +0000 Subject: [PATCH 7/8] all: fix errcheck findings in new test code Silence golangci-lint errcheck warnings for unchecked return values in the new streamlocal and tcpip forwarding tests. Updates gliderlabs/ssh#196 --- streamlocal_test.go | 12 ++++++------ tcpip_test.go | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/streamlocal_test.go b/streamlocal_test.go index a43a309d..b04ba51e 100644 --- a/streamlocal_test.go +++ b/streamlocal_test.go @@ -61,8 +61,8 @@ func sampleUnixSocketServer(t *testing.T) net.Listener { if err != nil { return } - conn.Write(sampleServerResponse) - conn.Close() + _, _ = conn.Write(sampleServerResponse) + _ = conn.Close() }() return l @@ -87,7 +87,7 @@ func newTestSessionWithUnixForwarding(t *testing.T, forwardingEnabled bool) (net return l, client, func() { cleanup() - l.Close() + _ = l.Close() } } @@ -149,14 +149,14 @@ func TestReverseUnixForwardingWorks(t *testing.T) { if err != nil { t.Fatalf("failed to listen on a unix socket over SSH %q: %v", remoteSocketPath, err) } - defer l.Close() + defer l.Close() //nolint:errcheck go func() { conn, err := l.Accept() if err != nil { return } - conn.Write(sampleServerResponse) - conn.Close() + _, _ = conn.Write(sampleServerResponse) + _ = conn.Close() }() // Dial the listener that should've been created by the server. diff --git a/tcpip_test.go b/tcpip_test.go index 525ca2d7..f6a9fa9f 100644 --- a/tcpip_test.go +++ b/tcpip_test.go @@ -106,14 +106,14 @@ func TestReverseTCPForwardingWorks(t *testing.T) { if err != nil { t.Fatalf("failed to listen on a random TCP port over SSH: %v", err) } - defer l.Close() + defer l.Close() //nolint:errcheck go func() { conn, err := l.Accept() if err != nil { return } - conn.Write(sampleServerResponse) - conn.Close() + _, _ = conn.Write(sampleServerResponse) + _ = conn.Close() }() // Dial the listener that should've been created by the server. From 56325c30bd866754b0747da0824dacded9109dd3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 13 Mar 2026 08:26:25 +0000 Subject: [PATCH 8/8] all: apply go modernize suggestions Use maps.Copy instead of manual for-range copy loops in ensureHandlers, as suggested by golang.org/x/tools/go/analysis/passes/modernize. --- server.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/server.go b/server.go index 8824cdfa..e2290b8a 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package ssh import ( "context" "errors" + "maps" "net" "sync" "time" @@ -100,21 +101,15 @@ func (srv *Server) ensureHandlers() { if srv.RequestHandlers == nil { srv.RequestHandlers = map[string]RequestHandler{} - for k, v := range DefaultRequestHandlers { - srv.RequestHandlers[k] = v - } + maps.Copy(srv.RequestHandlers, DefaultRequestHandlers) } if srv.ChannelHandlers == nil { srv.ChannelHandlers = map[string]ChannelHandler{} - for k, v := range DefaultChannelHandlers { - srv.ChannelHandlers[k] = v - } + maps.Copy(srv.ChannelHandlers, DefaultChannelHandlers) } if srv.SubsystemHandlers == nil { srv.SubsystemHandlers = map[string]SubsystemHandler{} - for k, v := range DefaultSubsystemHandlers { - srv.SubsystemHandlers[k] = v - } + maps.Copy(srv.SubsystemHandlers, DefaultSubsystemHandlers) } }