Skip to content
Merged
2 changes: 1 addition & 1 deletion options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) {

func TestPasswordAuthBadPass(t *testing.T) {
t.Parallel()
l := newLocalListener()
l := newLocalTCPListener()
srv := &Server{Handler: func(s Session) {}}
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
return false
Expand Down
15 changes: 6 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ssh
import (
"context"
"errors"
"maps"
"net"
"sync"
"time"
Expand Down Expand Up @@ -47,6 +48,8 @@ type Server struct {
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding (direct-streamlocal@openssh.com), denies all if nil
ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding (streamlocal-forward@openssh.com), denies all if nil
ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions

Expand Down Expand Up @@ -98,21 +101,15 @@ func (srv *Server) ensureHandlers() {

if srv.RequestHandlers == nil {
srv.RequestHandlers = map[string]RequestHandler{}
for k, v := range DefaultRequestHandlers {
srv.RequestHandlers[k] = v
}
maps.Copy(srv.RequestHandlers, DefaultRequestHandlers)
}
if srv.ChannelHandlers == nil {
srv.ChannelHandlers = map[string]ChannelHandler{}
for k, v := range DefaultChannelHandlers {
srv.ChannelHandlers[k] = v
}
maps.Copy(srv.ChannelHandlers, DefaultChannelHandlers)
}
if srv.SubsystemHandlers == nil {
srv.SubsystemHandlers = map[string]SubsystemHandler{}
for k, v := range DefaultSubsystemHandlers {
srv.SubsystemHandlers[k] = v
}
maps.Copy(srv.SubsystemHandlers, DefaultSubsystemHandlers)
}
}

Expand Down
6 changes: 3 additions & 3 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestAddHostKey(t *testing.T) {
}

func TestServerShutdown(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()
testBytes := []byte("Hello world\n")
s := &Server{
Handler: func(s Session) {
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestServerShutdown(t *testing.T) {
}

func TestServerClose(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()
s := &Server{
Handler: func(s Session) {
time.Sleep(5 * time.Second)
Expand Down Expand Up @@ -128,7 +128,7 @@ func TestServerClose(t *testing.T) {
}

func TestServerHandshakeTimeout(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()

s := &Server{
HandshakeTimeout: time.Millisecond,
Expand Down
19 changes: 15 additions & 4 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,25 @@ func (srv *Server) serveOnce(l net.Listener) error {
return e
}
srv.ChannelHandlers = map[string]ChannelHandler{
"session": DefaultSessionHandler,
"direct-tcpip": DirectTCPIPHandler,
"session": DefaultSessionHandler,
"direct-tcpip": DirectTCPIPHandler,
"direct-streamlocal@openssh.com": DirectStreamLocalHandler,
}

forwardedTCPHandler := &ForwardedTCPHandler{}
forwardedUnixHandler := &ForwardedUnixHandler{}
srv.RequestHandlers = map[string]RequestHandler{
"tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
"streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest,
"cancel-streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest,
}

srv.HandleConn(conn)
return nil
}

func newLocalListener() net.Listener {
func newLocalTCPListener() net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
Expand Down Expand Up @@ -65,7 +76,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
}

func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
l := newLocalListener()
l := newLocalTCPListener()
go srv.serveOnce(l)
return newClientSession(t, l.Addr().String(), cfg)
}
Expand Down
31 changes: 31 additions & 0 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssh

import (
"crypto/subtle"
"errors"
"net"

gossh "golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -29,6 +30,12 @@ const (
// DefaultHandler is the default Handler used by Serve.
var DefaultHandler Handler

// ErrRejected may be returned by LocalUnixForwardingCallback or
// ReverseUnixForwardingCallback to reject a forwarding request. When
// returned, the server replies with "prohibited" rather than
// "connection failed."
var ErrRejected = errors.New("ssh: rejected")

// Option is a functional option handler for Server.
type Option func(*Server) error

Expand Down Expand Up @@ -66,6 +73,30 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti
// ReversePortForwardingCallback is a hook for allowing reverse port forwarding
type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool

// LocalUnixForwardingCallback is a hook for allowing unix forwarding
// (direct-streamlocal@openssh.com). The callback receives the client-requested
// socket path and returns a connection to the target socket, or an error.
//
// Returning ErrRejected (or an error wrapping it) rejects the request with
// "administratively prohibited" and the error message is sent to the client.
// Any other error rejects with "connection failed."
//
// Use NewLocalUnixForwardingCallback to create a callback with built-in path
// validation and security controls.
type LocalUnixForwardingCallback func(ctx Context, socketPath string) (net.Conn, error)

// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding
// (streamlocal-forward@openssh.com). The callback receives the client-requested
// socket path and returns a listener bound to that path, or an error.
//
// Returning ErrRejected (or an error wrapping it) rejects the request with
// "administratively prohibited" and the error message is sent to the client.
// Any other error rejects the request silently.
//
// Use NewReverseUnixForwardingCallback to create a callback with built-in path
// validation, permission controls, and security defaults matching OpenSSH.
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error)

// ServerConfigCallback is a hook for creating custom default server configs
type ServerConfigCallback func(ctx Context) *gossh.ServerConfig

Expand Down
Loading