diff --git a/agentprotocol/Connection.go b/agentprotocol/Connection.go new file mode 100644 index 00000000..4eeb81a0 --- /dev/null +++ b/agentprotocol/Connection.go @@ -0,0 +1,18 @@ +package agentprotocol + +type Connection interface { + // Read reads data from the connection. The blocking nature of this call depends on the underlying communication medium + Read(p []byte) (n int, err error) + // Read reads data from the connection. The blocking nature of this call depends on the underlying communication medium + Write(data []byte) (n int, err error) + // Close requests to close an active connection + Close() error + // CloseImmediately closes an active connection without waiting for the other side to acknowledge + CloseImmediately() error + // Accept accepts a pending connection request. It must be called before any Read/Write functions can be called on the connection + Accept() error + // Reject rejects a pending connection request and closes the connection. + Reject() error + // Details returns the details of a connection request. It can be called to gain more information about a connection before an Accept/Reject action is made + Details() NewConnectionPayload +} diff --git a/agentprotocol/ForwardCtx.go b/agentprotocol/ForwardCtx.go new file mode 100644 index 00000000..47ca755b --- /dev/null +++ b/agentprotocol/ForwardCtx.go @@ -0,0 +1,69 @@ +package agentprotocol + +import "io" + +type ForwardCtx interface { + // NewConnectionTCP requests the other side to connect to a specified address/host and port combination and forward all data from the returned ReadWriteCloser to it. + // + // connectedAddress is the address that the connection requested to connect to, is it used by the receiving side to initiate the connection to the desired address. + // connectedPort is the port that the connection requested to connect to, is it used by the receiving side to initiate the connection to the desired port. + // origAddress is the originator address of the connection. It can be used by the receiving side to decide whether to accept this connection. + // origAddress is the originator port of the connection. It can be used by the receiving side to decide whether to accept this connection. + // closeFunc is a callback function that is called when the connection is called to perform cleanup of the backing connection. + NewConnectionTCP( + connectedAddress string, + connectedPort uint32, + origAddress string, + origPort uint32, + closeFunc func() error, + ) (io.ReadWriteCloser, error) + // NewConnectionUnix requests the other side to connect to a specified unix and forward all data from the returned ReadWriteCloser to it. + // + // path is the path of the unix socket to connect to. + // closeFunc is a callback function that is called when the connection is called to perform cleanup of the backing connection. + NewConnectionUnix( + path string, + closeFunc func() error, + ) (io.ReadWriteCloser, error) + + // StartServer initializes the ForwardCtx in server mode which waits for information from the other side about the function it needs to perform. It returns the connection type the other side requests and additional information in the setupPacket. Additionally, a Connection channel, connChan, is returned that provides connection requests from the other side of the connection. + StartServer() (connectionType uint32, setupPacket SetupPacket, connChan chan Connection, err error) + + // StartClientForward initializes the ForwardCtx in client mode and informs the server that the client is going to be the connection requestor (Direct Forward). A connection channel is returned that informs the server of connection requests by the client however in this mode it is a assumed that the server sends no connection request so the sane behaviour is to reject all connections. In this mode, the client can start new connections on the server using the NewConnection* function family. + StartClientForward() (chan Connection, error) + + // StartX11ForwardClient initializes the ForwardCtx in client mode and informs the server to start an X11 server and forward all X11 connections to the client. + // + // singleConnection is the X11 singleConnection parameter that requests to only accept the first connection (X11 window) and no more. + // screen is the X11 screen number. + // authProtocol is the X11 auth protocol. + // authCookie is the X11 auth cookie. + StartX11ForwardClient( + singleConnection bool, + screen string, + authProtocol string, + authCookie string, + ) (chan Connection, error) + + // StartReverseForwardClient initializes the ForwardCtx in client mode and informs the server to start listening for connections on the requested host and port. Once a connection is received a new connection is created and sent through the Connection channel. + // + // bindHost is the host to listen on connections on. + // bindPort is the port to listen on connections on. + // singleConnection is a flag that requests to stop listening for new connections after the first one. + StartReverseForwardClient(bindHost string, bindPort uint32, singleConnection bool) (chan Connection, error) + + // StartReverseForwardClient initializes the ForwardCtx in client mode and informs the server to start listening for connections on the requested unix socket. Once a connection is received a new connection is created and sent through the Connection channel. + // + // path is the path to the unix socket to listen on. + // singleConnection is a flag that requests to stop listening for new connections after the first one. + StartReverseForwardClientUnix(path string, singleConnection bool) (chan Connection, error) + + // NoMoreConnections informs the other side that it should not accept any more connection requests from it. It is used as a forwarding security feature in cases where it's clear there will only be one connection. + NoMoreConnections() error + + // WaitFinish blocks until NoMoreConnections has been received and all active connections have been closed. + WaitFinish() + + // Kill closes the ForwardCtx immediately and terminates all connections. + Kill() +} diff --git a/agentprotocol/NewForwardCtx.go b/agentprotocol/NewForwardCtx.go index e368bf9a..5320bd3c 100644 --- a/agentprotocol/NewForwardCtx.go +++ b/agentprotocol/NewForwardCtx.go @@ -3,13 +3,13 @@ package agentprotocol import ( "io" - log "go.containerssh.io/libcontainerssh/log" + log "go.containerssh.io/libcontainerssh/log" ) -func NewForwardCtx(fromBackend io.Reader, toBackend io.Writer, logger log.Logger) *ForwardCtx { - return &ForwardCtx{ +func NewForwardCtx(fromBackend io.Reader, toBackend io.Writer, logger log.Logger) ForwardCtx { + return &forwardCtx{ fromBackend: fromBackend, - toBackend: toBackend, - logger: logger, + toBackend: toBackend, + logger: logger, } -} \ No newline at end of file +} diff --git a/agentprotocol/Protocol.go b/agentprotocol/Protocol.go index 0d60035d..d73e1643 100644 --- a/agentprotocol/Protocol.go +++ b/agentprotocol/Protocol.go @@ -2,14 +2,14 @@ package agentprotocol const ( CONNECTION_TYPE_X11 = iota - CONNECTION_TYPE_PORT_FORWARD = iota - CONNECTION_TYPE_PORT_DIAL = iota - CONNECTION_TYPE_SOCKET_FORWARD = iota - CONNECTION_TYPE_SOCKET_DIAL = iota + CONNECTION_TYPE_PORT_FORWARD + CONNECTION_TYPE_PORT_DIAL + CONNECTION_TYPE_SOCKET_FORWARD + CONNECTION_TYPE_SOCKET_DIAL ) const ( - PROTOCOL_TCP string = "tcp" + PROTOCOL_TCP string = "tcp" PROTOCOL_UNIX string = "unix" ) @@ -24,10 +24,10 @@ const ( ) type SetupPacket struct { - ConnectionType uint32 - BindHost string - BindPort uint32 - Protocol string + ConnectionType uint32 + BindHost string + BindPort uint32 + Protocol string Screen string SingleConnection bool @@ -36,8 +36,8 @@ type SetupPacket struct { } type NewConnectionPayload struct { - Protocol string - + Protocol string + ConnectedAddress string ConnectedPort uint32 OriginatorAddress string @@ -45,7 +45,7 @@ type NewConnectionPayload struct { } type Packet struct { - Type int - ConnectionId uint64 - Payload []byte + Type int + ConnectionID uint64 + Payload []byte } diff --git a/agentprotocol/README.md b/agentprotocol/README.md new file mode 100644 index 00000000..c8c9b7e9 --- /dev/null +++ b/agentprotocol/README.md @@ -0,0 +1,27 @@ +# The ContainerSSH Agent protocol + +The ContainerSSH Agent protocol allows for forwarding and reverse-forwarding several types of connections within the container: X11, TCP, etc. The protocol is designed to be symmetrical in a way that both ends can request, accept/reject, and process connections and both ends have the same capabilities during a connection. However, there is a small client/server distinction while initializing the protocol in the initial exchange: The 'Client' (ContainerSSH) sends a `SetupPacket` to the 'Server' (Agent) that specificies the mode that the agent should initialize to (e.g. forward, reverse-forward, X11 forward etc). + +## Concepts + +The server and the client communicate over the standard input/output using the container APIs. (Docker and Kubernetes) The agent is running just like any other program would in the container. + +### Server + +The server in this context is the ContainerSSH agent. It waits for connection requests from the client (ContainerSSH) and opens the corresponding sockkets. + +### Client + +The client in this context is ContainerSSH, opening a connection by sending requests to the agent via the standard input/output using the container API (Docker or Kubernetes). + +### Connection + +A Connection is a bidirectional binary communication between the server and the client. Multiple number of connections can be active at any given time and both sides (client/server) have the capacity to request a new connection. Connections are identified by a ConnectionID and each packet includes the ConnectionID to associate it with a connection. The state of connections is detailed in the following flow graph where the nodes represent the valid connection states and the edges are the actions/packets that affect the connection state. + +![connection state diagram](./images/cssh-agent.png) + +When a connection is initiated it is in WAITINIT state until the other end issues either an Accept action, which results in a SUCCESS message and the connection starting or a Reject action whith results in an ERROR message and the connection closing. When a connection is in the STARTED state it can accept data and both sides can issue write() and read() calls to write and read from the connection. The blocking/non-blocking nature of these calls depends on the underlying communication medium. When a connection is closed from one end it is moved to the WAITCLOSE state until the other side acknowledges the close request. This is necessary to ensure that any leftover data sent after the close call is processed. Finally once the close request is acknowledged the connection is finally closed. + +## Protocol + +The protocol consists of messages sent in [CBOR-encoding](https://cbor.io/) in a back-to-back fashion. Other than the connection control packets described above there is additionally a 'No More Connections' packet that instructs the other side to stop accepting new connections. This is handled internally in the protocol library by closing the new connection channel. \ No newline at end of file diff --git a/agentprotocol/images/cssh-agent.png b/agentprotocol/images/cssh-agent.png new file mode 100644 index 00000000..8cb463e9 Binary files /dev/null and b/agentprotocol/images/cssh-agent.png differ diff --git a/agentprotocol/server.go b/agentprotocol/server.go index d0924358..50361ce0 100644 --- a/agentprotocol/server.go +++ b/agentprotocol/server.go @@ -6,9 +6,9 @@ import ( "sync" "time" - log "go.containerssh.io/libcontainerssh/log" - message "go.containerssh.io/libcontainerssh/message" "github.com/fxamacker/cbor/v2" + log "go.containerssh.io/libcontainerssh/log" + message "go.containerssh.io/libcontainerssh/message" ) const ( @@ -18,7 +18,7 @@ const ( CONNECTION_STATE_CLOSED ) -type Connection struct { +type connection struct { logger log.Logger lock sync.Mutex state int @@ -28,15 +28,15 @@ type Connection struct { details NewConnectionPayload bufferReader *io.PipeReader bufferWriter *io.PipeWriter - ctx *ForwardCtx + ctx *forwardCtx closeCallback func() error } -func (c *Connection) Read(p []byte) (n int, err error) { +func (c *connection) Read(p []byte) (n int, err error) { return c.bufferReader.Read(p) } -func (c *Connection) Write(data []byte) (n int, err error) { +func (c *connection) Write(data []byte) (n int, err error) { c.lock.Lock() defer c.lock.Unlock() L: @@ -59,7 +59,7 @@ L: packet := Packet{ Type: PACKET_DATA, - ConnectionId: c.id, + ConnectionID: c.id, Payload: data, } err = c.ctx.writePacket(&packet) @@ -74,7 +74,7 @@ L: return len(data), nil } -func (c *Connection) Close() error { +func (c *connection) Close() error { c.lock.Lock() switch c.state { @@ -86,7 +86,7 @@ func (c *Connection) Close() error { c.lock.Unlock() packet := Packet{ Type: PACKET_CLOSE_CONNECTION, - ConnectionId: c.id, + ConnectionID: c.id, } return c.ctx.writePacket(&packet) case CONNECTION_STATE_WAITCLOSE: @@ -99,7 +99,7 @@ func (c *Connection) Close() error { return fmt.Errorf("unknown state") } -func (c *Connection) CloseImm() error { +func (c *connection) CloseImmediately() error { c.lock.Lock() defer c.lock.Unlock() if c.state != CONNECTION_STATE_WAITINIT && c.state != CONNECTION_STATE_STARTED && c.state != CONNECTION_STATE_WAITCLOSE { @@ -116,7 +116,7 @@ func (c *Connection) CloseImm() error { return nil } -func (c *Connection) Accept() error { +func (c *connection) Accept() error { c.lock.Lock() defer c.lock.Unlock() if c.initiator { @@ -129,12 +129,12 @@ func (c *Connection) Accept() error { c.stateCond.Broadcast() packet := Packet{ Type: PACKET_SUCCESS, - ConnectionId: c.id, + ConnectionID: c.id, } return c.ctx.writePacket(&packet) } -func (c *Connection) Reject() error { +func (c *connection) Reject() error { c.lock.Lock() defer c.lock.Unlock() if c.initiator { @@ -147,32 +147,32 @@ func (c *Connection) Reject() error { c.stateCond.Broadcast() packet := Packet{ Type: PACKET_ERROR, - ConnectionId: c.id, + ConnectionID: c.id, } return c.ctx.writePacket(&packet) } -func (c *Connection) Details() NewConnectionPayload { +func (c *connection) Details() NewConnectionPayload { return c.details } -func (c *Connection) setState(state int) { +func (c *connection) setState(state int) { c.lock.Lock() c.state = state c.stateCond.Broadcast() c.lock.Unlock() } -type ForwardCtx struct { +type forwardCtx struct { fromBackend io.Reader toBackend io.Writer logger log.Logger - connectionChannel chan *Connection + connectionChannel chan Connection stopped bool connectionId uint64 connMapMu sync.RWMutex - connMap map[uint64]*Connection + connMap map[uint64]*connection encoderMu sync.Mutex encoder *cbor.Encoder decoder *cbor.Decoder @@ -180,23 +180,23 @@ type ForwardCtx struct { waitGroup sync.WaitGroup } -func (c *ForwardCtx) writePacket(packet *Packet) error { +func (c *forwardCtx) writePacket(packet *Packet) error { c.encoderMu.Lock() err := c.encoder.Encode(&packet) c.encoderMu.Unlock() return err } -func (c *ForwardCtx) handleData(packet *Packet) { +func (c *forwardCtx) handleData(packet *Packet) { c.connMapMu.RLock() - conn, ok := c.connMap[packet.ConnectionId] + conn, ok := c.connMap[packet.ConnectionID] c.connMapMu.RUnlock() if !ok { c.logger.Info( message.NewMessage( message.EAgentUnknownConnection, "Received data packet with unknown connection id %d", - packet.ConnectionId, + packet.ConnectionID, ), ) return @@ -232,24 +232,24 @@ func (c *ForwardCtx) handleData(packet *Packet) { } } -func (c *ForwardCtx) handleClose(packet *Packet) { +func (c *forwardCtx) handleClose(packet *Packet) { c.connMapMu.Lock() - conn, ok := c.connMap[packet.ConnectionId] + conn, ok := c.connMap[packet.ConnectionID] if !ok { c.logger.Info( message.NewMessage( message.EAgentUnknownConnection, "Received close packet with unknown connection id %d", - packet.ConnectionId, + packet.ConnectionID, ), ) return } c.connMapMu.Unlock() - err := conn.CloseImm() + err := conn.CloseImmediately() retPacket := Packet{ Type: PACKET_SUCCESS, - ConnectionId: conn.id, + ConnectionID: conn.id, } if err != nil { retPacket.Type = PACKET_ERROR @@ -257,16 +257,16 @@ func (c *ForwardCtx) handleClose(packet *Packet) { _ = c.writePacket(&retPacket) } -func (c *ForwardCtx) handleSuccess(packet *Packet) { +func (c *forwardCtx) handleSuccess(packet *Packet) { c.connMapMu.Lock() defer c.connMapMu.Unlock() - conn, ok := c.connMap[packet.ConnectionId] + conn, ok := c.connMap[packet.ConnectionID] if !ok { c.logger.Info( message.NewMessage( message.EAgentUnknownConnection, "Received success packet with unknown connection id %d", - packet.ConnectionId, + packet.ConnectionID, ), ) return @@ -276,7 +276,7 @@ func (c *ForwardCtx) handleSuccess(packet *Packet) { case CONNECTION_STATE_WAITINIT: conn.setState(CONNECTION_STATE_STARTED) case CONNECTION_STATE_WAITCLOSE: - _ = conn.CloseImm() + _ = conn.CloseImmediately() default: c.logger.Warning( message.NewMessage( @@ -287,16 +287,16 @@ func (c *ForwardCtx) handleSuccess(packet *Packet) { } } -func (c *ForwardCtx) handleError(packet *Packet) { +func (c *forwardCtx) handleError(packet *Packet) { c.connMapMu.Lock() defer c.connMapMu.Unlock() - conn, ok := c.connMap[packet.ConnectionId] + conn, ok := c.connMap[packet.ConnectionID] if !ok { c.logger.Info( message.NewMessage( message.EAgentUnknownConnection, "Received error packet with unknown connection id %d", - packet.ConnectionId, + packet.ConnectionID, ), ) return @@ -306,23 +306,23 @@ func (c *ForwardCtx) handleError(packet *Packet) { message.NewMessage( message.MAgentRemoteError, "Received error packet for connection %d from remote", - packet.ConnectionId, + packet.ConnectionID, ), ) - _ = conn.CloseImm() + _ = conn.CloseImmediately() } -func (c *ForwardCtx) handleNewConnection(packet *Packet) { +func (c *forwardCtx) handleNewConnection(packet *Packet) { newConnectionPacket, err := c.unmarshalNewConnection(packet.Payload) if err != nil { c.logger.Error("Error unmarshalling new connection payload", err) return } pipeReader, pipeWriter := io.Pipe() - connection := Connection{ + connection := connection{ state: CONNECTION_STATE_WAITINIT, - id: packet.ConnectionId, + id: packet.ConnectionID, details: newConnectionPacket, bufferReader: pipeReader, bufferWriter: pipeWriter, @@ -331,24 +331,24 @@ func (c *ForwardCtx) handleNewConnection(packet *Packet) { } connection.stateCond = sync.NewCond(&connection.lock) c.connMapMu.Lock() - if _, ok := c.connMap[packet.ConnectionId]; ok { + if _, ok := c.connMap[packet.ConnectionID]; ok { c.logger.Warning("Remote tried to open connection with re-used connectionId") // Cannot send reject here, might interfere with other connection ? c.connMapMu.Unlock() return } - if packet.ConnectionId <= c.connectionId { + if packet.ConnectionID <= c.connectionId { c.logger.Warning("Suspicious connection, id <= prev") // Can't send reject here either c.connMapMu.Unlock() return } - if packet.ConnectionId != c.connectionId+1 { + if packet.ConnectionID != c.connectionId+1 { c.logger.Warning("Suspicious connection, id not prev + 1") } - c.connectionId = packet.ConnectionId - c.connMap[packet.ConnectionId] = &connection + c.connectionId = packet.ConnectionID + c.connMap[packet.ConnectionID] = &connection c.waitGroup.Add(1) c.connMapMu.Unlock() @@ -361,7 +361,7 @@ func (c *ForwardCtx) handleNewConnection(packet *Packet) { c.connectionChannel <- &connection } -func (c *ForwardCtx) handleBackend() { +func (c *forwardCtx) handleBackend() { for { packet := Packet{} err := c.decoder.Decode(&packet) @@ -401,7 +401,7 @@ func (c *ForwardCtx) handleBackend() { } } -func (c *ForwardCtx) unmarshalSetup(payload []byte) (SetupPacket, error) { +func (c *forwardCtx) unmarshalSetup(payload []byte) (SetupPacket, error) { packet := SetupPacket{} err := cbor.Unmarshal(payload, &packet) if err != nil { @@ -410,7 +410,7 @@ func (c *ForwardCtx) unmarshalSetup(payload []byte) (SetupPacket, error) { return packet, nil } -func (c *ForwardCtx) unmarshalNewConnection(payload []byte) (NewConnectionPayload, error) { +func (c *forwardCtx) unmarshalNewConnection(payload []byte) (NewConnectionPayload, error) { packet := NewConnectionPayload{} err := cbor.Unmarshal(payload, &packet) if err != nil { @@ -419,7 +419,7 @@ func (c *ForwardCtx) unmarshalNewConnection(payload []byte) (NewConnectionPayloa return packet, nil } -func (c *ForwardCtx) NewConnectionTCP( +func (c *forwardCtx) NewConnectionTCP( connectedAddress string, connectedPort uint32, origAddress string, @@ -436,7 +436,7 @@ func (c *ForwardCtx) NewConnectionTCP( ) } -func (c *ForwardCtx) NewConnectionUnix( +func (c *forwardCtx) NewConnectionUnix( path string, closeFunc func() error, ) (io.ReadWriteCloser, error) { @@ -450,7 +450,7 @@ func (c *ForwardCtx) NewConnectionUnix( ) } -func (c *ForwardCtx) newConnection( +func (c *forwardCtx) newConnection( protocol string, connectedAddress string, connectedPort uint32, @@ -476,7 +476,7 @@ func (c *ForwardCtx) newConnection( } bufferReader, bufferWriter := io.Pipe() - conn := Connection{ + conn := connection{ state: CONNECTION_STATE_WAITINIT, initiator: true, bufferReader: bufferReader, @@ -498,7 +498,7 @@ func (c *ForwardCtx) newConnection( c.connMapMu.Unlock() err = c.writePacket(&Packet{ Type: PACKET_NEW_CONNECTION, - ConnectionId: conn.id, + ConnectionID: conn.id, Payload: marInfo, }) if err != nil { @@ -513,15 +513,15 @@ func (c *ForwardCtx) newConnection( return &conn, nil } -func (c *ForwardCtx) init() { - c.connMap = make(map[uint64]*Connection) - c.connectionChannel = make(chan *Connection) +func (c *forwardCtx) init() { + c.connMap = make(map[uint64]*connection) + c.connectionChannel = make(chan Connection) c.encoder = cbor.NewEncoder(c.toBackend) c.decoder = cbor.NewDecoder(c.fromBackend) } -func (c *ForwardCtx) StartClient() (connectionType uint32, setupPacket SetupPacket, connChan chan *Connection, err error) { +func (c *forwardCtx) StartServer() (connectionType uint32, setupPacket SetupPacket, connChan chan Connection, err error) { c.init() packet := Packet{} @@ -568,7 +568,7 @@ func (c *ForwardCtx) StartClient() (connectionType uint32, setupPacket SetupPack return setup.ConnectionType, setup, c.connectionChannel, nil } -func (c *ForwardCtx) StartServerForward() (chan *Connection, error) { +func (c *forwardCtx) StartClientForward() (chan Connection, error) { c.init() setupPacket := SetupPacket{ @@ -609,7 +609,7 @@ func (c *ForwardCtx) StartServerForward() (chan *Connection, error) { return c.connectionChannel, nil } -func (c *ForwardCtx) startReverseForwardingClient(setupPacket SetupPacket) (chan *Connection, error) { +func (c *forwardCtx) startReverseForwardingClient(setupPacket SetupPacket) (chan Connection, error) { c.init() mar, err := cbor.Marshal(&setupPacket) @@ -647,7 +647,7 @@ func (c *ForwardCtx) startReverseForwardingClient(setupPacket SetupPacket) (chan return c.connectionChannel, nil } -func (c *ForwardCtx) StartX11ForwardClient(singleConnection bool, screen string, authProtocol string, authCookie string) (chan *Connection, error) { +func (c *forwardCtx) StartX11ForwardClient(singleConnection bool, screen string, authProtocol string, authCookie string) (chan Connection, error) { setupPacket := SetupPacket{ ConnectionType: CONNECTION_TYPE_X11, Protocol: "tcp", @@ -660,7 +660,7 @@ func (c *ForwardCtx) StartX11ForwardClient(singleConnection bool, screen string, return c.startReverseForwardingClient(setupPacket) } -func (c *ForwardCtx) StartReverseForwardClient(bindHost string, bindPort uint32, singleConnection bool) (chan *Connection, error) { +func (c *forwardCtx) StartReverseForwardClient(bindHost string, bindPort uint32, singleConnection bool) (chan Connection, error) { setupPacket := SetupPacket{ ConnectionType: CONNECTION_TYPE_PORT_FORWARD, BindHost: bindHost, @@ -672,7 +672,7 @@ func (c *ForwardCtx) StartReverseForwardClient(bindHost string, bindPort uint32, return c.startReverseForwardingClient(setupPacket) } -func (c *ForwardCtx) StartReverseForwardClientUnix(path string, singleConnection bool) (chan *Connection, error) { +func (c *forwardCtx) StartReverseForwardClientUnix(path string, singleConnection bool) (chan Connection, error) { setupPacket := SetupPacket{ ConnectionType: CONNECTION_TYPE_PORT_FORWARD, BindHost: path, @@ -683,7 +683,7 @@ func (c *ForwardCtx) StartReverseForwardClientUnix(path string, singleConnection return c.startReverseForwardingClient(setupPacket) } -func (c *ForwardCtx) NoMoreConnections() error { +func (c *forwardCtx) NoMoreConnections() error { c.stopped = true close(c.connectionChannel) return c.writePacket( @@ -693,11 +693,11 @@ func (c *ForwardCtx) NoMoreConnections() error { ) } -func (c *ForwardCtx) WaitFinish() { +func (c *forwardCtx) WaitFinish() { c.waitGroup.Wait() } -func (c *ForwardCtx) Kill() { +func (c *forwardCtx) Kill() { if !c.stopped { _ = c.NoMoreConnections() } @@ -710,7 +710,7 @@ func (c *ForwardCtx) Kill() { case <-t: case <-time.After(5 * time.Second): for _, conn := range c.connMap { - _ = conn.CloseImm() + _ = conn.CloseImmediately() } } }() diff --git a/agentprotocol/server_test.go b/agentprotocol/server_test.go index b4f589ae..9d8a4e50 100644 --- a/agentprotocol/server_test.go +++ b/agentprotocol/server_test.go @@ -5,22 +5,21 @@ import ( "io" "testing" - proto "go.containerssh.io/libcontainerssh/agentprotocol" + proto "go.containerssh.io/libcontainerssh/agentprotocol" - log "go.containerssh.io/libcontainerssh/log" + log "go.containerssh.io/libcontainerssh/log" ) -// region Tests func TestConnectionSetup(t *testing.T) { - log := log.NewTestLogger(t) + logger := log.NewTestLogger(t) fromClientReader, fromClientWriter := io.Pipe() toClientReader, toClientWriter := io.Pipe() - clientCtx := proto.NewForwardCtx(toClientReader, fromClientWriter, log) - - serverCtx := proto.NewForwardCtx(fromClientReader, toClientWriter, log) + clientCtx := proto.NewForwardCtx(toClientReader, fromClientWriter, logger) + serverCtx := proto.NewForwardCtx(fromClientReader, toClientWriter, logger) closeChan := make(chan struct{}) + startedChan := make(chan struct{}) go func() { connChan, err := serverCtx.StartReverseForwardClient( @@ -32,30 +31,32 @@ func TestConnectionSetup(t *testing.T) { panic(err) } + close(startedChan) + testConServer := <-connChan err = testConServer.Accept() if err != nil { - log.Error("Error accept connection", err) + logger.Error("Error accept connection", err) } buf := make([]byte, 512) nBytes, err := testConServer.Read(buf) if err != nil { - log.Error("Failed to read from server") + logger.Error("Failed to read from server") } _, err = testConServer.Write(buf[:nBytes]) if err != nil { - log.Error("Failed to write to server") + logger.Error("Failed to write to server") } <-closeChan serverCtx.Kill() }() - conType, setup, connectionChan, err := clientCtx.StartClient() + conType, setup, connectionChan, err := clientCtx.StartServer() if err != nil { t.Fatal("Test failed with error", err) } if conType != proto.CONNECTION_TYPE_PORT_FORWARD { - panic(fmt.Errorf("Invalid connection type %d", conType)) + panic(fmt.Errorf("invalid connection type %d", conType)) } go func() { diff --git a/internal/agentforward/agentForwardImpl.go b/internal/agentforward/agentForwardImpl.go index 8b8445cc..b649646a 100644 --- a/internal/agentforward/agentForwardImpl.go +++ b/internal/agentforward/agentForwardImpl.go @@ -6,17 +6,17 @@ import ( "io" "sync" - protocol "go.containerssh.io/libcontainerssh/agentprotocol" - "go.containerssh.io/libcontainerssh/internal/sshserver" - "go.containerssh.io/libcontainerssh/log" + protocol "go.containerssh.io/libcontainerssh/agentprotocol" + "go.containerssh.io/libcontainerssh/internal/sshserver" + "go.containerssh.io/libcontainerssh/log" ) type agentForward struct { lock sync.Mutex - reverseForwards map[string]*protocol.ForwardCtx + reverseForwards map[string]protocol.ForwardCtx nX11Channels uint32 - x11Forward *protocol.ForwardCtx - directForward *protocol.ForwardCtx + x11Forward protocol.ForwardCtx + directForward protocol.ForwardCtx logger log.Logger } @@ -24,7 +24,7 @@ func NewAgentForward( logger log.Logger, ) AgentForward { return &agentForward{ - reverseForwards: make(map[string]*protocol.ForwardCtx), + reverseForwards: make(map[string]protocol.ForwardCtx), logger: logger, } } @@ -61,7 +61,7 @@ func serveConnection(log log.Logger, dst io.WriteCloser, src io.ReadCloser) { _ = src.Close() } -func (f *agentForward) serveX11(connChan chan *protocol.Connection, reverseHandler sshserver.ReverseForward) { +func (f *agentForward) serveX11(connChan chan protocol.Connection, reverseHandler sshserver.ReverseForward) { for { agentConn, ok := <-connChan if !ok { @@ -94,7 +94,7 @@ func (f *agentForward) serveX11(connChan chan *protocol.Connection, reverseHandl } } -func (f *agentForward) serveReverseForward(connChan chan *protocol.Connection, reverseHandler sshserver.ReverseForward) { +func (f *agentForward) serveReverseForward(connChan chan protocol.Connection, reverseHandler sshserver.ReverseForward) { for { agentConn, ok := <-connChan if !ok { @@ -120,7 +120,7 @@ func (f *agentForward) serveReverseForward(connChan chan *protocol.Connection, r } } -func (f *agentForward) serveReverseForwardUnix(connChan chan *protocol.Connection, reverseHandler sshserver.ReverseForward) { +func (f *agentForward) serveReverseForwardUnix(connChan chan protocol.Connection, reverseHandler sshserver.ReverseForward) { for { agentConn, ok := <-connChan if !ok { @@ -325,7 +325,7 @@ func (f *agentForward) setupDirectForward( return err } f.directForward = protocol.NewForwardCtx(fromAgent, toAgent, logger) - connChan, err := f.directForward.StartServerForward() + connChan, err := f.directForward.StartClientForward() if err != nil { return err }