diff --git a/Build.act b/Build.act
new file mode 100644
index 0000000..d12b2b6
--- /dev/null
+++ b/Build.act
@@ -0,0 +1,9 @@
+name = "ssh"
+fingerprint = 0xee8dcc44255e0a2e
+zig_dependencies = {
+ "libssh": (
+ path="../acton-deps/libssh",
+ options={"WITH_SERVER": "true"},
+ artifacts=["ssh"]
+ )
+}
diff --git a/build.act.json b/build.act.json
index ec3c9b7..ac51969 100644
--- a/build.act.json
+++ b/build.act.json
@@ -2,8 +2,10 @@
"dependencies": {},
"zig_dependencies": {
"libssh": {
- "url": "https://github.com/actonlang/libssh/archive/refs/heads/zig-build.tar.gz",
- "hash": "122076b4676ec7d777e5efc24ae4b7f1ab4fa3df407ab9649fdd55f9fc594ca053ec",
+ "path": "../acton-deps/libssh",
+ "options": {
+ "WITH_SERVER": "true"
+ },
"artifacts": [
"ssh"
]
diff --git a/src/ssh.act b/src/ssh.act
index a8b0268..eb5b29e 100644
--- a/src/ssh.act
+++ b/src/ssh.act
@@ -1,69 +1,590 @@
+"""
+Acton SSH module
+
+This module exposes a callback-driven SSH client/server API built on libssh
+and integrated with Acton's libuv loop. The main actors are:
+
+- Client: connects to a server, reports on_connect(err) and on_close(reason),
+ and optionally on_hostkey(state, HostKeyInfo) for strict host key checks.
+- Channel: client-side data stream (stdout/stderr/exit/close callbacks).
+- RunCommand: convenience wrapper that buffers stdout/stderr and reports exit.
+- Server: listens for incoming connections and creates ServerSession actors.
+- ServerSession: per-connection state; emits on_auth(AuthRequest),
+ on_channel_open, and optional on_exec/on_subsystem.
+- ServerChannel: server-side data stream for a single channel.
+
+Usage sketch (client):
+ def on_connect(c, err):
+ if err is not None: ...
+ Channel(c, on_open, on_stdout, on_stderr, on_exit, on_close)
+
+ def on_hostkey(c, state, info):
+ if state == HOSTKEY_OK: c.accept_hostkey()
+ else: c.reject_hostkey("...")
+
+Usage sketch (server):
+ def on_auth(sess, req): sess.accept_auth()
+ def on_channel_open(sess): sess.accept_channel(ServerChannel(...))
+ def on_subsystem(sess, ch, name): ch.accept_request(); ch.write(...)
+
+Notes:
+ - All operations are async actor actions; callbacks receive errors as ?str.
+ - ?bytes callbacks use None to signal EOF.
+ - Timeouts and keepalive are configurable; host keys are generated in-memory
+ unless host_key_path is set.
+ - Server admission limits are configurable; values <= 0 disable a limit.
+ - For debugging, set ACTON_SSH_DEBUG and ACTON_SSH_LIBSSH_LOG.
+"""
+
import net
-import testing
+
def version() -> str:
"""Get the libssh version"""
NotImplemented
-def _test_version():
- testing.assertEqual("0.11.0", version())
+
+def _debug(msg: str) -> None:
+ """Internal debug logging (gated by ACTON_SSH_DEBUG)."""
+ NotImplemented
+
+
+HOSTKEY_OK = "ok"
+HOSTKEY_UNKNOWN = "unknown"
+HOSTKEY_NOT_FOUND = "not_found"
+HOSTKEY_CHANGED = "changed"
+HOSTKEY_OTHER = "other"
+HOSTKEY_ERROR = "error"
+
+class HostKeyInfo():
+ """Server host key details"""
+ key_type: str
+ fingerprint: str
+
+ def __init__(self, key_type: str, fingerprint: str):
+ self.key_type = key_type
+ self.fingerprint = fingerprint
+
+
+class AuthRequest():
+ """Authentication request details"""
+ method: str
+ user: str
+ password: ?str
+ pubkey: ?bytes
+
+ def __init__(self, method: str, user: str, password: ?str=None, pubkey: ?bytes=None):
+ self.method = method
+ self.user = user
+ self.password = password
+ self.pubkey = pubkey
+
actor Client(cap: net.TCPConnectCap,
host: str,
username: str,
- on_connect: action(Client) -> None,
+ on_connect: action(Client, ?str) -> None,
on_close: action(Client, str) -> None,
- key: ?str=None,
+ on_hostkey: ?action(Client, str, HostKeyInfo) -> None=None,
password: ?str=None,
port: u16=22,
+ known_hosts: ?str=None,
+ connect_timeout: float=10.0,
+ auth_timeout: float=10.0,
+ keepalive_interval: float=30.0,
+ keepalive_enabled: bool=True,
):
"""SSH Client"""
# haha, this is really a pointer :P
- var _ssh_session: u64 = 0
+ var _client: u64 = 0
+
+ proc def _pin_affinity() -> None:
+ NotImplemented
+ _pin_affinity()
proc def _init() -> None:
"""Initialize the SSH client"""
NotImplemented
_init()
- print("SSH Client connected")
-# action def close(on_close: action(TLSConnection) -> None) -> None:
-# """Close the connection"""
-# NotImplemented
-#
-# def reconnect():
-# close(_connect)
-#
-# def _connect(c):
-# NotImplemented
-#
+ action def accept_hostkey() -> None:
+ """Accept the currently pending host key (not persisted)"""
+ NotImplemented
+
+ action def reject_hostkey(reason: str) -> None:
+ """Reject the currently pending host key"""
+ NotImplemented
+
+ action def close() -> None:
+ """Close the SSH client"""
+ NotImplemented
+
+ proc def _cleanup_native() -> None:
+ NotImplemented
+
+ action def __cleanup__() -> None:
+ if _client != 0:
+ _cleanup_native()
+
+ # Internal channel operations (used by Channel actor)
+ action def channel_create(channel: Channel,
+ on_open: action(Channel, ?str) -> None,
+ on_stdout: action(Channel, ?bytes) -> None,
+ on_stderr: action(Channel, ?bytes) -> None,
+ on_exit: action(Channel, int, ?str) -> None,
+ on_close: action(Channel, str) -> None) -> None:
+ NotImplemented
+
+ action def channel_request_exec(channel: Channel, cmd: str) -> None:
+ NotImplemented
+
+ action def channel_request_shell(channel: Channel,
+ term: str,
+ cols: int,
+ rows: int,
+ width_px: int,
+ height_px: int,
+ with_pty: bool) -> None:
+ NotImplemented
+
+ action def channel_request_subsystem(channel: Channel, name: str) -> None:
+ NotImplemented
+
+ action def channel_write(channel: Channel, data: bytes) -> None:
+ NotImplemented
+
+ action def channel_send_eof(channel: Channel) -> None:
+ NotImplemented
+
+ action def channel_close(channel: Channel) -> None:
+ NotImplemented
+
+
+actor Channel(client: Client,
+ on_open: action(Channel, ?str) -> None,
+ on_stdout: action(Channel, ?bytes) -> None,
+ on_stderr: action(Channel, ?bytes) -> None,
+ on_exit: action(Channel, int, ?str) -> None,
+ on_close: action(Channel, str) -> None):
+ """SSH Channel"""
+
+ # pointer to internal channel state
+ var _channel_id: u64 = 0
+ var _on_open: action(Channel, ?str) -> None = on_open
+ var _on_stdout: action(Channel, ?bytes) -> None = on_stdout
+ var _on_stderr: action(Channel, ?bytes) -> None = on_stderr
+ var _on_exit: action(Channel, int, ?str) -> None = on_exit
+ var _on_close: action(Channel, str) -> None = on_close
+
+ proc def _init() -> None:
+ client.channel_create(self, _on_open, _on_stdout, _on_stderr, _on_exit, _on_close)
+ _init()
+
+ action def request_exec(cmd: str) -> None:
+ """Request to execute a command"""
+ client.channel_request_exec(self, cmd)
+
+ action def request_shell(term: str="xterm-256color",
+ cols: int=80,
+ rows: int=24,
+ width_px: int=0,
+ height_px: int=0,
+ with_pty: bool=True) -> None:
+ """Request an interactive shell"""
+ client.channel_request_shell(self, term, cols, rows, width_px, height_px, with_pty)
+
+ action def request_subsystem(name: str) -> None:
+ """Request a subsystem"""
+ client.channel_request_subsystem(self, name)
+
+ action def write(data: bytes) -> None:
+ """Write data to the channel"""
+ client.channel_write(self, data)
+
+ action def send_eof() -> None:
+ """Half-close writes while keeping the read side open"""
+ client.channel_send_eof(self)
+
+ action def close() -> None:
+ """Close the channel and discard unread inbound data"""
+ client.channel_close(self)
+
+ proc def _cleanup_native() -> None:
+ NotImplemented
+
+ action def __cleanup__() -> None:
+ if _channel_id != 0:
+ _cleanup_native()
+
+
+actor RunCommand(client: Client,
+ cmd: str,
+ on_exit: action(Channel, int, ?str, bytes, bytes, ?str) -> None,
+ timeout: ?float=None):
+ """Run a command and collect output"""
+
+ var out_buf = b""
+ var err_buf = b""
+ var _out_done = False
+ var _err_done = False
+ var _exited = False
+ var _exit_code = 0
+ var _exit_signal: ?str = None
+ var _error: ?str = None
+ var _done = False
+
+ def _finish(ch: Channel):
+ if _done:
+ return
+ _done = True
+ on_exit(ch, _exit_code, _exit_signal, out_buf, err_buf, _error)
+
+ def _on_open(ch: Channel, err: ?str):
+ if _done:
+ return
+ if err is not None:
+ _error = err
+ _finish(ch)
+ return
+ ch.request_exec(cmd)
+
+ def _on_stdout(ch: Channel, data: ?bytes):
+ if _done:
+ return
+ if data is not None:
+ out_buf += data
+ else:
+ _out_done = True
+ _check_done(ch)
+
+ def _on_stderr(ch: Channel, data: ?bytes):
+ if _done:
+ return
+ if data is not None:
+ err_buf += data
+ else:
+ _err_done = True
+ _check_done(ch)
+
+ def _on_exit(ch: Channel, code: int, sig: ?str):
+ if _done:
+ return
+ _exited = True
+ _exit_code = code
+ _exit_signal = sig
+ _check_done(ch)
+
+ def _on_close(ch: Channel, reason: str):
+ if _done:
+ return
+ if _error is None and reason != "closed":
+ _error = reason
+ if _error is not None:
+ _finish(ch)
+ return
+
+ def _check_done(ch: Channel):
+ if _out_done and _err_done and _exited and _error is None:
+ _finish(ch)
+
+ var _channel = Channel(client, _on_open, _on_stdout, _on_stderr, _on_exit, _on_close)
+
+ if timeout is not None:
+ def _on_timeout():
+ if _done:
+ return
+ _error = "timeout"
+ _finish(_channel)
+ _channel.close()
+ after timeout: _on_timeout()
+
+
+actor Server(cap: net.TCPListenCap,
+ host: str,
+ port: u16,
+ on_listen: action(Server, ?str) -> None,
+ on_close: action(Server, str) -> None,
+ on_session: action(ServerSession) -> None,
+ on_auth: action(ServerSession, AuthRequest) -> None,
+ on_channel_open: action(ServerSession) -> None,
+ on_exec: ?action(ServerSession, ServerChannel, str) -> None=None,
+ on_subsystem: ?action(ServerSession, ServerChannel, str) -> None=None,
+ on_session_close: ?action(ServerSession, str) -> None=None,
+ host_key_path: ?str=None,
+ host_key_type: str="ed25519",
+ host_key_bits: int=2048,
+ auth_timeout: float=10.0,
+ keepalive_interval: float=30.0,
+ keepalive_enabled: bool=True,
+ max_sessions: int=128,
+ max_channels_per_session: int=32,
+ ):
+ """SSH Server"""
+
+ var _server: u64 = 0
+ var _host: str = host
+ var _port: u16 = port
+ var _host_key_path: ?str = host_key_path
+ var _host_key_type: str = host_key_type
+ var _host_key_bits: int = host_key_bits
+ var _auth_timeout: float = auth_timeout
+ var _keepalive_interval: float = keepalive_interval
+ var _keepalive_enabled: bool = keepalive_enabled
+ var _max_sessions: int = max_sessions
+ var _max_channels_per_session: int = max_channels_per_session
+ var _on_listen: action(Server, ?str) -> None = on_listen
+ var _on_close: action(Server, str) -> None = on_close
+
+ proc def _pin_affinity() -> None:
+ NotImplemented
+ _pin_affinity()
+
+ proc def _init() -> None:
+ """Initialize the SSH server"""
+ NotImplemented
+ _init()
+
+ action def close() -> None:
+ """Close the SSH server"""
+ NotImplemented
-# TODO: implement support for channels
-# AFAIK, all things over ssh are done via channels, so need some channel
-# primitive, maybe an actor per channel that then multiplexes into the Client
-# session? Prolly need some higher level wrappers for common things like
-# starting a shell or running a single command. SFTP / SCP would be nice too,
-# but for sometime in the future. Custom subsystems need to be supported too.
+ proc def _cleanup_native() -> None:
+ NotImplemented
+ action def __cleanup__() -> None:
+ if _server != 0:
+ _cleanup_native()
+ action def on_session_pending(session_id: u64) -> None:
+ _debug("server on_session_pending: " + str(session_id))
+ ServerSession(self,
+ session_id,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close)
+ _debug("server on_session_pending: created session actor for " + str(session_id))
+
+ action def on_session_ready(session: ServerSession) -> None:
+ on_session(session)
+
+
+actor ServerSession(server: Server,
+ session_id: u64,
+ on_auth: action(ServerSession, AuthRequest) -> None,
+ on_channel_open: action(ServerSession) -> None,
+ on_exec: ?action(ServerSession, ServerChannel, str) -> None=None,
+ on_subsystem: ?action(ServerSession, ServerChannel, str) -> None=None,
+ on_close: ?action(ServerSession, str) -> None=None):
+ """SSH Server Session"""
+
+ var _session_id: u64 = 0
+ var _on_auth: action(ServerSession, AuthRequest) -> None = on_auth
+ var _on_channel_open: action(ServerSession) -> None = on_channel_open
+ var _on_exec: ?action(ServerSession, ServerChannel, str) -> None = on_exec
+ var _on_subsystem: ?action(ServerSession, ServerChannel, str) -> None = on_subsystem
+ var _on_close: ?action(ServerSession, str) -> None = on_close
+
+ proc def _pin_affinity() -> None:
+ NotImplemented
+ _pin_affinity()
+
+ proc def _attach(session_id: u64) -> None:
+ NotImplemented
+
+ proc def _drive_attached() -> None:
+ NotImplemented
+
+ action def _attach_ready() -> None:
+ _debug("server session attach ready: " + str(session_id))
+ _attach(session_id)
+ _debug("server session attach: _attach returned for " + str(session_id))
+ if _session_id != 0:
+ server.on_session_ready(self)
+ _debug("server session attach: on_session_ready sent for " + str(session_id))
+ _drive_attached()
+
+ after 0: _attach_ready()
+
+ action def accept_auth() -> None:
+ """Accept the pending authentication request"""
+ NotImplemented
+
+ action def reject_auth(reason: str) -> None:
+ """Reject the pending authentication request"""
+ NotImplemented
+
+ action def accept_channel(channel: ServerChannel) -> None:
+ """Accept the pending channel open request"""
+ channel.accept_open()
+
+ action def accept_channel_open(channel: ServerChannel,
+ on_data: action(ServerChannel, ?bytes) -> None,
+ on_stderr: action(ServerChannel, ?bytes) -> None,
+ on_close: action(ServerChannel, str) -> None) -> None:
+ NotImplemented
+
+ action def reject_channel(reason: str) -> None:
+ """Reject the pending channel open request"""
+ NotImplemented
+
+ action def close() -> None:
+ """Close the SSH session"""
+ NotImplemented
+
+ proc def _cleanup_native() -> None:
+ NotImplemented
+
+ action def __cleanup__() -> None:
+ if _session_id != 0:
+ _cleanup_native()
+
+ # Internal channel operations (used by ServerChannel actor)
+ action def channel_accept_request(channel: ServerChannel) -> None:
+ NotImplemented
+
+ action def channel_reject_request(channel: ServerChannel, reason: str) -> None:
+ NotImplemented
+
+ action def channel_write(channel: ServerChannel, data: bytes) -> None:
+ NotImplemented
+
+ action def channel_write_stderr(channel: ServerChannel, data: bytes) -> None:
+ NotImplemented
+
+ action def channel_send_eof(channel: ServerChannel) -> None:
+ NotImplemented
+
+ action def channel_send_exit_status(channel: ServerChannel, status: int) -> None:
+ NotImplemented
+
+ action def channel_close(channel: ServerChannel) -> None:
+ NotImplemented
+
+
+actor ServerChannel(session: ServerSession,
+ on_data: action(ServerChannel, ?bytes) -> None,
+ on_stderr: action(ServerChannel, ?bytes) -> None,
+ on_close: action(ServerChannel, str) -> None):
+ """SSH Server Channel"""
+
+ var _channel_id: u64 = 0
+ var _on_data: action(ServerChannel, ?bytes) -> None = on_data
+ var _on_stderr: action(ServerChannel, ?bytes) -> None = on_stderr
+ var _on_close: action(ServerChannel, str) -> None = on_close
+
+ action def accept_request() -> None:
+ """Accept the pending channel request"""
+ session.channel_accept_request(self)
+
+ action def accept_open() -> None:
+ """Accept the pending channel open request"""
+ session.accept_channel_open(self, _on_data, _on_stderr, _on_close)
+
+ action def reject_request(reason: str) -> None:
+ """Reject the pending channel request"""
+ session.channel_reject_request(self, reason)
+
+ action def write(data: bytes) -> None:
+ """Write data to the channel (stdout)"""
+ session.channel_write(self, data)
+
+ action def write_stderr(data: bytes) -> None:
+ """Write data to the channel (stderr)"""
+ session.channel_write_stderr(self, data)
+
+ action def send_eof() -> None:
+ """Half-close writes while keeping the read side open"""
+ session.channel_send_eof(self)
+
+ action def send_exit_status(status: int) -> None:
+ """Send an exit status for the channel"""
+ session.channel_send_exit_status(self, status)
+
+ action def close() -> None:
+ """Close the channel and discard unread inbound data"""
+ session.channel_close(self)
+
+ proc def _cleanup_native() -> None:
+ NotImplemented
+
+ action def __cleanup__() -> None:
+ if _channel_id != 0:
+ _cleanup_native()
actor main(env):
- def on_connect(client: Client):
- print("Connected")
+ NETCONF_HELLO = b'\n' + \
+ b'\n' + \
+ b' \n' + \
+ b' urn:ietf:params:netconf:base:1.0\n' + \
+ b' \n' + \
+ b']]>]]>'
+ var out_buf = b""
+ var c: ?Client = None
- def on_close(client: Client, error: str):
- print("Error", error)
+ def on_connect(client: Client, err: ?str):
+ if err is not None:
+ print("SSH error", err)
+ env.exit(1)
+ return
+ print("SSH connected")
- print(version())
- c = Client(
+ def on_close(client: Client, reason: str):
+ print("SSH closed", reason)
+ env.exit(0)
+
+ def on_hostkey(client: Client, state: str, info: HostKeyInfo):
+ print("Host key", state, info.key_type, info.fingerprint)
+ # WARNING: accepting any host key is insecure; do this only for testing
+ client.accept_hostkey()
+
+ def ch_open(ch: Channel, err: ?str):
+ if err is not None:
+ print("Channel open error", err)
+ if c is not None:
+ c.close()
+ return
+ ch.request_subsystem("netconf")
+ ch.write(NETCONF_HELLO)
+
+ def ch_out(ch: Channel, data: ?bytes):
+ if data is not None:
+ out_buf += data
+ if out_buf.find(b"]]>]]>") >= 0:
+ print(out_buf)
+ ch.close()
+
+ def ch_err(ch: Channel, data: ?bytes):
+ if data is not None:
+ print(data)
+
+ def ch_exit(ch: Channel, code: int, sig: ?str):
+ print("exit", code, sig)
+
+ def ch_close(ch: Channel, reason: str):
+ print("channel closed", reason)
+ if c is not None:
+ c.close()
+
+ client = Client(
net.TCPConnectCap(net.TCPCap(net.NetCap(env.cap))),
- "localhost",
- "foo",
+ "127.0.0.1",
+ "admin",
on_connect,
on_close,
- password="bar",
- port=2223,
+ on_hostkey,
+ password="admin",
+ port=42830,
+ known_hosts="/tmp/acton_ssh_known_hosts",
)
- env.exit(0)
+
+ c = client
+ Channel(client, ch_open, ch_out, ch_err, ch_exit, ch_close)
+
+ def _on_close_timeout():
+ if c is not None:
+ c.close()
+ after 10.0: _on_close_timeout()
diff --git a/src/ssh.ext.c b/src/ssh.ext.c
index 2112598..0b3ea10 100644
--- a/src/ssh.ext.c
+++ b/src/ssh.ext.c
@@ -1,120 +1,3941 @@
+/*
+ * Acton <-> libssh integration overview
+ *
+ * This file is the external-C glue that drives libssh from Acton's libuv loop
+ * and exposes it to Acton actors. The core goals are:
+ * - Nonblocking SSH I/O integrated with libuv (no blocking syscalls).
+ * - Actor-safe, async callback-driven API in Acton.
+ * - GC-safe memory: libssh allocations use Acton's allocator.
+ *
+ * Event loop integration
+ * - Each libssh session is created nonblocking.
+ * - We attach a uv_poll watcher to the libssh socket fd.
+ * - On poll events we call ssh_session_handle_poll() (via
+ * session_apply_poll_events), then drive a small state machine
+ * (connect/auth/ready for client, keyex/auth/ready for server).
+ * - ssh_get_poll_flags()/ssh_get_status() decide which poll events to arm.
+ *
+ * Buffered data + SSH_AGAIN (why we keep driving without fd readability)
+ * - libssh maintains its own internal buffers. After a poll callback, libssh
+ * may have already read bytes into those buffers even though the socket is
+ * no longer readable at the OS level.
+ * - When a nonblocking API returns SSH_AGAIN and ssh_get_status() includes
+ * SSH_READ_PENDING, it means "call again, there is buffered data to
+ * process" even if the fd will not trigger another readable event.
+ * - If we only wait for uv_poll readability, we can deadlock:
+ * 1) uv_poll READABLE fires; ssh_session_handle_poll() drains the fd.
+ * 2) ssh_connect()/ssh_handle_key_exchange()/ssh_userauth_password()
+ * returns SSH_AGAIN.
+ * 3) No more kernel readability events happen, but libssh still has
+ * buffered protocol bytes (SSH_READ_PENDING).
+ * 4) We wait for an event that never comes and eventually time out.
+ * - The fix is to keep driving the state machine in a bounded loop while
+ * SSH_READ_PENDING is set, even without fd readability.
+ * We cap iterations with SSH_IO_PUMP_LIMIT to avoid CPU spin.
+ *
+ * Channel I/O
+ * - SSH channels carry two streams: "data" and "extended data". We expose
+ * these as stdout/stderr callbacks (client on_stdout/on_stderr, server
+ * on_data/on_stderr). This is protocol-level stdout/stderr, not host OS
+ * process stdio.
+ * - Channels install libssh callbacks for data/extended-data/EOF/close.
+ * - Inbound data always flows through these callbacks. As we drive libssh
+ * (via ssh_session_handle_poll), libssh invokes the registered C callback
+ * functions, and those callbacks call the corresponding Acton action
+ * methods (foo->$class->on_stdout/on_stderr/on_close, etc.). We do not run
+ * manual read loops; libssh owns buffering and read state.
+ * - Channel writes are queued and flushed when libssh reports write pending.
+ *
+ * Actor/GC/threading model
+ * - Client and ServerSession actors own libssh state and are pinned to a
+ * worker thread. Channel actors invoke action methods on their owning
+ * Client/ServerSession actor for all operations; there is no hidden
+ * cross-actor C magic.
+ * - We replace libssh allocators with Acton's GC allocator so libuv/GC
+ * roots remain visible (libssh structures can reference GC memory).
+ *
+ * Config & filesystem
+ * - libssh config processing is disabled by default; known_hosts is only
+ * read if explicitly configured by the Acton API.
+ * - Server host keys are generated in-memory unless a path is provided.
+ */
+#include
+#include
#include
-#include
-// TODO: figure out how to include rts/log so we get access to log_error etc
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
-void noop_free(void *ptr) {
+#include "rts/log.h"
+
+uv_loop_t *get_uv_loop(void);
+extern struct $Cont $Done$instance;
+
+#define SSH_READ_BUFSIZE 4096
+#define SSH_IO_PUMP_LIMIT 128
+#define SSH_ATTACH_TIMEOUT_SEC 5.0
+#define SSH_KEYEX_TIMEOUT_SEC 2.0
+#define SSH_SERVER_ACCEPT_LIMIT 64
+static int ssh_debug_enabled = 0;
+static int ssh_libssh_log_level = SSH_LOG_NOLOG;
+
+typedef enum {
+ CLIENT_STATE_INIT = 0,
+ CLIENT_STATE_CONNECTING,
+ CLIENT_STATE_HOSTKEY,
+ CLIENT_STATE_HOSTKEY_WAIT,
+ CLIENT_STATE_AUTH,
+ CLIENT_STATE_READY,
+ CLIENT_STATE_ERROR,
+ CLIENT_STATE_CLOSING,
+ CLIENT_STATE_CLOSED,
+} client_state_t;
+
+typedef enum {
+ CHAN_STATE_INIT = 0,
+ CHAN_STATE_OPENING,
+ CHAN_STATE_OPEN,
+ CHAN_STATE_RUNNING,
+ CHAN_STATE_CLOSING,
+ CHAN_STATE_CLOSED,
+ CHAN_STATE_ERROR,
+} channel_state_t;
+
+typedef enum {
+ CHAN_REQ_NONE = 0,
+ CHAN_REQ_SHELL,
+ CHAN_REQ_EXEC,
+ CHAN_REQ_SUBSYSTEM,
+} channel_request_t;
+
+typedef struct write_chunk {
+ B_bytes data;
+ size_t offset;
+ struct write_chunk *next;
+} write_chunk_t;
+
+typedef struct ssh_channel_ctx {
+ struct ssh_channel_ctx *next;
+ ssh_channel channel;
+ struct ssh_client_ctx *client;
+ struct ssh_channel_callbacks_struct *callbacks;
+ sshQ_Channel actor;
+ channel_state_t state;
+ channel_request_t pending_req;
+ int pty_pending;
+ int pty_done;
+ B_str exec_cmd;
+ B_str subsystem;
+ B_str term;
+ int cols;
+ int rows;
+ int width_px;
+ int height_px;
+ int send_eof;
+ int eof_sent;
+ int close_requested;
+ int close_sent;
+ int remote_close_seen;
+ int write_wontblock;
+ int stdout_eof;
+ int stderr_eof;
+ int exit_sent;
+ int open_notified;
+ int open_succeeded;
+ int close_notified;
+ $action2 on_open;
+ $action2 on_close;
+ $action2 on_stdout;
+ $action2 on_stderr;
+ $action3 on_exit;
+ write_chunk_t *write_head;
+ write_chunk_t *write_tail;
+} ssh_channel_ctx;
+
+typedef struct ssh_client_ctx {
+ GC_hidden_pointer actor;
+ ssh_session session;
+ uv_poll_t *poll;
+ int poll_events;
+ uv_timer_t *connect_timer;
+ uv_timer_t *auth_timer;
+ uv_timer_t *keepalive_timer;
+ double connect_timeout;
+ double auth_timeout;
+ double keepalive_interval;
+ int keepalive_enabled;
+ client_state_t state;
+ int fd;
+ int connect_notified;
+ int connected_ok;
+ int close_notified;
+ int close_finalized;
+ int close_force;
+ int write_ready;
+ char *close_reason;
+ enum ssh_known_hosts_e hostkey_state;
+ ssh_channel_ctx *channels;
+ ssh_channel_ctx *retired_channels;
+ $action2 on_connect;
+ $action2 on_close;
+ $action3 on_hostkey;
+} ssh_client_ctx;
+
+typedef enum {
+ SERVER_STATE_INIT = 0,
+ SERVER_STATE_LISTENING,
+ SERVER_STATE_ERROR,
+ SERVER_STATE_CLOSING,
+ SERVER_STATE_CLOSED,
+} server_state_t;
+
+typedef enum {
+ SESSION_STATE_PENDING = 0,
+ SESSION_STATE_KEYEX,
+ SESSION_STATE_AUTH,
+ SESSION_STATE_READY,
+ SESSION_STATE_ERROR,
+ SESSION_STATE_CLOSING,
+ SESSION_STATE_CLOSED,
+} session_state_t;
+
+typedef enum {
+ SCHAN_STATE_OPEN = 0,
+ SCHAN_STATE_CLOSING,
+ SCHAN_STATE_CLOSED,
+ SCHAN_STATE_ERROR,
+} schan_state_t;
+
+typedef enum {
+ SCHAN_REQ_NONE = 0,
+ SCHAN_REQ_EXEC,
+ SCHAN_REQ_SUBSYSTEM,
+} schan_req_t;
+
+typedef struct server_write_chunk {
+ B_bytes data;
+ size_t offset;
+ int is_stderr;
+ struct server_write_chunk *next;
+} server_write_chunk_t;
+
+typedef struct ssh_server_channel_ctx {
+ struct ssh_server_channel_ctx *next;
+ ssh_channel channel;
+ struct ssh_channel_callbacks_struct *callbacks;
+ struct ssh_server_session_ctx *session;
+ sshQ_ServerChannel actor;
+ schan_state_t state;
+ int send_eof;
+ int close_requested;
+ int close_sent;
+ int remote_close_seen;
+ int write_wontblock;
+ int eof_sent;
+ int stdout_eof;
+ int stderr_eof;
+ int close_notified;
+ ssh_message pending_req;
+ schan_req_t pending_req_type;
+ $action2 on_data;
+ $action2 on_stderr;
+ $action2 on_close;
+ server_write_chunk_t *write_head;
+ server_write_chunk_t *write_tail;
+} ssh_server_channel_ctx;
+
+typedef struct ssh_server_session_ctx {
+ struct ssh_server_session_ctx *next;
+ sshQ_ServerSession actor;
+ struct ssh_server_ctx *server;
+ ssh_session session;
+ uv_poll_t *poll;
+ int poll_events;
+ uv_timer_t *attach_timer;
+ uv_timer_t *auth_timer;
+ uv_timer_t *keepalive_timer;
+ double auth_timeout;
+ double keepalive_interval;
+ int keepalive_enabled;
+ session_state_t state;
+ int attached;
+ int fd;
+ int owner_wt;
+ int write_ready;
+ int close_notified;
+ int close_finalized;
+ int close_force;
+ uint64_t pending_id;
+ char *close_reason;
+ ssh_message pending_auth;
+ ssh_message pending_channel_open;
+ ssh_server_channel_ctx *channels;
+ ssh_server_channel_ctx *retired_channels;
+ $action2 on_auth;
+ $action on_channel_open;
+ $action3 on_exec;
+ $action3 on_subsystem;
+ $action2 on_close;
+} ssh_server_session_ctx;
+
+typedef struct ssh_server_ctx {
+ GC_hidden_pointer actor;
+ ssh_bind bind;
+ ssh_key hostkey;
+ uv_poll_t *poll;
+ int fd;
+ server_state_t state;
+ int max_sessions;
+ int max_channels_per_session;
+ int listen_notified;
+ int listen_ok;
+ int close_notified;
+ int close_finalized;
+ char *close_reason;
+ ssh_server_session_ctx *sessions;
+ $action2 on_listen;
+ $action2 on_close;
+} ssh_server_ctx;
+
+static void client_drive(ssh_client_ctx *c);
+static void client_update_poll(ssh_client_ctx *c);
+static void client_close_internal(ssh_client_ctx *c, const char *reason, int force_close);
+static void client_finalize(ssh_client_ctx *c);
+static void client_finish_close(ssh_client_ctx *c);
+static void client_maybe_release(ssh_client_ctx *c);
+static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch);
+static int client_needs_write(ssh_client_ctx *c);
+static void client_pump_io(ssh_client_ctx *c);
+
+static void server_accept(ssh_server_ctx *s);
+static void server_close_internal(ssh_server_ctx *s, const char *reason);
+static void server_finalize(ssh_server_ctx *s);
+static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *sess);
+static void server_maybe_release(ssh_server_ctx *s);
+static void session_drive(ssh_server_session_ctx *s);
+static void session_update_poll(ssh_server_session_ctx *s);
+static void session_pump_io(ssh_server_session_ctx *s);
+static void session_close_internal(ssh_server_session_ctx *s, const char *reason, int force_close);
+static void session_finalize(ssh_server_session_ctx *s);
+static void session_finish_close(ssh_server_session_ctx *s);
+static void session_maybe_release(ssh_server_session_ctx *s);
+static void session_fail(ssh_server_session_ctx *s, const char *msg);
+static void session_start_attach_timer(ssh_server_session_ctx *s);
+static void session_start_keyex_timer(ssh_server_session_ctx *s);
+static int session_start_poll(ssh_server_session_ctx *s, char *errmsg, size_t errmsg_len);
+static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch);
+static int session_needs_write(ssh_server_session_ctx *s);
+static void session_poll_cb(uv_poll_t *handle, int status, int events);
+static void server_poll_cb(uv_poll_t *handle, int status, int events);
+static void session_auth_timeout_cb(uv_timer_t *timer);
+static void session_keepalive_cb(uv_timer_t *timer);
+static void client_poll_close_cb(uv_handle_t *handle);
+static void client_timer_close_cb(uv_handle_t *handle);
+static void server_poll_close_cb(uv_handle_t *handle);
+static void session_poll_close_cb(uv_handle_t *handle);
+static void session_timer_close_cb(uv_handle_t *handle);
+
+#define STORE_HIDDEN_PTR(slot, ptr) \
+ ((slot) = (ptr) ? (GC_hidden_pointer)GC_HIDE_POINTER(ptr) : (GC_hidden_pointer)0)
+#define LOAD_HIDDEN_PTR(type, slot) \
+ ((slot) ? (type)GC_REVEAL_POINTER(slot) : NULL)
+
+static sshQ_Client client_actor_ref(const ssh_client_ctx *c) {
+ return c ? LOAD_HIDDEN_PTR(sshQ_Client, c->actor) : NULL;
+}
+
+static sshQ_Channel channel_actor_ref(const ssh_channel_ctx *ch) {
+ return ch ? ch->actor : NULL;
+}
+
+static sshQ_Server server_actor_ref(const ssh_server_ctx *s) {
+ return s ? LOAD_HIDDEN_PTR(sshQ_Server, s->actor) : NULL;
+}
+
+static sshQ_ServerSession session_actor_ref(const ssh_server_session_ctx *s) {
+ return s ? s->actor : NULL;
+}
+
+static sshQ_ServerChannel server_channel_actor_ref(const ssh_server_channel_ctx *ch) {
+ return ch ? ch->actor : NULL;
+}
+
+static void ssh_debug_log(const char *fmt, ...) {
+ if (!ssh_debug_enabled)
+ return;
+ va_list ap;
+ va_start(ap, fmt);
+ vfprintf(stderr, fmt, ap);
+ va_end(ap);
+ fprintf(stderr, "\n");
+ fflush(stderr);
+}
+
+static int parse_libssh_log_level(const char *value) {
+ if (value == NULL || value[0] == '\0')
+ return SSH_LOG_NOLOG;
+ char *endptr = NULL;
+ long lvl = strtol(value, &endptr, 10);
+ if (endptr != value && endptr && *endptr == '\0') {
+ if (lvl < 0)
+ return SSH_LOG_NOLOG;
+ if (lvl > SSH_LOG_TRACE)
+ lvl = SSH_LOG_TRACE;
+ return (int)lvl;
+ }
+ if (strcasecmp(value, "warn") == 0 || strcasecmp(value, "warning") == 0)
+ return SSH_LOG_WARN;
+ if (strcasecmp(value, "info") == 0 || strcasecmp(value, "protocol") == 0)
+ return SSH_LOG_INFO;
+ if (strcasecmp(value, "debug") == 0 || strcasecmp(value, "packet") == 0)
+ return SSH_LOG_DEBUG;
+ if (strcasecmp(value, "trace") == 0 || strcasecmp(value, "functions") == 0)
+ return SSH_LOG_TRACE;
+ if (strcasecmp(value, "none") == 0 || strcasecmp(value, "off") == 0)
+ return SSH_LOG_NOLOG;
+ return SSH_LOG_NOLOG;
+}
+
+static void ssh_log_cb(int priority, const char *function, const char *buffer, void *userdata) {
+ (void)userdata;
+ if (buffer == NULL)
+ return;
+ fprintf(stderr, "libssh[%d] %s: %s\n", priority, function ? function : "", buffer);
+ fflush(stderr);
+}
+
+static void ssh_configure_libssh_logging(void) {
+ if (ssh_libssh_log_level <= SSH_LOG_NOLOG)
+ return;
+ ssh_set_log_callback(ssh_log_cb);
+ ssh_set_log_level(ssh_libssh_log_level);
+}
+
+static int session_apply_poll_events(ssh_session session, int events) {
+ if (session == NULL)
+ return -1;
+ int revents = 0;
+ if (events & UV_READABLE)
+ revents |= POLLIN;
+ if (events & UV_WRITABLE)
+ revents |= POLLOUT;
+#ifdef UV_DISCONNECT
+ if (events & UV_DISCONNECT)
+ revents |= POLLHUP;
+#endif
+#ifdef UV_PRIORITIZED
+ if (events & UV_PRIORITIZED)
+ revents |= POLLPRI;
+#endif
+ if (revents == 0)
+ return 0;
+ if (ssh_session_handle_poll(session, revents) != SSH_OK)
+ return -1;
+ return 0;
+}
+
+static const char *hostkey_state_str(enum ssh_known_hosts_e state) {
+ switch (state) {
+ case SSH_KNOWN_HOSTS_OK:
+ return "ok";
+ case SSH_KNOWN_HOSTS_UNKNOWN:
+ return "unknown";
+ case SSH_KNOWN_HOSTS_NOT_FOUND:
+ return "not_found";
+ case SSH_KNOWN_HOSTS_CHANGED:
+ return "changed";
+ case SSH_KNOWN_HOSTS_OTHER:
+ return "other";
+ case SSH_KNOWN_HOSTS_ERROR:
+ default:
+ return "error";
+ }
+}
+
+static ssh_client_ctx *client_from_actor(sshQ_Client self) {
+ if (self == NULL)
+ return NULL;
+ if (self->_client == NULL)
+ return NULL;
+ unsigned long ptr = fromB_u64(self->_client);
+ if (ptr == 0)
+ return NULL;
+ return (ssh_client_ctx *)ptr;
+}
+
+static ssh_channel_ctx *channel_from_actor(sshQ_Channel channel) {
+ if (channel == NULL)
+ return NULL;
+ if (channel->_channel_id == NULL)
+ return NULL;
+ unsigned long ptr = fromB_u64(channel->_channel_id);
+ if (ptr == 0)
+ return NULL;
+ return (ssh_channel_ctx *)ptr;
+}
+
+static ssh_server_ctx *server_from_actor(sshQ_Server self) {
+ if (self == NULL)
+ return NULL;
+ if (self->_server == NULL)
+ return NULL;
+ unsigned long ptr = fromB_u64(self->_server);
+ if (ptr == 0)
+ return NULL;
+ return (ssh_server_ctx *)ptr;
+}
+
+static ssh_server_session_ctx *session_from_pending_token(ssh_server_ctx *server, B_u64 session_id) {
+ if (server == NULL)
+ return NULL;
+ uint64_t token = fromB_u64(session_id);
+ if (token == 0)
+ return NULL;
+ ssh_server_session_ctx *cur = server->sessions;
+ while (cur != NULL) {
+ if (cur->pending_id == token)
+ return cur;
+ cur = cur->next;
+ }
+ return NULL;
+}
+
+static ssh_server_session_ctx *session_from_actor(sshQ_ServerSession self) {
+ if (self == NULL)
+ return NULL;
+ if (self->_session_id == NULL)
+ return NULL;
+ unsigned long ptr = fromB_u64(self->_session_id);
+ if (ptr == 0)
+ return NULL;
+ return (ssh_server_session_ctx *)ptr;
+}
+
+static ssh_server_channel_ctx *server_channel_from_actor(sshQ_ServerChannel channel) {
+ if (channel == NULL)
+ return NULL;
+ if (channel->_channel_id == NULL)
+ return NULL;
+ unsigned long ptr = fromB_u64(channel->_channel_id);
+ if (ptr == 0)
+ return NULL;
+ return (ssh_server_channel_ctx *)ptr;
+}
+
+static ssh_server_channel_ctx *server_channel_from_ssh(ssh_server_session_ctx *s, ssh_channel chan) {
+ if (s == NULL || chan == NULL)
+ return NULL;
+ ssh_server_channel_ctx *cur = s->channels;
+ while (cur != NULL) {
+ if (cur->channel == chan)
+ return cur;
+ cur = cur->next;
+ }
+ return NULL;
+}
+
+static void close_poll(uv_poll_t **poll, uv_close_cb close_cb) {
+ if (*poll != NULL) {
+ if (uv_is_closing((uv_handle_t *)*poll))
+ return;
+ uv_poll_stop(*poll);
+ uv_close((uv_handle_t *)*poll, close_cb);
+ }
+}
+
+static void stop_timer(uv_timer_t **timer, uv_close_cb close_cb) {
+ if (*timer != NULL) {
+ if (uv_is_closing((uv_handle_t *)*timer))
+ return;
+ uv_timer_stop(*timer);
+ uv_close((uv_handle_t *)*timer, close_cb);
+ }
+}
+
+static int fd_has_data(int fd) {
+ if (fd < 0)
+ return 0;
+ char byte;
+ ssize_t rc;
+ do {
+ rc = recv(fd, &byte, 1, MSG_PEEK | MSG_DONTWAIT);
+ } while (rc < 0 && errno == EINTR);
+ if (rc > 0)
+ return 1;
+ if (rc == 0)
+ return 1;
+ if (errno == EAGAIN || errno == EWOULDBLOCK)
+ return 0;
+ if (errno == ECONNRESET || errno == ECONNABORTED || errno == ENOTCONN)
+ return 1;
+ if (errno == EBADF || errno == ENOTSOCK || errno == EINVAL)
+ return 0;
+ return 1;
+}
+
+static int fd_can_write(int fd) {
+ if (fd < 0)
+ return 0;
+ struct pollfd pfd;
+ pfd.fd = fd;
+ pfd.events = POLLOUT;
+ pfd.revents = 0;
+ int rc;
+ do {
+ rc = poll(&pfd, 1, 0);
+ } while (rc < 0 && errno == EINTR);
+ if (rc <= 0)
+ return 0;
+ if (pfd.revents & POLLOUT)
+ return 1;
+ if (pfd.revents & POLLNVAL)
+ return 0;
+ return 0;
+}
+
+static int fd_set_nonblocking(int fd) {
+ if (fd < 0)
+ return -1;
+ int flags = fcntl(fd, F_GETFL, 0);
+ if (flags < 0)
+ return -1;
+ if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0)
+ return -1;
+ return 0;
+}
+
+static void format_session_error(ssh_session session, const char *prefix,
+ char *buf, size_t buflen) {
+ const char *err = NULL;
+ if (session != NULL)
+ err = ssh_get_error(session);
+ if (err != NULL && err[0] != '\0')
+ snprintf(buf, buflen, "%s: %s", prefix, err);
+ else
+ snprintf(buf, buflen, "%s", prefix);
+}
+
+static int session_has_pending_write(ssh_session session) {
+ if (session == NULL)
+ return 0;
+ int pending = ssh_get_status(session) | ssh_get_poll_flags(session);
+ return (pending & SSH_WRITE_PENDING) != 0;
+}
+
+static uint64_t next_pending_session_id = 1;
+
+static uint64_t alloc_pending_session_id(void) {
+ return __atomic_fetch_add(&next_pending_session_id, 1, __ATOMIC_RELAXED);
+}
+
+static void client_retire_channel(ssh_client_ctx *c, ssh_channel_ctx *ch) {
+ if (c == NULL || ch == NULL)
+ return;
+ ch->next = c->retired_channels;
+ c->retired_channels = ch;
+}
+
+static void client_free_retired_channels(ssh_client_ctx *c) {
+ if (c == NULL)
+ return;
+ ssh_channel_ctx *ch = c->retired_channels;
+ c->retired_channels = NULL;
+ while (ch != NULL) {
+ ssh_channel_ctx *next = ch->next;
+ ch->next = NULL;
+ acton_free(ch);
+ ch = next;
+ }
+}
+
+static void session_retire_channel(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) {
+ if (s == NULL || ch == NULL)
+ return;
+ ch->next = s->retired_channels;
+ s->retired_channels = ch;
+}
+
+static void session_free_retired_channels(ssh_server_session_ctx *s) {
+ if (s == NULL)
+ return;
+ ssh_server_channel_ctx *ch = s->retired_channels;
+ s->retired_channels = NULL;
+ while (ch != NULL) {
+ ssh_server_channel_ctx *next = ch->next;
+ ch->next = NULL;
+ acton_free(ch);
+ ch = next;
+ }
+}
+
+static void client_poll_close_cb(uv_handle_t *handle) {
+ ssh_client_ctx *c = (ssh_client_ctx *)handle->data;
+ if (c == NULL) {
+ acton_free(handle);
+ return;
+ }
+ if (c->poll == (uv_poll_t *)handle) {
+ c->poll = NULL;
+ c->poll_events = 0;
+ }
+ if (c->state == CLIENT_STATE_CLOSING) {
+ client_finalize(c);
+ }
+ acton_free(handle);
+}
+
+static void client_timer_close_cb(uv_handle_t *handle) {
+ ssh_client_ctx *c = (ssh_client_ctx *)handle->data;
+ if (c == NULL) {
+ acton_free(handle);
+ return;
+ }
+ if ((uv_timer_t *)handle == c->connect_timer)
+ c->connect_timer = NULL;
+ if ((uv_timer_t *)handle == c->auth_timer)
+ c->auth_timer = NULL;
+ if ((uv_timer_t *)handle == c->keepalive_timer)
+ c->keepalive_timer = NULL;
+ client_maybe_release(c);
+ acton_free(handle);
+}
+
+static void server_poll_close_cb(uv_handle_t *handle) {
+ ssh_server_ctx *s = (ssh_server_ctx *)handle->data;
+ if (s == NULL) {
+ acton_free(handle);
+ return;
+ }
+ if (s->poll == (uv_poll_t *)handle)
+ s->poll = NULL;
+ if (s->state == SERVER_STATE_CLOSING) {
+ server_finalize(s);
+ }
+ acton_free(handle);
+}
+
+static void session_poll_close_cb(uv_handle_t *handle) {
+ ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data;
+ if (s == NULL) {
+ acton_free(handle);
+ return;
+ }
+ if (s->poll == (uv_poll_t *)handle) {
+ s->poll = NULL;
+ s->poll_events = 0;
+ }
+ if (s->state == SESSION_STATE_CLOSING) {
+ session_finalize(s);
+ }
+ acton_free(handle);
+}
+
+static void session_timer_close_cb(uv_handle_t *handle) {
+ ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data;
+ if (s == NULL) {
+ acton_free(handle);
+ return;
+ }
+ if ((uv_timer_t *)handle == s->attach_timer)
+ s->attach_timer = NULL;
+ if ((uv_timer_t *)handle == s->auth_timer)
+ s->auth_timer = NULL;
+ if ((uv_timer_t *)handle == s->keepalive_timer)
+ s->keepalive_timer = NULL;
+ session_maybe_release(s);
+ acton_free(handle);
+}
+
+static void client_maybe_release(ssh_client_ctx *c) {
+ if (c == NULL || !c->close_finalized)
+ return;
+ if (c->poll != NULL || c->connect_timer != NULL ||
+ c->auth_timer != NULL || c->keepalive_timer != NULL)
+ return;
+ if (c->close_reason != NULL) {
+ acton_free(c->close_reason);
+ c->close_reason = NULL;
+ }
+ acton_free(c);
+}
+
+static void server_maybe_release(ssh_server_ctx *s) {
+ if (s == NULL || !s->close_finalized)
+ return;
+ if (s->poll != NULL || s->sessions != NULL)
+ return;
+ if (s->close_reason != NULL) {
+ acton_free(s->close_reason);
+ s->close_reason = NULL;
+ }
+ acton_free(s);
+}
+
+static void session_maybe_release(ssh_server_session_ctx *s) {
+ if (s == NULL || !s->close_finalized)
+ return;
+ if (s->poll != NULL || s->attach_timer != NULL ||
+ s->auth_timer != NULL || s->keepalive_timer != NULL)
+ return;
+ if (s->close_reason != NULL) {
+ acton_free(s->close_reason);
+ s->close_reason = NULL;
+ }
+ acton_free(s);
+}
+
+static int server_session_count(ssh_server_ctx *s) {
+ int count = 0;
+ if (s == NULL)
+ return 0;
+ ssh_server_session_ctx *sess = s->sessions;
+ while (sess != NULL) {
+ count++;
+ sess = sess->next;
+ }
+ return count;
+}
+
+static int session_channel_count(ssh_server_session_ctx *s) {
+ int count = 0;
+ if (s == NULL)
+ return 0;
+ ssh_server_channel_ctx *ch = s->channels;
+ while (ch != NULL) {
+ count++;
+ ch = ch->next;
+ }
+ return count;
+}
+
+static int server_session_limit_reached(ssh_server_ctx *s) {
+ if (s == NULL || s->max_sessions <= 0)
+ return 0;
+ return server_session_count(s) >= s->max_sessions;
+}
+
+static int session_channel_limit_reached(ssh_server_session_ctx *s) {
+ if (s == NULL || s->server == NULL || s->server->max_channels_per_session <= 0)
+ return 0;
+ return session_channel_count(s) >= s->server->max_channels_per_session;
+}
+
+static void client_notify_connect(ssh_client_ctx *c, const char *err) {
+ if (c == NULL)
+ return;
+ if (c->connect_notified)
+ return;
+ sshQ_Client actor = client_actor_ref(c);
+ if (c->on_connect) {
+ $action2 f = ($action2)c->on_connect;
+ f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None);
+ }
+ c->connect_notified = 1;
+ if (err == NULL)
+ c->connected_ok = 1;
+}
+
+static void client_notify_close(ssh_client_ctx *c, const char *reason) {
+ if (c->close_notified)
+ return;
+ if (!c->connected_ok)
+ return;
+ sshQ_Client actor = client_actor_ref(c);
+ if (c->on_close) {
+ $action2 f = ($action2)c->on_close;
+ f->$class->__asyn__(f, actor, to$str((char *)reason));
+ }
+ c->close_notified = 1;
+}
+
+static void client_fail(ssh_client_ctx *c, const char *msg) {
+ if (c == NULL || c->state == CLIENT_STATE_CLOSED || c->state == CLIENT_STATE_ERROR)
+ return;
+ c->state = CLIENT_STATE_ERROR;
+ if (!c->connected_ok)
+ client_notify_connect(c, msg);
+ client_close_internal(c, msg, 1);
+}
+
+static void channel_notify_open(ssh_channel_ctx *ch, const char *err) {
+ if (ch->open_notified)
+ return;
+ sshQ_Channel actor = channel_actor_ref(ch);
+ if (ch->on_open) {
+ $action2 f = ($action2)ch->on_open;
+ f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None);
+ }
+ ch->open_notified = 1;
+ if (err == NULL)
+ ch->open_succeeded = 1;
+}
+
+static void channel_notify_close(ssh_channel_ctx *ch, const char *reason) {
+ if (ch->close_notified)
+ return;
+ if (!ch->open_succeeded) {
+ ch->close_notified = 1;
+ return;
+ }
+ sshQ_Channel actor = channel_actor_ref(ch);
+ if (ch->on_close) {
+ $action2 f = ($action2)ch->on_close;
+ f->$class->__asyn__(f, actor, to$str((char *)reason));
+ }
+ ch->close_notified = 1;
+}
+
+static void channel_notify_error(ssh_channel_ctx *ch, const char *msg) {
+ if (ch->open_succeeded)
+ channel_notify_close(ch, msg);
+ else
+ channel_notify_open(ch, msg);
+}
+
+static void channel_notify_exit(ssh_channel_ctx *ch, int exit_status, B_str signal) {
+ if (ch->exit_sent)
+ return;
+ sshQ_Channel actor = channel_actor_ref(ch);
+ if (ch->on_exit) {
+ $action3 f = ($action3)ch->on_exit;
+ f->$class->__asyn__(f, actor, toB_int(exit_status), signal);
+ }
+ ch->exit_sent = 1;
+}
+
+static int client_channel_data_cb(ssh_session session, ssh_channel channel, void *data,
+ uint32_t len, int is_stderr, void *userdata) {
+ ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata;
+ (void)session;
+ (void)channel;
+ if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR)
+ return 0;
+ if (len == 0)
+ return 0;
+ B_bytes out = to$bytesD_len((char *)data, (size_t)len);
+ sshQ_Channel actor = channel_actor_ref(ch);
+ if (is_stderr) {
+ if (ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, out);
+ }
+ } else {
+ if (ch->on_stdout) {
+ $action2 f = ($action2)ch->on_stdout;
+ f->$class->__asyn__(f, actor, out);
+ }
+ }
+ return (int)len;
+}
+
+static void client_channel_eof_cb(ssh_session session, ssh_channel channel, void *userdata) {
+ ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata;
+ (void)session;
+ (void)channel;
+ if (ch == NULL)
+ return;
+ sshQ_Channel actor = channel_actor_ref(ch);
+ if (!ch->stdout_eof && ch->on_stdout) {
+ $action2 f = ($action2)ch->on_stdout;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stdout_eof = 1;
+ }
+ if (!ch->stderr_eof && ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stderr_eof = 1;
+ }
+}
+
+static void client_channel_close_cb(ssh_session session, ssh_channel channel, void *userdata) {
+ ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata;
+ (void)session;
+ (void)channel;
+ if (ch == NULL)
+ return;
+ ch->remote_close_seen = 1;
+}
+
+static int client_channel_write_wontblock_cb(ssh_session session, ssh_channel channel,
+ uint32_t bytes, void *userdata) {
+ ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata;
+ (void)session;
+ (void)channel;
+ if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR)
+ return 0;
+ ch->write_wontblock = bytes > 0 ? 1 : 0;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client channel write_wontblock: bytes=%u ch=%p", bytes, (void *)ch);
+ }
+ return 0;
+}
+
+static int client_channel_setup_callbacks(ssh_channel_ctx *ch) {
+ if (ch == NULL || ch->channel == NULL)
+ return SSH_ERROR;
+ if (ch->callbacks != NULL)
+ return SSH_OK;
+ struct ssh_channel_callbacks_struct *cb = acton_calloc(1, sizeof(*cb));
+ ssh_callbacks_init(cb);
+ cb->userdata = ch;
+ cb->channel_data_function = client_channel_data_cb;
+ cb->channel_eof_function = client_channel_eof_cb;
+ cb->channel_close_function = client_channel_close_cb;
+ cb->channel_write_wontblock_function = client_channel_write_wontblock_cb;
+ if (ssh_add_channel_callbacks(ch->channel, cb) != SSH_OK) {
+ acton_free(cb);
+ return SSH_ERROR;
+ }
+ ch->callbacks = cb;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client channel callbacks set ch=%p", (void *)ch);
+ }
+ return SSH_OK;
+}
+
+static void channel_notify_eof(ssh_channel_ctx *ch) {
+ if (ch->channel == NULL)
+ return;
+ if (ssh_channel_is_eof(ch->channel)) {
+ sshQ_Channel actor = channel_actor_ref(ch);
+ if (!ch->stdout_eof && ch->on_stdout) {
+ $action2 f = ($action2)ch->on_stdout;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stdout_eof = 1;
+ }
+ if (!ch->stderr_eof && ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stderr_eof = 1;
+ }
+ }
+}
+
+static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) {
+ int exit_status = -1;
+ B_str exit_signal = B_None;
+ sshQ_Channel actor = channel_actor_ref(ch);
+
+ while (ch->write_head != NULL) {
+ write_chunk_t *chunk = ch->write_head;
+ ch->write_head = chunk->next;
+ acton_free(chunk);
+ }
+ ch->write_tail = NULL;
+ if (ch->pending_req != CHAN_REQ_NONE) {
+ ch->pending_req = CHAN_REQ_NONE;
+ channel_notify_error(ch, "SSH channel request failed: channel closed");
+ }
+ if (ch->channel != NULL) {
+ if (ssh_channel_is_closed(ch->channel)) {
+ uint32_t exit_code = 0;
+ char *signal = NULL;
+ int core_dumped = 0;
+ int rc = ssh_channel_get_exit_state(ch->channel, &exit_code, &signal, &core_dumped);
+ if (rc == SSH_OK) {
+ exit_status = (int)exit_code;
+ if (signal != NULL)
+ exit_signal = to$str(signal);
+ }
+ if (signal)
+ ssh_string_free_char(signal);
+ (void)core_dumped;
+ }
+ if (ch->callbacks) {
+ ssh_remove_channel_callbacks(ch->channel, ch->callbacks);
+ acton_free(ch->callbacks);
+ ch->callbacks = NULL;
+ }
+ ssh_channel_free(ch->channel);
+ ch->channel = NULL;
+ }
+ ch->state = CHAN_STATE_CLOSED;
+ if (actor)
+ actor->_channel_id = toB_u64(0);
+
+ channel_notify_exit(ch, exit_status, exit_signal);
+ if (!ch->stdout_eof && ch->on_stdout) {
+ $action2 f = ($action2)ch->on_stdout;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stdout_eof = 1;
+ }
+ if (!ch->stderr_eof && ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stderr_eof = 1;
+ }
+ channel_notify_close(ch, "closed");
+ ch->actor = NULL;
+ (void)c;
+}
+
+static void channel_fail(ssh_client_ctx *c, ssh_channel_ctx *ch, const char *msg) {
+ if (ch->state == CHAN_STATE_ERROR || ch->state == CHAN_STATE_CLOSED)
+ return;
+ ch->state = CHAN_STATE_ERROR;
+ channel_notify_error(ch, msg);
+ if (ch->channel) {
+ ssh_channel_close(ch->channel);
+ }
+ channel_finalize(c, ch);
+}
+
+static void channel_queue_write(ssh_channel_ctx *ch, B_bytes data) {
+ write_chunk_t *chunk = acton_calloc(1, sizeof(write_chunk_t));
+ chunk->data = data;
+ chunk->offset = 0;
+ chunk->next = NULL;
+ if (ch->write_tail) {
+ ch->write_tail->next = chunk;
+ } else {
+ ch->write_head = chunk;
+ }
+ ch->write_tail = chunk;
+}
+
+static void channel_try_write(ssh_client_ctx *c, ssh_channel_ctx *ch) {
+ while (ch->write_head != NULL && ch->write_head->data->nbytes == ch->write_head->offset) {
+ write_chunk_t *chunk = ch->write_head;
+ ch->write_head = chunk->next;
+ acton_free(chunk);
+ if (ch->write_head == NULL)
+ ch->write_tail = NULL;
+ }
+
+ if (ch->write_head == NULL || !ch->write_wontblock || session_has_pending_write(c->session))
+ return;
+
+ write_chunk_t *chunk = ch->write_head;
+ size_t remaining = chunk->data->nbytes - chunk->offset;
+ ch->write_wontblock = 0;
+ int rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining);
+ if (rc > 0) {
+ chunk->offset += (size_t)rc;
+ if (chunk->offset >= chunk->data->nbytes) {
+ ch->write_head = chunk->next;
+ acton_free(chunk);
+ if (ch->write_head == NULL)
+ ch->write_tail = NULL;
+ }
+ } else if (rc == 0 || rc == SSH_AGAIN) {
+ c->write_ready = 0;
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH channel write error: %s", ssh_get_error(c->session));
+ channel_fail(c, ch, errmsg);
+ return;
+ }
+}
+
+static int channel_read_stream(ssh_client_ctx *c, ssh_channel_ctx *ch, int is_stderr) {
+ char buf[SSH_READ_BUFSIZE];
+ int read_any = 0;
+ for (;;) {
+ int n = ssh_channel_read_buffered(ch->channel, buf, sizeof(buf), is_stderr);
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client channel read: rc=%d stderr=%d", n, is_stderr);
+ }
+ if (n > 0) {
+ read_any = 1;
+ B_bytes out = to$bytesD_len(buf, n);
+ sshQ_Channel actor = channel_actor_ref(ch);
+ if (is_stderr) {
+ if (ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, out);
+ }
+ } else {
+ if (ch->on_stdout) {
+ $action2 f = ($action2)ch->on_stdout;
+ f->$class->__asyn__(f, actor, out);
+ }
+ }
+ continue;
+ }
+ if (n == 0 || n == SSH_AGAIN) {
+ break;
+ }
+ if (n == SSH_EOF) {
+ break;
+ }
+ if (n == SSH_ERROR) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH channel read error: %s", ssh_get_error(c->session));
+ channel_fail(c, ch, errmsg);
+ break;
+ }
+ }
+ return read_any;
+}
+
+static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) {
+ if (ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR)
+ return;
+
+ if (ch->state == CHAN_STATE_INIT) {
+ ch->channel = ssh_channel_new(c->session);
+ if (ch->channel == NULL) {
+ channel_fail(c, ch, "Failed to create SSH channel");
+ return;
+ }
+ if (client_channel_setup_callbacks(ch) != SSH_OK) {
+ channel_fail(c, ch, "Failed to set SSH channel callbacks");
+ return;
+ }
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client channel new ch=%p callbacks=%p", (void *)ch, (void *)ch->callbacks);
+ }
+ ch->state = CHAN_STATE_OPENING;
+ }
+
+ if (ch->state == CHAN_STATE_OPENING) {
+ int rc = ssh_channel_open_session(ch->channel);
+ if (rc == SSH_OK) {
+ ch->state = CHAN_STATE_OPEN;
+ channel_notify_open(ch, NULL);
+ } else if (rc == SSH_AGAIN) {
+ c->write_ready = 0;
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "Failed to open SSH channel: %s", ssh_get_error(c->session));
+ channel_fail(c, ch, errmsg);
+ return;
+ }
+ }
+
+ if (ch->state == CHAN_STATE_OPEN || ch->state == CHAN_STATE_RUNNING) {
+ if (ch->pty_pending && !ch->pty_done) {
+ const char *term = ch->term ? (const char *)fromB_str(ch->term) : "xterm-256color";
+ int rc = ssh_channel_request_pty_size(ch->channel, term, ch->cols, ch->rows);
+ if (rc == SSH_OK) {
+ ch->pty_done = 1;
+ } else if (rc == SSH_AGAIN) {
+ c->write_ready = 0;
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "Failed to request PTY: %s", ssh_get_error(c->session));
+ channel_fail(c, ch, errmsg);
+ return;
+ }
+ }
+
+ if (ch->pending_req != CHAN_REQ_NONE) {
+ int rc = SSH_ERROR;
+ if (ch->pending_req == CHAN_REQ_SHELL) {
+ rc = ssh_channel_request_shell(ch->channel);
+ } else if (ch->pending_req == CHAN_REQ_EXEC) {
+ rc = ssh_channel_request_exec(ch->channel, (const char *)fromB_str(ch->exec_cmd));
+ } else if (ch->pending_req == CHAN_REQ_SUBSYSTEM) {
+ rc = ssh_channel_request_subsystem(ch->channel, (const char *)fromB_str(ch->subsystem));
+ }
+
+ if (rc == SSH_OK) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client channel request ok");
+ }
+ ch->pending_req = CHAN_REQ_NONE;
+ ch->state = CHAN_STATE_RUNNING;
+ } else if (rc == SSH_AGAIN) {
+ c->write_ready = 0;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client channel request again");
+ }
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH channel request failed: %s", ssh_get_error(c->session));
+ channel_fail(c, ch, errmsg);
+ return;
+ }
+ }
+ }
+
+ if (ch->state == CHAN_STATE_OPEN || ch->state == CHAN_STATE_RUNNING) {
+ if (ch->write_head != NULL) {
+ channel_try_write(c, ch);
+ if (ch->state == CHAN_STATE_ERROR)
+ return;
+ }
+
+ if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL &&
+ !session_has_pending_write(c->session)) {
+ int rc = ssh_channel_send_eof(ch->channel);
+ if (rc == SSH_OK) {
+ ch->eof_sent = 1;
+ } else if (rc == SSH_AGAIN) {
+ c->write_ready = 0;
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "Failed to send EOF: %s", ssh_get_error(c->session));
+ channel_fail(c, ch, errmsg);
+ return;
+ }
+ }
+
+ if (ch->close_requested && !ch->close_sent && ch->write_head == NULL &&
+ (!ch->send_eof || ch->eof_sent) &&
+ !session_has_pending_write(c->session)) {
+ int rc = ssh_channel_close(ch->channel);
+ if (rc == SSH_OK) {
+ ch->close_sent = 1;
+ ch->state = CHAN_STATE_CLOSING;
+ } else if (rc == SSH_AGAIN) {
+ c->write_ready = 0;
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "Failed to close channel: %s", ssh_get_error(c->session));
+ channel_fail(c, ch, errmsg);
+ return;
+ }
+ }
+
+ if (ch->callbacks == NULL) {
+ for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) {
+ int did = 0;
+ did |= channel_read_stream(c, ch, 0);
+ did |= channel_read_stream(c, ch, 1);
+ if (!did)
+ break;
+ }
+ }
+ }
+
+ if (ch->channel != NULL && ch->remote_close_seen &&
+ ssh_channel_is_closed(ch->channel)) {
+ channel_finalize(c, ch);
+ }
+}
+
+static void client_drive_channels(ssh_client_ctx *c) {
+ ssh_channel_ctx *prev = NULL;
+ ssh_channel_ctx *ch = c->channels;
+ while (ch != NULL) {
+ ssh_channel_ctx *next = ch->next;
+ channel_drive(c, ch);
+ if (ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) {
+ if (prev != NULL) {
+ prev->next = next;
+ } else {
+ c->channels = next;
+ }
+ client_retire_channel(c, ch);
+ } else {
+ prev = ch;
+ }
+ ch = next;
+ }
+}
+
+static int client_get_hostkey_info(ssh_client_ctx *c, B_str *key_type_out, B_str *fingerprint_out) {
+ ssh_key key = NULL;
+ unsigned char *hash = NULL;
+ size_t hash_len = 0;
+ char *fingerprint = NULL;
+
+ int rc = ssh_get_server_publickey(c->session, &key);
+ if (rc != SSH_OK || key == NULL)
+ return -1;
+
+ enum ssh_keytypes_e key_type = ssh_key_type(key);
+ const char *key_type_str = ssh_key_type_to_char(key_type);
+
+ rc = ssh_get_publickey_hash(key, SSH_PUBLICKEY_HASH_SHA256, &hash, &hash_len);
+ if (rc != SSH_OK) {
+ ssh_key_free(key);
+ return -1;
+ }
+
+ fingerprint = ssh_get_fingerprint_hash(SSH_PUBLICKEY_HASH_SHA256, hash, hash_len);
+
+ if (key_type_out)
+ *key_type_out = to$str((char *)(key_type_str ? key_type_str : ""));
+ if (fingerprint_out)
+ *fingerprint_out = to$str((char *)(fingerprint ? fingerprint : ""));
+
+ ssh_clean_pubkey_hash(&hash);
+ ssh_key_free(key);
+ if (fingerprint)
+ ssh_string_free_char(fingerprint);
+
+ return 0;
+}
+
+static int client_check_hostkey(ssh_client_ctx *c) {
+ enum ssh_known_hosts_e state = SSH_KNOWN_HOSTS_UNKNOWN;
+ int use_known_hosts = 0;
+ sshQ_Client actor = client_actor_ref(c);
+
+ if (actor != NULL && actor->known_hosts != B_None)
+ use_known_hosts = 1;
+
+ if (use_known_hosts) {
+ state = ssh_session_is_known_server(c->session);
+ if (state == SSH_KNOWN_HOSTS_OK)
+ return 0;
+
+ if (state == SSH_KNOWN_HOSTS_ERROR) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "Host key check error: %s", ssh_get_error(c->session));
+ client_fail(c, errmsg);
+ return -1;
+ }
+ }
+
+ c->hostkey_state = state;
+
+ if (!c->on_hostkey) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "Host key not accepted (%s)", hostkey_state_str(state));
+ client_fail(c, errmsg);
+ return -1;
+ }
+
+ B_str key_type = to$str((char *)"");
+ B_str fingerprint = to$str((char *)"");
+ if (client_get_hostkey_info(c, &key_type, &fingerprint) != 0) {
+ key_type = to$str((char *)"");
+ fingerprint = to$str((char *)"");
+ }
+
+ sshQ_HostKeyInfo info = sshQ_HostKeyInfoG_new(key_type, fingerprint);
+ $action3 f = ($action3)c->on_hostkey;
+ f->$class->__asyn__(f, actor, to$str((char *)hostkey_state_str(state)), info);
+ return 1;
+}
+
+static void connect_timeout_cb(uv_timer_t *timer) {
+ ssh_client_ctx *c = (ssh_client_ctx *)timer->data;
+ if (c == NULL)
+ return;
+ if (c->state == CLIENT_STATE_CONNECTING || c->state == CLIENT_STATE_HOSTKEY || c->state == CLIENT_STATE_HOSTKEY_WAIT) {
+ client_fail(c, "SSH connect timeout");
+ }
+}
+
+static void auth_timeout_cb(uv_timer_t *timer) {
+ ssh_client_ctx *c = (ssh_client_ctx *)timer->data;
+ if (c == NULL)
+ return;
+ if (c->state == CLIENT_STATE_AUTH) {
+ client_fail(c, "SSH authentication timeout");
+ }
+}
+
+static void keepalive_cb(uv_timer_t *timer) {
+ ssh_client_ctx *c = (ssh_client_ctx *)timer->data;
+ if (c == NULL)
+ return;
+ if (c->state != CLIENT_STATE_READY || c->session == NULL)
+ return;
+ int rc = ssh_send_ignore(c->session, "keepalive");
+ if (rc == SSH_AGAIN) {
+ client_update_poll(c);
+ return;
+ }
+ if (rc != SSH_OK) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH keepalive failed: %s", ssh_get_error(c->session));
+ client_fail(c, errmsg);
+ return;
+ }
+ client_update_poll(c);
+}
+
+static void client_start_connect_timer(ssh_client_ctx *c) {
+ if (c->connect_timeout <= 0.0 || c->connect_timer != NULL)
+ return;
+ c->connect_timer = acton_calloc(1, sizeof(uv_timer_t));
+ c->connect_timer->data = c;
+ uv_timer_init(get_uv_loop(), c->connect_timer);
+ uv_timer_start(c->connect_timer, connect_timeout_cb, (uint64_t)(c->connect_timeout * 1000), 0);
+}
+
+static void client_start_auth_timer(ssh_client_ctx *c) {
+ if (c->auth_timeout <= 0.0 || c->auth_timer != NULL)
+ return;
+ c->auth_timer = acton_calloc(1, sizeof(uv_timer_t));
+ c->auth_timer->data = c;
+ uv_timer_init(get_uv_loop(), c->auth_timer);
+ uv_timer_start(c->auth_timer, auth_timeout_cb, (uint64_t)(c->auth_timeout * 1000), 0);
+}
+
+static void client_start_keepalive(ssh_client_ctx *c) {
+ if (!c->keepalive_enabled || c->keepalive_interval <= 0.0 || c->keepalive_timer != NULL)
+ return;
+ c->keepalive_timer = acton_calloc(1, sizeof(uv_timer_t));
+ c->keepalive_timer->data = c;
+ uv_timer_init(get_uv_loop(), c->keepalive_timer);
+ uint64_t interval_ms = (uint64_t)(c->keepalive_interval * 1000);
+ uv_timer_start(c->keepalive_timer, keepalive_cb, interval_ms, interval_ms);
+}
+
+static void poll_cb(uv_poll_t *handle, int status, int events) {
+ ssh_client_ctx *c = (ssh_client_ctx *)handle->data;
+ if (c == NULL)
+ return;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client poll: status=%d events=0x%x state=%d", status, events, c->state);
+ }
+ if (status < 0) {
+ char errmsg[256] = {0};
+ uv_strerror_r(status, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg));
+ client_fail(c, errmsg);
+ return;
+ }
+ int libssh_events = 0;
+ if ((events & UV_READABLE) && fd_has_data(c->fd)) {
+ ssh_set_fd_toread(c->session);
+ libssh_events |= UV_READABLE;
+ }
+#ifdef UV_DISCONNECT
+ if (events & UV_DISCONNECT) {
+ ssh_set_fd_toread(c->session);
+ libssh_events |= UV_DISCONNECT;
+ }
+#endif
+ if ((events & UV_WRITABLE) && fd_can_write(c->fd)) {
+ c->write_ready = 1;
+ ssh_set_fd_towrite(c->session);
+ libssh_events |= UV_WRITABLE;
+ }
+ if (session_apply_poll_events(c->session, libssh_events) != 0) {
+ char errmsg[256] = {0};
+ format_session_error(c->session, "SSH poll callback error", errmsg, sizeof(errmsg));
+ client_fail(c, errmsg);
+ return;
+ }
+ client_drive(c);
+ client_pump_io(c);
+ c->write_ready = 0;
+}
+
+static void client_update_poll(ssh_client_ctx *c) {
+ if (c->poll == NULL || c->session == NULL)
+ return;
+ if (c->state == CLIENT_STATE_CLOSED)
+ return;
+ if (uv_is_closing((uv_handle_t *)c->poll))
+ return;
+ int status = ssh_get_status(c->session);
+ if (status & SSH_CLOSED_ERROR) {
+ client_fail(c, "SSH session closed with error");
+ return;
+ }
+ if (status & SSH_CLOSED) {
+ client_close_internal(c, "SSH session closed", 1);
+ return;
+ }
+ int flags = ssh_get_poll_flags(c->session);
+ int pending = flags | status;
+ int events = UV_READABLE;
+#ifdef UV_DISCONNECT
+ events |= UV_DISCONNECT;
+#endif
+ if (pending & SSH_WRITE_PENDING)
+ events |= UV_WRITABLE;
+ if ((events & UV_WRITABLE) == 0 && client_needs_write(c))
+ events |= UV_WRITABLE;
+ if (ssh_debug_enabled && (events != c->poll_events || (pending & SSH_WRITE_PENDING))) {
+ ssh_debug_log("client update poll: status=0x%x flags=0x%x pending=0x%x events=0x%x state=%d",
+ status, flags, pending, events, c->state);
+ }
+ if (events != c->poll_events) {
+ int uv_rc = uv_poll_start(c->poll, events, poll_cb);
+ if (uv_rc != 0) {
+ char errmsg[256] = {0};
+ uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg));
+ client_fail(c, errmsg);
+ return;
+ }
+ c->poll_events = events;
+ }
+}
+
+static void client_pump_io(ssh_client_ctx *c) {
+ if (c == NULL || c->session == NULL)
+ return;
+ int i;
+ for (i = 0; i < SSH_IO_PUMP_LIMIT; i++) {
+ int did = 0;
+ if (c->session == NULL)
+ return;
+ int has_data = fd_has_data(c->fd);
+ if (has_data) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client pump: readable");
+ }
+ ssh_set_fd_toread(c->session);
+ if (session_apply_poll_events(c->session, UV_READABLE) != 0) {
+ char errmsg[256] = {0};
+ format_session_error(c->session, "SSH poll callback error", errmsg, sizeof(errmsg));
+ client_fail(c, errmsg);
+ return;
+ }
+ client_drive(c);
+ did = 1;
+ }
+ if (c->session == NULL)
+ return;
+ if (!did) {
+ int status = ssh_get_status(c->session);
+ if (status & SSH_READ_PENDING) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client pump: buffered read pending=0x%x", status);
+ }
+ client_drive(c);
+ /* Avoid spinning when only buffered data remains. */
+ break;
+ }
+ }
+ if (!did)
+ break;
+ }
+ if (ssh_debug_enabled && i >= SSH_IO_PUMP_LIMIT) {
+ int status = ssh_get_status(c->session);
+ int flags = ssh_get_poll_flags(c->session);
+ ssh_debug_log("client pump: hit limit status=0x%x flags=0x%x", status, flags);
+ }
+}
+
+static int client_needs_write(ssh_client_ctx *c) {
+ if (c == NULL)
+ return 0;
+ ssh_channel_ctx *ch = c->channels;
+ while (ch != NULL) {
+ if (ch->pending_req != CHAN_REQ_NONE || ch->write_head != NULL)
+ return 1;
+ if (ch->send_eof && !ch->eof_sent)
+ return 1;
+ if (ch->close_requested && !ch->close_sent)
+ return 1;
+ ch = ch->next;
+ }
+ return 0;
+}
+
+static void client_on_ready(ssh_client_ctx *c) {
+ c->state = CLIENT_STATE_READY;
+ stop_timer(&c->connect_timer, client_timer_close_cb);
+ stop_timer(&c->auth_timer, client_timer_close_cb);
+ client_notify_connect(c, NULL);
+ client_start_keepalive(c);
+ client_drive_channels(c);
+ client_update_poll(c);
+}
+
+static void client_drive(ssh_client_ctx *c) {
+ if (c == NULL)
+ return;
+ if (c->state == CLIENT_STATE_ERROR || c->state == CLIENT_STATE_CLOSED)
+ return;
+ if (c->state == CLIENT_STATE_CLOSING) {
+ client_drive_channels(c);
+ client_finish_close(c);
+ return;
+ }
+
+ int spin = 0;
+ while (1) {
+ if (c->state == CLIENT_STATE_CONNECTING) {
+ int rc = ssh_connect(c->session);
+ if (ssh_debug_enabled) {
+ log_debug("ssh_connect rc=%d state=%d", rc, c->state);
+ }
+ if (rc == SSH_OK) {
+ stop_timer(&c->connect_timer, client_timer_close_cb);
+ if (fd_set_nonblocking(c->fd) != 0) {
+ client_fail(c, "Failed to restore SSH session fd nonblocking");
+ return;
+ }
+ c->state = CLIENT_STATE_HOSTKEY;
+ continue;
+ } else if (rc == SSH_AGAIN) {
+ int status = ssh_get_status(c->session);
+ if (status & SSH_WRITE_PENDING)
+ c->write_ready = 0;
+ if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client drive: buffered read pending during connect");
+ }
+ continue;
+ }
+ client_update_poll(c);
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH connect failed: %s", ssh_get_error(c->session));
+ client_fail(c, errmsg);
+ return;
+ }
+ }
+
+ if (c->state == CLIENT_STATE_HOSTKEY) {
+ int rc = client_check_hostkey(c);
+ if (rc == 0) {
+ c->state = CLIENT_STATE_AUTH;
+ client_start_auth_timer(c);
+ continue;
+ } else if (rc == 1) {
+ c->state = CLIENT_STATE_HOSTKEY_WAIT;
+ return;
+ } else {
+ return;
+ }
+ }
+
+ if (c->state == CLIENT_STATE_HOSTKEY_WAIT) {
+ client_update_poll(c);
+ return;
+ }
+
+ if (c->state == CLIENT_STATE_AUTH) {
+ sshQ_Client actor = client_actor_ref(c);
+ if (actor == NULL || actor->password == B_None) {
+ client_fail(c, "Password auth requested but no password provided");
+ return;
+ }
+ int rc = ssh_userauth_password(c->session, NULL, (const char *)fromB_str(actor->password));
+ if (rc == SSH_AUTH_SUCCESS) {
+ client_on_ready(c);
+ return;
+ } else if (rc == SSH_AUTH_AGAIN) {
+ int status = ssh_get_status(c->session);
+ if (status & SSH_WRITE_PENDING)
+ c->write_ready = 0;
+ if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client drive: buffered read pending during auth");
+ }
+ continue;
+ }
+ client_update_poll(c);
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH auth failed: %s", ssh_get_error(c->session));
+ client_fail(c, errmsg);
+ return;
+ }
+ }
+
+ if (c->state == CLIENT_STATE_READY) {
+ client_drive_channels(c);
+ client_update_poll(c);
+ return;
+ }
+
+ return;
+ }
+}
+
+static void client_finalize(ssh_client_ctx *c) {
+ if (c == NULL || c->close_finalized)
+ return;
+ c->close_finalized = 1;
+
+ if (c->session != NULL) {
+ ssh_disconnect(c->session);
+ ssh_free(c->session);
+ c->session = NULL;
+ }
+ client_free_retired_channels(c);
+
+ client_notify_close(c, c->close_reason ? c->close_reason : "closed");
+ c->state = CLIENT_STATE_CLOSED;
+ sshQ_Client actor = client_actor_ref(c);
+ if (actor)
+ actor->_client = toB_u64(0);
+ STORE_HIDDEN_PTR(c->actor, NULL);
+ client_maybe_release(c);
+}
+
+static void client_abort_channels(ssh_client_ctx *c, int notify_channel_error) {
+ ssh_channel_ctx *ch = c->channels;
+ while (ch != NULL) {
+ ssh_channel_ctx *next = ch->next;
+ if (notify_channel_error)
+ channel_notify_error(ch, "Session closed");
+ channel_notify_eof(ch);
+ channel_finalize(c, ch);
+ client_retire_channel(c, ch);
+ ch = next;
+ }
+ c->channels = NULL;
+}
+
+static void client_request_channel_close(ssh_client_ctx *c) {
+ ssh_channel_ctx *ch = c->channels;
+ while (ch != NULL) {
+ if (ch->state != CHAN_STATE_CLOSED && ch->state != CHAN_STATE_ERROR) {
+ ch->send_eof = 1;
+ ch->close_requested = 1;
+ }
+ ch = ch->next;
+ }
+}
+
+static void client_finish_close(ssh_client_ctx *c) {
+ if (c == NULL || c->state != CLIENT_STATE_CLOSING)
+ return;
+ if (c->close_force) {
+ if (c->poll != NULL) {
+ close_poll(&c->poll, client_poll_close_cb);
+ c->poll_events = 0;
+ return;
+ }
+ client_finalize(c);
+ return;
+ }
+ if (c->channels != NULL) {
+ client_update_poll(c);
+ return;
+ }
+ if (c->session != NULL && session_has_pending_write(c->session)) {
+ client_update_poll(c);
+ return;
+ }
+ if (c->poll != NULL) {
+ close_poll(&c->poll, client_poll_close_cb);
+ c->poll_events = 0;
+ return;
+ }
+ client_finalize(c);
+}
+
+static void client_close_internal(ssh_client_ctx *c, const char *reason, int force_close) {
+ if (c == NULL || c->state == CLIENT_STATE_CLOSED)
+ return;
+ if (!force_close && c->state != CLIENT_STATE_READY)
+ force_close = 1;
+
+ if (!c->connected_ok && !c->connect_notified) {
+ client_notify_connect(c, reason ? reason : "closed");
+ }
+ if (reason != NULL && c->close_reason == NULL)
+ c->close_reason = acton_strdup(reason);
+
+ if (c->state == CLIENT_STATE_CLOSING) {
+ if (force_close && !c->close_force) {
+ c->close_force = 1;
+ client_abort_channels(c, 1);
+ }
+ client_finish_close(c);
+ return;
+ }
+
+ stop_timer(&c->connect_timer, client_timer_close_cb);
+ stop_timer(&c->auth_timer, client_timer_close_cb);
+ stop_timer(&c->keepalive_timer, client_timer_close_cb);
+
+ c->state = CLIENT_STATE_CLOSING;
+ c->close_force = force_close;
+ if (force_close) {
+ client_abort_channels(c, 1);
+ client_finish_close(c);
+ return;
+ }
+
+ client_request_channel_close(c);
+ client_drive(c);
+}
+
+void sshQ___ext_init__() {
+ const char *dbg_env = getenv("ACTON_SSH_DEBUG");
+ const char *log_env = getenv("ACTON_SSH_LIBSSH_LOG");
+ if (dbg_env != NULL && dbg_env[0] != '\0')
+ ssh_debug_enabled = 1;
+ if (log_env != NULL && log_env[0] != '\0') {
+ ssh_libssh_log_level = parse_libssh_log_level(log_env);
+ }
+ int r = libssh_replace_allocator(acton_malloc,
+ acton_realloc,
+ acton_calloc,
+ acton_free,
+ acton_strdup,
+ acton_strndup);
+ if (r != SSH_OK) {
+ log_warn("SSH allocator replacement failed");
+ }
+ r = ssh_threads_set_callbacks(ssh_threads_get_default());
+ if (r != SSH_OK) {
+ log_warn("SSH thread callbacks setup failed");
+ }
+ r = ssh_init();
+ if (r != SSH_OK) {
+ log_warn("SSH init failed");
+ }
+}
+
+B_str sshQ_version() {
+ return to$str((char *)ssh_version(0));
+}
+
+B_NoneType sshQ__debug(B_str msg) {
+ if (ssh_debug_enabled) {
+ log_info("%s", fromB_str(msg));
+ }
+ return B_None;
+}
+
+$R sshQ_ClientD__pin_affinityG_local(sshQ_Client self, $Cont c$cont) {
+ pin_actor_affinity();
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD__initG_local(sshQ_Client self, $Cont c$cont) {
+ ssh_configure_libssh_logging();
+ ssh_client_ctx *c = acton_calloc(1, sizeof(ssh_client_ctx));
+ STORE_HIDDEN_PTR(c->actor, self);
+ c->on_connect = ($action2)self->on_connect;
+ c->on_close = ($action2)self->on_close;
+ c->on_hostkey = ($action3)self->on_hostkey;
+ c->connect_timeout = fromB_float(self->connect_timeout);
+ c->auth_timeout = fromB_float(self->auth_timeout);
+ c->keepalive_interval = fromB_float(self->keepalive_interval);
+ c->keepalive_enabled = fromB_bool(self->keepalive_enabled) ? 1 : 0;
+
+ self->_client = toB_u64((unsigned long)c);
+
+ c->session = ssh_new();
+ if (c->session == NULL) {
+ client_fail(c, "Failed to create SSH session");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ int rc;
+ int strict = 1;
+ int port = (int)fromB_u16(self->port);
+ bool process_config = false;
+
+ rc = ssh_options_set(c->session, SSH_OPTIONS_PROCESS_CONFIG, &process_config);
+ if (rc != SSH_OK) {
+ client_fail(c, "Failed to disable SSH config processing");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ rc = ssh_options_set(c->session, SSH_OPTIONS_HOST, fromB_str(self->host));
+ if (rc != SSH_OK) {
+ client_fail(c, "Failed to set SSH host");
+ return $R_CONT(c$cont, B_None);
+ }
+ rc = ssh_options_set(c->session, SSH_OPTIONS_PORT, &port);
+ if (rc != SSH_OK) {
+ client_fail(c, "Failed to set SSH port");
+ return $R_CONT(c$cont, B_None);
+ }
+ rc = ssh_options_set(c->session, SSH_OPTIONS_USER, fromB_str(self->username));
+ if (rc != SSH_OK) {
+ client_fail(c, "Failed to set SSH username");
+ return $R_CONT(c$cont, B_None);
+ }
+ if (self->known_hosts != B_None) {
+ const char *known_hosts = (const char *)fromB_str(self->known_hosts);
+ rc = ssh_options_set(c->session, SSH_OPTIONS_KNOWNHOSTS, known_hosts);
+ if (rc != SSH_OK) {
+ client_fail(c, "Failed to set SSH known_hosts path");
+ return $R_CONT(c$cont, B_None);
+ }
+ rc = ssh_options_set(c->session, SSH_OPTIONS_GLOBAL_KNOWNHOSTS, known_hosts);
+ if (rc != SSH_OK) {
+ client_fail(c, "Failed to set SSH global known_hosts path");
+ return $R_CONT(c$cont, B_None);
+ }
+ }
+ rc = ssh_options_set(c->session, SSH_OPTIONS_STRICTHOSTKEYCHECK, &strict);
+ if (rc != SSH_OK) {
+ client_fail(c, "Failed to set SSH strict host key checking");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ ssh_set_blocking(c->session, 0);
+
+ c->state = CLIENT_STATE_CONNECTING;
+ rc = ssh_connect(c->session);
+ if (rc == SSH_OK) {
+ c->state = CLIENT_STATE_HOSTKEY;
+ } else if (rc == SSH_AGAIN) {
+ c->state = CLIENT_STATE_CONNECTING;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH connect failed: %s", ssh_get_error(c->session));
+ client_fail(c, errmsg);
+ return $R_CONT(c$cont, B_None);
+ }
+
+ c->fd = ssh_get_fd(c->session);
+ if (c->fd < 0) {
+ client_fail(c, "Failed to get SSH session fd");
+ return $R_CONT(c$cont, B_None);
+ }
+ if (c->state != CLIENT_STATE_CONNECTING && fd_set_nonblocking(c->fd) != 0) {
+ client_fail(c, "Failed to set SSH session fd nonblocking");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ c->poll = acton_calloc(1, sizeof(uv_poll_t));
+ c->poll->data = c;
+ int uv_rc = uv_poll_init(get_uv_loop(), c->poll, c->fd);
+ if (uv_rc != 0) {
+ char errmsg[256] = {0};
+ uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg));
+ client_fail(c, errmsg);
+ return $R_CONT(c$cont, B_None);
+ }
+ c->poll_events = UV_READABLE | UV_WRITABLE;
+ uv_rc = uv_poll_start(c->poll, c->poll_events, poll_cb);
+ if (uv_rc != 0) {
+ char errmsg[256] = {0};
+ uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg));
+ client_fail(c, errmsg);
+ return $R_CONT(c$cont, B_None);
+ }
+
+ client_start_connect_timer(c);
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_accept_hostkeyG_local(sshQ_Client self, $Cont c$cont) {
+ ssh_client_ctx *c = client_from_actor(self);
+ if (c == NULL)
+ return $R_CONT(c$cont, B_None);
+ if (c->state != CLIENT_STATE_HOSTKEY_WAIT)
+ return $R_CONT(c$cont, B_None);
+
+ c->state = CLIENT_STATE_AUTH;
+ client_start_auth_timer(c);
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_reject_hostkeyG_local(sshQ_Client self, $Cont c$cont, B_str reason) {
+ ssh_client_ctx *c = client_from_actor(self);
+ if (c == NULL)
+ return $R_CONT(c$cont, B_None);
+ if (c->state != CLIENT_STATE_HOSTKEY_WAIT)
+ return $R_CONT(c$cont, B_None);
+
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "Host key rejected: %s", fromB_str(reason));
+ client_fail(c, errmsg);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_closeG_local(sshQ_Client self, $Cont c$cont) {
+ ssh_client_ctx *c = client_from_actor(self);
+ if (c == NULL)
+ return $R_CONT(c$cont, B_None);
+ client_close_internal(c, "closed", 0);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD__cleanup_nativeG_local(sshQ_Client self, $Cont c$cont) {
+ ssh_client_ctx *c = client_from_actor(self);
+ if (c != NULL)
+ client_close_internal(c, "collected", 1);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_channel_createG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel,
+ $action on_open,
+ $action on_stdout,
+ $action on_stderr,
+ $action on_exit,
+ $action on_close) {
+ ssh_client_ctx *c = client_from_actor(self);
+ if (c == NULL) {
+ if (on_open) {
+ $action2 f = ($action2)on_open;
+ f->$class->__asyn__(f, channel, to$str((char *)"Client not initialized"));
+ }
+ return $R_CONT(c$cont, B_None);
+ }
+
+ ssh_channel_ctx *ch = acton_calloc(1, sizeof(ssh_channel_ctx));
+ ch->client = c;
+ ch->actor = channel;
+ ch->callbacks = NULL;
+ ch->state = CHAN_STATE_INIT;
+ ch->pending_req = CHAN_REQ_NONE;
+ ch->pty_pending = 0;
+ ch->pty_done = 0;
+ ch->term = to$str((char *)"xterm-256color");
+ ch->cols = 80;
+ ch->rows = 24;
+ ch->width_px = 0;
+ ch->height_px = 0;
+ ch->on_open = ($action2)on_open;
+ ch->on_close = ($action2)on_close;
+ ch->on_stdout = ($action2)on_stdout;
+ ch->on_stderr = ($action2)on_stderr;
+ ch->on_exit = ($action3)on_exit;
+
+ ch->next = c->channels;
+ c->channels = ch;
+
+ channel->_channel_id = toB_u64((unsigned long)ch);
+
+ if (c->state == CLIENT_STATE_READY)
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) {
+ if (c == NULL || ch == NULL) {
+ return 1;
+ }
+ if (c->state != CLIENT_STATE_READY) {
+ return 1;
+ }
+ if (ch->client != c) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("client channel ownership mismatch: client=%p owner=%p ch=%p",
+ (void *)c, (void *)ch->client, (void *)ch);
+ }
+ return 1;
+ }
+ if (ch->state == CHAN_STATE_ERROR || ch->state == CHAN_STATE_CLOSED) {
+ return 2;
+ }
+ return 0;
+}
+
+static int server_channel_validate(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) {
+ if (s == NULL || ch == NULL) {
+ return -1;
+ }
+ if (s->state != SESSION_STATE_READY) {
+ return -1;
+ }
+ if (ch->session != s) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server channel ownership mismatch: session=%p owner=%p ch=%p",
+ (void *)s, (void *)ch->session, (void *)ch);
+ }
+ return -1;
+ }
+ if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) {
+ return -1;
+ }
+ return 0;
+}
+
+$R sshQ_ClientD_channel_request_execG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str cmd) {
+ ssh_client_ctx *c = client_from_actor(self);
+ ssh_channel_ctx *ch = channel_from_actor(channel);
+ int valid = channel_validate(c, ch);
+ if (valid != 0) {
+ if (valid == 2 && ch != NULL)
+ channel_notify_error(ch, "Channel not ready");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ if (ch->pending_req != CHAN_REQ_NONE) {
+ channel_notify_error(ch, "Channel already has a pending request");
+ return $R_CONT(c$cont, B_None);
+ }
+ if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) {
+ channel_notify_error(ch, "Channel already running");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ ch->exec_cmd = cmd;
+ ch->pending_req = CHAN_REQ_EXEC;
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_channel_request_shellG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str term, B_int cols, B_int rows, B_int width_px, B_int height_px, B_bool with_pty) {
+ ssh_client_ctx *c = client_from_actor(self);
+ ssh_channel_ctx *ch = channel_from_actor(channel);
+ int valid = channel_validate(c, ch);
+ if (valid != 0) {
+ if (valid == 2 && ch != NULL)
+ channel_notify_error(ch, "Channel not ready");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ if (ch->pending_req != CHAN_REQ_NONE) {
+ channel_notify_error(ch, "Channel already has a pending request");
+ return $R_CONT(c$cont, B_None);
+ }
+ if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) {
+ channel_notify_error(ch, "Channel already running");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ ch->term = term;
+ ch->cols = fromB_int(cols);
+ ch->rows = fromB_int(rows);
+ ch->width_px = fromB_int(width_px);
+ ch->height_px = fromB_int(height_px);
+ ch->pty_pending = fromB_bool(with_pty) ? 1 : 0;
+
+ ch->pending_req = CHAN_REQ_SHELL;
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_channel_request_subsystemG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str name) {
+ ssh_client_ctx *c = client_from_actor(self);
+ ssh_channel_ctx *ch = channel_from_actor(channel);
+ int valid = channel_validate(c, ch);
+ if (valid != 0) {
+ if (valid == 2 && ch != NULL)
+ channel_notify_error(ch, "Channel not ready");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ if (ch->pending_req != CHAN_REQ_NONE) {
+ channel_notify_error(ch, "Channel already has a pending request");
+ return $R_CONT(c$cont, B_None);
+ }
+ if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) {
+ channel_notify_error(ch, "Channel already running");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ ch->subsystem = name;
+ ch->pending_req = CHAN_REQ_SUBSYSTEM;
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_channel_writeG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_bytes data) {
+ ssh_client_ctx *c = client_from_actor(self);
+ ssh_channel_ctx *ch = channel_from_actor(channel);
+ int valid = channel_validate(c, ch);
+ if (valid != 0) {
+ if (valid == 2 && ch != NULL)
+ channel_notify_error(ch, "Channel not ready");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ channel_queue_write(ch, data);
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_channel_send_eofG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel) {
+ ssh_client_ctx *c = client_from_actor(self);
+ ssh_channel_ctx *ch = channel_from_actor(channel);
+ if (channel_validate(c, ch) != 0)
+ return $R_CONT(c$cont, B_None);
+
+ ch->send_eof = 1;
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ClientD_channel_closeG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel) {
+ ssh_client_ctx *c = client_from_actor(self);
+ ssh_channel_ctx *ch = channel_from_actor(channel);
+ if (channel_validate(c, ch) != 0)
+ return $R_CONT(c$cont, B_None);
+
+ ch->send_eof = 1;
+ ch->close_requested = 1;
+ client_drive(c);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ChannelD__cleanup_nativeG_local(sshQ_Channel self, $Cont c$cont) {
+ ssh_channel_ctx *ch = channel_from_actor(self);
+ if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR)
+ return $R_CONT(c$cont, B_None);
+ ssh_client_ctx *c = ch->client;
+ if (c == NULL)
+ return $R_CONT(c$cont, B_None);
+ ch->send_eof = 1;
+ ch->close_requested = 1;
+ client_drive(c);
+ return $R_CONT(c$cont, B_None);
+}
+
+// --- Server implementation
+
+static void server_notify_listen(ssh_server_ctx *s, const char *err) {
+ if (s == NULL)
+ return;
+ if (s->listen_notified)
+ return;
+ sshQ_Server actor = server_actor_ref(s);
+ if (s->on_listen) {
+ $action2 f = ($action2)s->on_listen;
+ f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None);
+ }
+ s->listen_notified = 1;
+ if (err == NULL)
+ s->listen_ok = 1;
+}
+
+static void server_notify_close(ssh_server_ctx *s, const char *reason) {
+ if (s->close_notified)
+ return;
+ if (!s->listen_ok)
+ return;
+ sshQ_Server actor = server_actor_ref(s);
+ if (s->on_close) {
+ $action2 f = ($action2)s->on_close;
+ f->$class->__asyn__(f, actor, to$str((char *)reason));
+ }
+ s->close_notified = 1;
+}
+
+static void session_notify_close(ssh_server_session_ctx *s, const char *reason) {
+ if (s->close_notified)
+ return;
+ sshQ_ServerSession actor = session_actor_ref(s);
+ if (s->on_close) {
+ $action2 f = ($action2)s->on_close;
+ f->$class->__asyn__(f, actor, to$str((char *)reason));
+ }
+ s->close_notified = 1;
+}
+
+static void server_channel_notify_close(ssh_server_channel_ctx *ch, const char *reason) {
+ if (ch->close_notified)
+ return;
+ sshQ_ServerChannel actor = server_channel_actor_ref(ch);
+ if (ch->on_close) {
+ $action2 f = ($action2)ch->on_close;
+ f->$class->__asyn__(f, actor, to$str((char *)reason));
+ }
+ ch->close_notified = 1;
+}
+
+static int server_channel_data_cb(ssh_session session, ssh_channel channel, void *data,
+ uint32_t len, int is_stderr, void *userdata) {
+ ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata;
+ (void)session;
+ (void)channel;
+ if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR)
+ return 0;
+ if (len == 0)
+ return 0;
+ B_bytes out = to$bytesD_len((char *)data, (size_t)len);
+ sshQ_ServerChannel actor = server_channel_actor_ref(ch);
+ if (is_stderr) {
+ if (ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, out);
+ }
+ } else {
+ if (ch->on_data) {
+ $action2 f = ($action2)ch->on_data;
+ f->$class->__asyn__(f, actor, out);
+ }
+ }
+ return (int)len;
+}
+
+static void server_channel_eof_cb(ssh_session session, ssh_channel channel, void *userdata) {
+ ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata;
+ (void)session;
+ (void)channel;
+ if (ch == NULL)
+ return;
+ sshQ_ServerChannel actor = server_channel_actor_ref(ch);
+ if (!ch->stdout_eof && ch->on_data) {
+ $action2 f = ($action2)ch->on_data;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stdout_eof = 1;
+ }
+ if (!ch->stderr_eof && ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stderr_eof = 1;
+ }
+}
+
+static void server_channel_close_cb(ssh_session session, ssh_channel channel, void *userdata) {
+ ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata;
+ (void)session;
+ (void)channel;
+ if (ch == NULL)
+ return;
+ ch->remote_close_seen = 1;
+}
+
+static int server_channel_write_wontblock_cb(ssh_session session, ssh_channel channel,
+ uint32_t bytes, void *userdata) {
+ ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata;
+ (void)session;
+ (void)channel;
+ if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR)
+ return 0;
+ ch->write_wontblock = bytes > 0 ? 1 : 0;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server channel write_wontblock: bytes=%u ch=%p", bytes, (void *)ch);
+ }
+ return 0;
+}
+
+static int server_channel_setup_callbacks(ssh_server_channel_ctx *ch) {
+ if (ch == NULL || ch->channel == NULL)
+ return SSH_ERROR;
+ if (ch->callbacks != NULL)
+ return SSH_OK;
+ struct ssh_channel_callbacks_struct *cb = acton_calloc(1, sizeof(*cb));
+ ssh_callbacks_init(cb);
+ cb->userdata = ch;
+ cb->channel_data_function = server_channel_data_cb;
+ cb->channel_eof_function = server_channel_eof_cb;
+ cb->channel_close_function = server_channel_close_cb;
+ cb->channel_write_wontblock_function = server_channel_write_wontblock_cb;
+ if (ssh_add_channel_callbacks(ch->channel, cb) != SSH_OK) {
+ acton_free(cb);
+ return SSH_ERROR;
+ }
+ ch->callbacks = cb;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server channel callbacks set ch=%p", (void *)ch);
+ }
+ return SSH_OK;
+}
+
+static void server_channel_finalize(ssh_server_channel_ctx *ch) {
+ sshQ_ServerChannel actor = server_channel_actor_ref(ch);
+ while (ch->write_head != NULL) {
+ server_write_chunk_t *chunk = ch->write_head;
+ ch->write_head = chunk->next;
+ acton_free(chunk);
+ }
+ ch->write_tail = NULL;
+ if (ch->pending_req) {
+ ssh_message_reply_default(ch->pending_req);
+ ssh_message_free(ch->pending_req);
+ ch->pending_req = NULL;
+ ch->pending_req_type = SCHAN_REQ_NONE;
+ }
+ if (ch->channel != NULL) {
+ if (ch->callbacks) {
+ ssh_remove_channel_callbacks(ch->channel, ch->callbacks);
+ acton_free(ch->callbacks);
+ ch->callbacks = NULL;
+ }
+ ssh_channel_free(ch->channel);
+ ch->channel = NULL;
+ }
+ ch->state = SCHAN_STATE_CLOSED;
+ if (actor)
+ actor->_channel_id = toB_u64(0);
+ if (!ch->stdout_eof && ch->on_data) {
+ $action2 f = ($action2)ch->on_data;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stdout_eof = 1;
+ }
+ if (!ch->stderr_eof && ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, B_None);
+ ch->stderr_eof = 1;
+ }
+ server_channel_notify_close(ch, "closed");
+ ch->actor = NULL;
+}
+
+static void server_channel_queue_write(ssh_server_channel_ctx *ch, B_bytes data, int is_stderr) {
+ server_write_chunk_t *chunk = acton_calloc(1, sizeof(server_write_chunk_t));
+ chunk->data = data;
+ chunk->offset = 0;
+ chunk->is_stderr = is_stderr;
+ chunk->next = NULL;
+ if (ch->write_tail) {
+ ch->write_tail->next = chunk;
+ } else {
+ ch->write_head = chunk;
+ }
+ ch->write_tail = chunk;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server channel queue write: %zu bytes", data ? data->nbytes : 0);
+ }
+}
+
+static void server_channel_try_write(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) {
+ while (ch->write_head != NULL && ch->write_head->data->nbytes == ch->write_head->offset) {
+ server_write_chunk_t *chunk = ch->write_head;
+ ch->write_head = chunk->next;
+ acton_free(chunk);
+ if (ch->write_head == NULL)
+ ch->write_tail = NULL;
+ }
+
+ if (ch->write_head == NULL || !ch->write_wontblock || session_has_pending_write(s->session))
+ return;
+
+ server_write_chunk_t *chunk = ch->write_head;
+ size_t remaining = chunk->data->nbytes - chunk->offset;
+ ch->write_wontblock = 0;
+ int rc;
+ if (chunk->is_stderr) {
+ rc = ssh_channel_write_stderr(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining);
+ } else {
+ rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining);
+ }
+ if (rc > 0) {
+ chunk->offset += (size_t)rc;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server channel wrote %d bytes", rc);
+ }
+ if (chunk->offset >= chunk->data->nbytes) {
+ ch->write_head = chunk->next;
+ acton_free(chunk);
+ if (ch->write_head == NULL)
+ ch->write_tail = NULL;
+ }
+ } else if (rc == 0 || rc == SSH_AGAIN) {
+ s->write_ready = 0;
+ if (ssh_debug_enabled) {
+ unsigned int window = ssh_channel_window_size(ch->channel);
+ ssh_debug_log("server channel write pending rc=%d remaining=%zu window=%u",
+ rc, remaining, window);
+ }
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH server channel write error: %s", ssh_get_error(s->session));
+ server_channel_notify_close(ch, errmsg);
+ ch->state = SCHAN_STATE_ERROR;
+ return;
+ }
+}
+
+static int server_channel_read_stream(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch, int is_stderr) {
+ char buf[SSH_READ_BUFSIZE];
+ int read_any = 0;
+ for (;;) {
+ int n = ssh_channel_read_buffered(ch->channel, buf, sizeof(buf), is_stderr);
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server channel read: rc=%d stderr=%d", n, is_stderr);
+ }
+ if (n > 0) {
+ read_any = 1;
+ B_bytes out = to$bytesD_len(buf, n);
+ sshQ_ServerChannel actor = server_channel_actor_ref(ch);
+ if (is_stderr) {
+ if (ch->on_stderr) {
+ $action2 f = ($action2)ch->on_stderr;
+ f->$class->__asyn__(f, actor, out);
+ }
+ } else {
+ if (ch->on_data) {
+ $action2 f = ($action2)ch->on_data;
+ f->$class->__asyn__(f, actor, out);
+ }
+ }
+ continue;
+ }
+ if (n == 0 || n == SSH_AGAIN) {
+ break;
+ }
+ if (n == SSH_EOF) {
+ break;
+ }
+ if (n == SSH_ERROR) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH server channel read error: %s", ssh_get_error(s->session));
+ server_channel_notify_close(ch, errmsg);
+ ch->state = SCHAN_STATE_ERROR;
+ break;
+ }
+ }
+ return read_any;
+}
+
+static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) {
+ if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR)
+ return;
+
+ if (ch->write_head != NULL) {
+ server_channel_try_write(s, ch);
+ if (ch->state == SCHAN_STATE_ERROR)
+ return;
+ }
+
+ if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL &&
+ !session_has_pending_write(s->session)) {
+ int rc = ssh_channel_send_eof(ch->channel);
+ if (rc == SSH_OK) {
+ ch->eof_sent = 1;
+ } else if (rc == SSH_AGAIN) {
+ s->write_ready = 0;
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH server send EOF failed: %s", ssh_get_error(s->session));
+ server_channel_notify_close(ch, errmsg);
+ ch->state = SCHAN_STATE_ERROR;
+ return;
+ }
+ }
+
+ if (ch->close_requested && !ch->close_sent && ch->write_head == NULL &&
+ (!ch->send_eof || ch->eof_sent) &&
+ !session_has_pending_write(s->session)) {
+ int rc = ssh_channel_close(ch->channel);
+ if (rc == SSH_OK) {
+ ch->close_sent = 1;
+ ch->state = SCHAN_STATE_CLOSING;
+ } else if (rc == SSH_AGAIN) {
+ s->write_ready = 0;
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH server channel close failed: %s", ssh_get_error(s->session));
+ server_channel_notify_close(ch, errmsg);
+ ch->state = SCHAN_STATE_ERROR;
+ return;
+ }
+ }
+
+ if (ch->callbacks == NULL) {
+ for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) {
+ int did = 0;
+ did |= server_channel_read_stream(s, ch, 0);
+ did |= server_channel_read_stream(s, ch, 1);
+ if (!did)
+ break;
+ }
+ }
+
+ if (ch->channel != NULL && ch->remote_close_seen &&
+ ssh_channel_is_closed(ch->channel)) {
+ server_channel_finalize(ch);
+ }
+}
+
+static void session_drive_channels(ssh_server_session_ctx *s) {
+ ssh_server_channel_ctx *prev = NULL;
+ ssh_server_channel_ctx *ch = s->channels;
+ while (ch != NULL) {
+ ssh_server_channel_ctx *next = ch->next;
+ server_channel_drive(s, ch);
+ if (ch->state == SCHAN_STATE_ERROR) {
+ server_channel_finalize(ch);
+ }
+ if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) {
+ if (prev != NULL) {
+ prev->next = next;
+ } else {
+ s->channels = next;
+ }
+ session_retire_channel(s, ch);
+ } else {
+ prev = ch;
+ }
+ ch = next;
+ }
+}
+
+static void session_fail(ssh_server_session_ctx *s, const char *msg) {
+ if (s == NULL || s->state == SESSION_STATE_CLOSED || s->state == SESSION_STATE_ERROR)
+ return;
+ s->state = SESSION_STATE_ERROR;
+ session_close_internal(s, msg, 1);
+}
+
+static int session_check_reply_rc(ssh_server_session_ctx *s, int rc, const char *context) {
+ if (rc == SSH_OK || rc == SSH_AGAIN)
+ return 0;
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "%s: %s",
+ context, s != NULL && s->session != NULL ? ssh_get_error(s->session) : "unknown error");
+ session_fail(s, errmsg);
+ return -1;
+}
+
+static void session_restart_auth_timer(ssh_server_session_ctx *s, double timeout_sec) {
+ if (s == NULL)
+ return;
+ if (timeout_sec <= 0.0)
+ return;
+ if (s->auth_timer == NULL) {
+ s->auth_timer = acton_calloc(1, sizeof(uv_timer_t));
+ s->auth_timer->data = s;
+ uv_timer_init(get_uv_loop(), s->auth_timer);
+ } else {
+ uv_timer_stop(s->auth_timer);
+ }
+ uv_timer_start(s->auth_timer, session_auth_timeout_cb,
+ (uint64_t)(timeout_sec * 1000.0), 0);
+}
+
+static void session_start_attach_timer(ssh_server_session_ctx *s) {
+ if (s == NULL || s->attach_timer != NULL)
+ return;
+ s->attach_timer = acton_calloc(1, sizeof(uv_timer_t));
+ s->attach_timer->data = s;
+ uv_timer_init(get_uv_loop(), s->attach_timer);
+ uv_timer_start(s->attach_timer, session_auth_timeout_cb,
+ (uint64_t)(SSH_ATTACH_TIMEOUT_SEC * 1000.0), 0);
+}
+
+static void session_start_keyex_timer(ssh_server_session_ctx *s) {
+ if (s == NULL)
+ return;
+ double timeout = s->auth_timeout > SSH_KEYEX_TIMEOUT_SEC ?
+ s->auth_timeout : SSH_KEYEX_TIMEOUT_SEC;
+ session_restart_auth_timer(s, timeout);
+}
+
+static void session_start_auth_timer(ssh_server_session_ctx *s) {
+ session_restart_auth_timer(s, s != NULL ? s->auth_timeout : 0.0);
+}
+
+static void session_start_keepalive(ssh_server_session_ctx *s) {
+ if (s == NULL || !s->keepalive_enabled || s->keepalive_interval <= 0.0 || s->keepalive_timer != NULL)
+ return;
+ s->keepalive_timer = acton_calloc(1, sizeof(uv_timer_t));
+ s->keepalive_timer->data = s;
+ uv_timer_init(get_uv_loop(), s->keepalive_timer);
+ uv_timer_start(s->keepalive_timer, session_keepalive_cb, (uint64_t)(s->keepalive_interval * 1000.0), (uint64_t)(s->keepalive_interval * 1000.0));
+}
+
+static void server_fail(ssh_server_ctx *s, const char *msg) {
+ if (s == NULL || s->state == SERVER_STATE_CLOSED || s->state == SERVER_STATE_ERROR)
+ return;
+ s->state = SERVER_STATE_ERROR;
+ if (!s->listen_ok)
+ server_notify_listen(s, msg);
+ server_close_internal(s, msg);
+}
+
+static void session_auth_timeout_cb(uv_timer_t *timer) {
+ ssh_server_session_ctx *s = (ssh_server_session_ctx *)timer->data;
+ if (s == NULL)
+ return;
+ if (!s->attached && s->state == SESSION_STATE_KEYEX) {
+ session_fail(s, "SSH session attach timeout");
+ return;
+ }
+ if (s->attached && s->state == SESSION_STATE_KEYEX) {
+ session_fail(s, "SSH key exchange timeout");
+ return;
+ }
+ if (s->state == SESSION_STATE_AUTH) {
+ session_fail(s, "SSH authentication timeout");
+ }
}
-void sshQ___ext_init__() {
- // TODO: can we avoid custom malloc in libssh? like let libssh use stock
- // malloc and instead we would explicitly call free() from a finalizer()
- // All things related to buffers for receiving data and similarly would have
- // to be allocated on the GC-heap though since that data is passed outside
- // of the SSH actor
- libssh_replace_allocator(
- acton_gc_malloc,
- acton_gc_realloc,
- acton_gc_calloc,
- noop_free,
- acton_gc_strdup,
- acton_gc_strndup);
- int r = ssh_init();
- printf("SSH extension initialized %d\n", r);
-}
-
-B_str sshQ_version () {
- if (LIBSSH_VERSION_MINOR != 11)
- return to$str("invalid");
- return to$str("0.11.0");
-}
-
-// TODO: crap function for test, to be replaced with something
-int show_remote_processes(ssh_session session)
-{
- ssh_channel channel;
- int rc;
- char buffer[256];
- int nbytes;
-
- channel = ssh_channel_new(session);
- if (channel == NULL)
- return SSH_ERROR;
-
- rc = ssh_channel_open_session(channel);
- if (rc != SSH_OK)
- {
- ssh_channel_free(channel);
- return rc;
- }
-
- rc = ssh_channel_request_exec(channel, "ps aux");
- if (rc != SSH_OK)
- {
- ssh_channel_close(channel);
- ssh_channel_free(channel);
- return rc;
- }
-
- nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0);
- while (nbytes > 0)
- {
- if (write(1, buffer, nbytes) != (unsigned int) nbytes)
- {
- ssh_channel_close(channel);
- ssh_channel_free(channel);
- return SSH_ERROR;
- }
- nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0);
- }
-
- if (nbytes < 0)
- {
- ssh_channel_close(channel);
- ssh_channel_free(channel);
- return SSH_ERROR;
- }
-
- ssh_channel_send_eof(channel);
- ssh_channel_close(channel);
- ssh_channel_free(channel);
-
- return SSH_OK;
-}
-
-$R sshQ_ClientD__initG_local (sshQ_Client self, $Cont c$cont) {
- ssh_session session = ssh_new();
- if (session == NULL) {
- //log_error("Failed to create SSH session");
- return $R_CONT(c$cont, B_None);
- }
- printf("session: %p\n", session);
- self->_ssh_session = toB_u64((unsigned long)session);
- printf("init self->session: %p\n", self->_ssh_session);
-
- ssh_options_set(session, SSH_OPTIONS_HOST, fromB_str(self->host));
- ssh_options_set(session, SSH_OPTIONS_PORT, &self->port->val);
- ssh_options_set(session, SSH_OPTIONS_USER, fromB_str(self->username));
-
- ssh_set_blocking(session, 1);
- printf("Connecting to \n");
- int rc = ssh_connect(session);
+static int session_start_poll(ssh_server_session_ctx *s, char *errmsg, size_t errmsg_len) {
+ if (s == NULL || s->session == NULL || s->fd < 0) {
+ snprintf(errmsg, errmsg_len, "Failed to start SSH session poll");
+ return -1;
+ }
+ if (s->poll != NULL)
+ return 0;
+
+ s->poll = acton_calloc(1, sizeof(uv_poll_t));
+ s->poll->data = s;
+ int uv_rc = uv_poll_init(get_uv_loop(), s->poll, s->fd);
+ if (uv_rc != 0) {
+ uv_strerror_r(uv_rc, errmsg, errmsg_len);
+ acton_free(s->poll);
+ s->poll = NULL;
+ return -1;
+ }
+ s->poll_events = UV_READABLE | UV_WRITABLE;
+ uv_rc = uv_poll_start(s->poll, s->poll_events, session_poll_cb);
+ if (uv_rc != 0) {
+ uv_strerror_r(uv_rc, errmsg, errmsg_len);
+ return -1;
+ }
+ return 0;
+}
+
+static void session_keepalive_cb(uv_timer_t *timer) {
+ ssh_server_session_ctx *s = (ssh_server_session_ctx *)timer->data;
+ if (s == NULL || s->session == NULL)
+ return;
+ if (s->state != SESSION_STATE_READY)
+ return;
+ int rc = ssh_send_ignore(s->session, "keepalive");
+ if (rc == SSH_AGAIN) {
+ session_update_poll(s);
+ return;
+ }
+ if (rc != SSH_OK) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH server keepalive failed: %s", ssh_get_error(s->session));
+ session_fail(s, errmsg);
+ return;
+ }
+ session_update_poll(s);
+}
+
+static void session_update_poll(ssh_server_session_ctx *s) {
+ if (s->poll == NULL || s->session == NULL)
+ return;
+ if (s->state == SESSION_STATE_CLOSED)
+ return;
+ if (uv_is_closing((uv_handle_t *)s->poll))
+ return;
+ int status = ssh_get_status(s->session);
+ if (status & SSH_CLOSED_ERROR) {
+ session_fail(s, "SSH session closed with error");
+ return;
+ }
+ if (status & SSH_CLOSED) {
+ session_close_internal(s, "SSH session closed", 1);
+ return;
+ }
+ int flags = ssh_get_poll_flags(s->session);
+ int pending = flags | status;
+ int events = UV_READABLE;
+#ifdef UV_DISCONNECT
+ events |= UV_DISCONNECT;
+#endif
+ if (pending & SSH_WRITE_PENDING)
+ events |= UV_WRITABLE;
+ if ((events & UV_WRITABLE) == 0 && session_needs_write(s))
+ events |= UV_WRITABLE;
+ if (ssh_debug_enabled && (events != s->poll_events || (pending & SSH_WRITE_PENDING))) {
+ ssh_debug_log("server update poll: status=0x%x flags=0x%x pending=0x%x events=0x%x state=%d",
+ status, flags, pending, events, s->state);
+ }
+ if (events != s->poll_events) {
+ int uv_rc = uv_poll_start(s->poll, events, session_poll_cb);
+ if (uv_rc != 0) {
+ char errmsg[256] = {0};
+ uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg));
+ session_fail(s, errmsg);
+ return;
+ }
+ s->poll_events = events;
+ }
+}
+
+static void session_pump_io(ssh_server_session_ctx *s) {
+ if (s == NULL || s->session == NULL)
+ return;
+ int i;
+ for (i = 0; i < SSH_IO_PUMP_LIMIT; i++) {
+ int did = 0;
+ if (s->session == NULL)
+ return;
+ int has_data = fd_has_data(s->fd);
+ if (has_data) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server pump: readable");
+ }
+ ssh_set_fd_toread(s->session);
+ if (session_apply_poll_events(s->session, UV_READABLE) != 0) {
+ char errmsg[256] = {0};
+ format_session_error(s->session, "SSH poll callback error", errmsg, sizeof(errmsg));
+ session_fail(s, errmsg);
+ return;
+ }
+ session_drive(s);
+ did = 1;
+ }
+ if (s->session == NULL)
+ return;
+ if (!did) {
+ int status = ssh_get_status(s->session);
+ if (status & SSH_READ_PENDING) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server pump: buffered read pending=0x%x", status);
+ }
+ session_drive(s);
+ /* Avoid spinning when only buffered data remains. */
+ break;
+ }
+ }
+ if (!did)
+ break;
+ }
+ if (ssh_debug_enabled && i >= SSH_IO_PUMP_LIMIT) {
+ int status = ssh_get_status(s->session);
+ int flags = ssh_get_poll_flags(s->session);
+ ssh_debug_log("server pump: hit limit status=0x%x flags=0x%x", status, flags);
+ }
+}
+
+static int session_needs_write(ssh_server_session_ctx *s) {
+ if (s == NULL)
+ return 0;
+ ssh_server_channel_ctx *ch = s->channels;
+ while (ch != NULL) {
+ if (ch->pending_req != NULL || ch->write_head != NULL)
+ return 1;
+ if (ch->send_eof && !ch->eof_sent)
+ return 1;
+ if (ch->close_requested && !ch->close_sent)
+ return 1;
+ ch = ch->next;
+ }
+ return 0;
+}
+
+static void session_drive(ssh_server_session_ctx *s) {
+ if (s == NULL)
+ return;
+ if (s->state == SESSION_STATE_ERROR || s->state == SESSION_STATE_CLOSED)
+ return;
+ if (!s->attached)
+ return;
+ if (s->state == SESSION_STATE_CLOSING) {
+ session_drive_channels(s);
+ session_finish_close(s);
+ return;
+ }
+
+ int spin = 0;
+ while (1) {
+ if (s->state == SESSION_STATE_KEYEX) {
+ int rc = ssh_handle_key_exchange(s->session);
+ if (rc == SSH_OK) {
+ s->state = SESSION_STATE_AUTH;
+ stop_timer(&s->attach_timer, session_timer_close_cb);
+ session_start_auth_timer(s);
+ ssh_set_auth_methods(s->session, SSH_AUTH_METHOD_PASSWORD);
+ continue;
+ } else if (rc == SSH_AGAIN) {
+ int status = ssh_get_status(s->session);
+ if (status & SSH_WRITE_PENDING)
+ s->write_ready = 0;
+ if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server drive: buffered read pending during key exchange");
+ }
+ continue;
+ }
+ session_update_poll(s);
+ return;
+ } else {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH key exchange failed: %s", ssh_get_error(s->session));
+ session_fail(s, errmsg);
+ return;
+ }
+ }
+
+ if (s->state == SESSION_STATE_AUTH) {
+ if (s->pending_auth != NULL) {
+ session_update_poll(s);
+ return;
+ }
+ ssh_message msg = ssh_message_get(s->session);
+ if (msg == NULL) {
+ session_update_poll(s);
+ return;
+ }
+ int type = ssh_message_type(msg);
+ if (type == SSH_REQUEST_SERVICE) {
+ int rc;
+ const char *service = ssh_message_service_service(msg);
+ if (service && strcmp(service, "ssh-userauth") == 0) {
+ rc = ssh_message_service_reply_success(msg);
+ } else {
+ rc = ssh_message_reply_default(msg);
+ }
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH service reply failed") != 0)
+ return;
+ continue;
+ }
+ if (type == SSH_REQUEST_AUTH && ssh_message_subtype(msg) == SSH_AUTH_METHOD_PASSWORD) {
+ if (s->on_auth == NULL) {
+ int rc;
+ ssh_message_auth_set_methods(msg, SSH_AUTH_METHOD_PASSWORD);
+ rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH auth reject failed") != 0)
+ return;
+ continue;
+ }
+ const char *user = ssh_message_auth_user(msg);
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+ const char *pass = ssh_message_auth_password(msg);
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
+ s->pending_auth = msg;
+ sshQ_AuthRequest req = sshQ_AuthRequestG_new(
+ to$str((char *)"password"),
+ to$str((char *)(user ? user : "")),
+ pass ? to$str((char *)(pass)) : B_None,
+ B_None);
+ $action2 f = ($action2)s->on_auth;
+ f->$class->__asyn__(f, session_actor_ref(s), req);
+ session_update_poll(s);
+ return;
+ }
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH auth reply failed") != 0)
+ return;
+ continue;
+ }
+
+ if (s->state == SESSION_STATE_READY) {
+ if (s->pending_channel_open != NULL) {
+ session_drive_channels(s);
+ session_update_poll(s);
+ return;
+ }
+ while (1) {
+ ssh_message msg = ssh_message_get(s->session);
+ if (msg == NULL)
+ break;
+ int type = ssh_message_type(msg);
+ if (type == SSH_REQUEST_CHANNEL_OPEN) {
+ if (s->pending_channel_open != NULL) {
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0)
+ return;
+ } else if (ssh_message_subtype(msg) != SSH_CHANNEL_SESSION) {
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0)
+ return;
+ } else if (session_channel_limit_reached(s)) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server session: rejecting channel open limit=%d",
+ s->server ? s->server->max_channels_per_session : 0);
+ }
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0)
+ return;
+ } else if (s->on_channel_open == NULL) {
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0)
+ return;
+ } else {
+ s->pending_channel_open = msg;
+ $action f = ($action)s->on_channel_open;
+ f->$class->__asyn__(f, session_actor_ref(s));
+ break;
+ }
+ } else if (type == SSH_REQUEST_CHANNEL) {
+ ssh_channel chan = ssh_message_channel_request_channel(msg);
+ ssh_server_channel_ctx *ch = server_channel_from_ssh(s, chan);
+ if (ch == NULL || ch->pending_req != NULL) {
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0)
+ return;
+ } else if (ssh_message_subtype(msg) == SSH_CHANNEL_REQUEST_EXEC) {
+ if (s->on_exec == NULL) {
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH exec reject failed") != 0)
+ return;
+ } else {
+ const char *cmd = ssh_message_channel_request_command(msg);
+ ch->pending_req = msg;
+ ch->pending_req_type = SCHAN_REQ_EXEC;
+ $action3 f = ($action3)s->on_exec;
+ f->$class->__asyn__(f, session_actor_ref(s), server_channel_actor_ref(ch),
+ to$str((char *)(cmd ? cmd : "")));
+ break;
+ }
+ } else if (ssh_message_subtype(msg) == SSH_CHANNEL_REQUEST_SUBSYSTEM) {
+ if (s->on_subsystem == NULL) {
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH subsystem reject failed") != 0)
+ return;
+ } else {
+ const char *name = ssh_message_channel_request_subsystem(msg);
+ ch->pending_req = msg;
+ ch->pending_req_type = SCHAN_REQ_SUBSYSTEM;
+ $action3 f = ($action3)s->on_subsystem;
+ f->$class->__asyn__(f, session_actor_ref(s), server_channel_actor_ref(ch),
+ to$str((char *)(name ? name : "")));
+ break;
+ }
+ } else {
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0)
+ return;
+ }
+ } else if (type == SSH_REQUEST_SERVICE) {
+ int rc;
+ const char *service = ssh_message_service_service(msg);
+ if (service && strcmp(service, "ssh-connection") == 0) {
+ rc = ssh_message_service_reply_success(msg);
+ } else {
+ rc = ssh_message_reply_default(msg);
+ }
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH connection service reply failed") != 0)
+ return;
+ } else {
+ int rc = ssh_message_reply_default(msg);
+ ssh_message_free(msg);
+ if (session_check_reply_rc(s, rc, "SSH request reply failed") != 0)
+ return;
+ }
+ }
+ session_drive_channels(s);
+ session_update_poll(s);
+ return;
+ }
+ return;
+ }
+}
+
+static void session_poll_cb(uv_poll_t *handle, int status, int events) {
+ ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data;
+ if (s == NULL)
+ return;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server poll: status=%d events=0x%x state=%d", status, events, s->state);
+ }
+ if (status < 0) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH session poll error: %s", uv_strerror(status));
+ session_fail(s, errmsg);
+ return;
+ }
+ int libssh_events = 0;
+ if ((events & UV_READABLE) && fd_has_data(s->fd)) {
+ ssh_set_fd_toread(s->session);
+ libssh_events |= UV_READABLE;
+ }
+#ifdef UV_DISCONNECT
+ if (events & UV_DISCONNECT) {
+ ssh_set_fd_toread(s->session);
+ libssh_events |= UV_DISCONNECT;
+ }
+#endif
+ if ((events & UV_WRITABLE) && fd_can_write(s->fd)) {
+ s->write_ready = 1;
+ ssh_set_fd_towrite(s->session);
+ libssh_events |= UV_WRITABLE;
+ }
+ if (session_apply_poll_events(s->session, libssh_events) != 0) {
+ char errmsg[256] = {0};
+ format_session_error(s->session, "SSH poll callback error", errmsg, sizeof(errmsg));
+ session_fail(s, errmsg);
+ return;
+ }
+ session_drive(s);
+ session_pump_io(s);
+ s->write_ready = 0;
+}
+
+static void server_poll_cb(uv_poll_t *handle, int status, int events) {
+ ssh_server_ctx *s = (ssh_server_ctx *)handle->data;
+ if (s == NULL)
+ return;
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server listen poll: status=%d events=0x%x", status, events);
+ }
+ if (status < 0) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH server poll error: %s", uv_strerror(status));
+ server_fail(s, errmsg);
+ return;
+ }
+ if (events & UV_READABLE)
+ server_accept(s);
+}
+
+static void server_accept(ssh_server_ctx *s) {
+ if (s == NULL || s->state != SERVER_STATE_LISTENING)
+ return;
+
+ int accepted = 0;
+ while (accepted < SSH_SERVER_ACCEPT_LIMIT) {
+ socket_t fd = accept(s->fd, NULL, NULL);
+ if (fd == SSH_INVALID_SOCKET) {
+ if (errno == EINTR) {
+ continue;
+ }
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ return;
+ }
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH accept failed: %s", strerror(errno));
+ server_fail(s, errmsg);
+ return;
+ }
+ accepted++;
+
+ if (server_session_limit_reached(s)) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server accept: rejecting fd=%d session limit=%d",
+ (int)fd, s->max_sessions);
+ }
+ close(fd);
+ continue;
+ }
+
+ if (fd_set_nonblocking(fd) != 0) {
+ log_warn("SSH accept: failed to set accepted fd nonblocking");
+ close(fd);
+ continue;
+ }
+
+ ssh_session session = ssh_new();
+ if (session == NULL) {
+ log_warn("SSH accept: failed to create SSH session");
+ close(fd);
+ continue;
+ }
+
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server accept: actor=%p", (void *)server_actor_ref(s));
+ }
+
+ int rc = ssh_bind_accept_fd(s->bind, session, fd);
+ if (rc != SSH_OK) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server accept: accept_fd failed: %s", ssh_get_error(s->bind));
+ }
+ close(fd);
+ ssh_free(session);
+ continue;
+ }
+
+ ssh_set_blocking(session, 0);
+ ssh_server_session_ctx *sess = acton_calloc(1, sizeof(ssh_server_session_ctx));
+ sess->server = s;
+ sess->session = session;
+ sess->state = SESSION_STATE_KEYEX;
+ sess->pending_id = alloc_pending_session_id();
+ sess->fd = ssh_get_fd(session);
+ sshQ_Server act = server_actor_ref(s);
+ sess->owner_wt = act ? (int)act->$affinity : 0;
+ sess->auth_timeout = act ? fromB_float(act->_auth_timeout) : 0.0;
+ if (sess->fd < 0) {
+ log_warn("SSH accept: failed to get accepted session fd");
+ ssh_disconnect(session);
+ ssh_free(session);
+ continue;
+ }
+ session_start_attach_timer(sess);
+
+ sess->next = s->sessions;
+ s->sessions = sess;
+
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server accept: session fd=%d", sess->fd);
+ }
+
+ if (act) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server accept: scheduling session pending act=%p session=%llu",
+ (void *)act, (unsigned long long)sess->pending_id);
+ }
+ act->$class->on_session_pending(act, toB_u64(sess->pending_id));
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server accept: on_session_pending call returned session=%llu",
+ (unsigned long long)sess->pending_id);
+ }
+ }
+ }
+}
+
+static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *sess) {
+ if (s == NULL || sess == NULL)
+ return;
+ ssh_server_session_ctx *prev = NULL;
+ ssh_server_session_ctx *cur = s->sessions;
+ while (cur != NULL) {
+ if (cur == sess) {
+ if (prev != NULL)
+ prev->next = cur->next;
+ else
+ s->sessions = cur->next;
+ break;
+ }
+ prev = cur;
+ cur = cur->next;
+ }
+ server_maybe_release(s);
+}
+
+static void server_finalize(ssh_server_ctx *s) {
+ if (s == NULL || s->close_finalized)
+ return;
+ s->close_finalized = 1;
+
+ if (s->bind != NULL) {
+ ssh_bind_free(s->bind);
+ s->bind = NULL;
+ }
+ if (s->hostkey != NULL) {
+ ssh_key_free(s->hostkey);
+ s->hostkey = NULL;
+ }
+
+ server_notify_close(s, s->close_reason ? s->close_reason : "closed");
+ s->state = SERVER_STATE_CLOSED;
+ sshQ_Server actor = server_actor_ref(s);
+ if (actor)
+ actor->_server = toB_u64(0);
+ STORE_HIDDEN_PTR(s->actor, NULL);
+ server_maybe_release(s);
+}
+
+static void session_finalize(ssh_server_session_ctx *s) {
+ if (s == NULL || s->close_finalized)
+ return;
+ s->close_finalized = 1;
+
+ if (s->session != NULL) {
+ ssh_disconnect(s->session);
+ ssh_free(s->session);
+ s->session = NULL;
+ }
+ session_free_retired_channels(s);
+
+ server_remove_session(s->server, s);
+ session_notify_close(s, s->close_reason ? s->close_reason : "closed");
+ s->state = SESSION_STATE_CLOSED;
+ s->pending_id = 0;
+ sshQ_ServerSession actor = session_actor_ref(s);
+ if (actor)
+ actor->_session_id = toB_u64(0);
+ s->actor = NULL;
+ session_maybe_release(s);
+}
+
+static void session_reject_pending_messages(ssh_server_session_ctx *s) {
+ if (s->pending_auth) {
+ ssh_message_reply_default(s->pending_auth);
+ ssh_message_free(s->pending_auth);
+ s->pending_auth = NULL;
+ }
+ if (s->pending_channel_open) {
+ ssh_message_reply_default(s->pending_channel_open);
+ ssh_message_free(s->pending_channel_open);
+ s->pending_channel_open = NULL;
+ }
+}
+
+static void session_abort_channels(ssh_server_session_ctx *s) {
+ ssh_server_channel_ctx *ch = s->channels;
+ while (ch != NULL) {
+ ssh_server_channel_ctx *next = ch->next;
+ server_channel_notify_close(ch, "Session closed");
+ server_channel_finalize(ch);
+ session_retire_channel(s, ch);
+ ch = next;
+ }
+ s->channels = NULL;
+}
+
+static void session_request_channel_close(ssh_server_session_ctx *s) {
+ ssh_server_channel_ctx *ch = s->channels;
+ while (ch != NULL) {
+ if (ch->pending_req) {
+ ssh_message_reply_default(ch->pending_req);
+ ssh_message_free(ch->pending_req);
+ ch->pending_req = NULL;
+ ch->pending_req_type = SCHAN_REQ_NONE;
+ }
+ if (ch->state != SCHAN_STATE_CLOSED && ch->state != SCHAN_STATE_ERROR) {
+ ch->send_eof = 1;
+ ch->close_requested = 1;
+ }
+ ch = ch->next;
+ }
+}
+
+static void session_finish_close(ssh_server_session_ctx *s) {
+ if (s == NULL || s->state != SESSION_STATE_CLOSING)
+ return;
+ if (s->close_force) {
+ if (s->poll != NULL) {
+ close_poll(&s->poll, session_poll_close_cb);
+ s->poll_events = 0;
+ return;
+ }
+ session_finalize(s);
+ return;
+ }
+ if (s->channels != NULL) {
+ session_update_poll(s);
+ return;
+ }
+ if (s->session != NULL && session_has_pending_write(s->session)) {
+ session_update_poll(s);
+ return;
+ }
+ if (s->poll != NULL) {
+ close_poll(&s->poll, session_poll_close_cb);
+ s->poll_events = 0;
+ return;
+ }
+ session_finalize(s);
+}
+
+static void server_close_internal(ssh_server_ctx *s, const char *reason) {
+ if (s == NULL || s->state == SERVER_STATE_CLOSED || s->state == SERVER_STATE_CLOSING)
+ return;
+ int force_sessions = (s->state == SERVER_STATE_ERROR);
+
+ if (!s->listen_ok && !s->listen_notified) {
+ server_notify_listen(s, reason ? reason : "closed");
+ }
+
+ s->state = SERVER_STATE_CLOSING;
+ if (reason != NULL && s->close_reason == NULL)
+ s->close_reason = acton_strdup(reason);
+
+ if (s->poll != NULL) {
+ close_poll(&s->poll, server_poll_close_cb);
+ }
+
+ ssh_server_session_ctx *sess = s->sessions;
+ while (sess != NULL) {
+ ssh_server_session_ctx *next = sess->next;
+ session_close_internal(sess, "Server closed", force_sessions);
+ sess = next;
+ }
+
+ if (s->poll != NULL)
+ return;
+
+ server_finalize(s);
+}
+
+static void session_close_internal(ssh_server_session_ctx *s, const char *reason, int force_close) {
+ if (s == NULL || s->state == SESSION_STATE_CLOSED)
+ return;
+ if (!force_close && s->state != SESSION_STATE_READY)
+ force_close = 1;
+ if (reason != NULL && s->close_reason == NULL)
+ s->close_reason = acton_strdup(reason);
+ if (s->state == SESSION_STATE_CLOSING) {
+ if (force_close && !s->close_force) {
+ s->close_force = 1;
+ session_reject_pending_messages(s);
+ session_abort_channels(s);
+ }
+ session_finish_close(s);
+ return;
+ }
+
+ stop_timer(&s->auth_timer, session_timer_close_cb);
+ stop_timer(&s->attach_timer, session_timer_close_cb);
+ stop_timer(&s->keepalive_timer, session_timer_close_cb);
+
+ s->state = SESSION_STATE_CLOSING;
+ s->close_force = force_close;
+ session_reject_pending_messages(s);
+ if (force_close) {
+ session_abort_channels(s);
+ session_finish_close(s);
+ return;
+ }
+
+ session_request_channel_close(s);
+ session_drive(s);
+}
+
+static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_out) {
+ if (type_str == NULL)
+ return SSH_KEYTYPE_UNKNOWN;
+ if (strcmp(type_str, "ed25519") == 0) {
+ if (param_out)
+ *param_out = 0;
+ return SSH_KEYTYPE_ED25519;
+ }
+ if (strcmp(type_str, "rsa") == 0) {
+ if (param_out)
+ *param_out = 2048;
+ return SSH_KEYTYPE_RSA;
+ }
+ if (strcmp(type_str, "ecdsa") == 0) {
+ if (param_out)
+ *param_out = 256;
+ return SSH_KEYTYPE_ECDSA;
+ }
+ return SSH_KEYTYPE_UNKNOWN;
+}
+
+$R sshQ_ServerD__pin_affinityG_local(sshQ_Server self, $Cont c$cont) {
+ pin_actor_affinity();
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerD__initG_local(sshQ_Server self, $Cont c$cont) {
+ ssh_configure_libssh_logging();
+ if (ssh_debug_enabled) {
+ log_info("ssh server init: self=%p", (void *)self);
+ }
+ ssh_server_ctx *s = acton_calloc(1, sizeof(ssh_server_ctx));
+ STORE_HIDDEN_PTR(s->actor, self);
+ s->on_listen = ($action2)self->_on_listen;
+ s->on_close = ($action2)self->_on_close;
+ s->state = SERVER_STATE_INIT;
+ s->max_sessions = fromB_int(self->_max_sessions);
+ s->max_channels_per_session = fromB_int(self->_max_channels_per_session);
+
+ self->_server = toB_u64((unsigned long)s);
+
+ s->bind = ssh_bind_new();
+ if (s->bind == NULL) {
+ server_fail(s, "Failed to create SSH bind");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ int rc;
+ bool process_config = false;
+
+ rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_PROCESS_CONFIG, &process_config);
if (rc != SSH_OK) {
- //log_error("Error connecting to SSH server: %s", ssh_get_error(session));
- $action2 f = ($action2) self->on_close;
- f->$class->__asyn__(f, self, to$str(ssh_get_error(session)));
+ server_fail(s, "Failed to disable SSH bind config processing");
return $R_CONT(c$cont, B_None);
}
- rc = ssh_userauth_password(session, NULL, fromB_str(self->password));
- if (rc == SSH_OK) {
- printf("Connected\n");
- show_remote_processes(session);
+ const char *host = (const char *)fromB_str(self->_host);
+ int port = (int)fromB_u16(self->_port);
+ rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_BINDADDR, host);
+ if (rc != SSH_OK) {
+ server_fail(s, "Failed to set bind address");
+ return $R_CONT(c$cont, B_None);
+ }
+ rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_BINDPORT, &port);
+ if (rc != SSH_OK) {
+ server_fail(s, "Failed to set bind port");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ if (self->_host_key_path != B_None) {
+ const char *path = (const char *)fromB_str(self->_host_key_path);
+ rc = ssh_pki_import_privkey_file(path, NULL, NULL, NULL, &s->hostkey);
+ if (rc != SSH_OK) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "Failed to load host key: %s", ssh_get_error(s->bind));
+ server_fail(s, errmsg);
+ return $R_CONT(c$cont, B_None);
+ }
} else {
- printf("Error: %s\n", ssh_get_error(session));
+ const char *type_str = (const char *)fromB_str(self->_host_key_type);
+ int param = fromB_int(self->_host_key_bits);
+ int default_param = 0;
+ enum ssh_keytypes_e type = parse_hostkey_type(type_str, &default_param);
+ if (type == SSH_KEYTYPE_UNKNOWN) {
+ server_fail(s, "Unsupported host key type");
+ return $R_CONT(c$cont, B_None);
+ }
+ if (type == SSH_KEYTYPE_ED25519) {
+ param = 0;
+ } else if (param <= 0) {
+ param = default_param;
+ }
+ rc = ssh_pki_generate(type, param, &s->hostkey);
+ if (rc != SSH_OK) {
+ server_fail(s, "Failed to generate host key");
+ return $R_CONT(c$cont, B_None);
+ }
}
-// self->_connected = true;
-// $action f = ($action) self->on_connect;
-// f->$class->__asyn__(f, self);
+ rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_IMPORT_KEY, s->hostkey);
+ if (rc != SSH_OK) {
+ server_fail(s, "Failed to set host key");
+ return $R_CONT(c$cont, B_None);
+ }
+ /* ssh_bind takes ownership of IMPORT_KEY and frees it via ssh_bind_free(). */
+ s->hostkey = NULL;
+
+ ssh_bind_set_blocking(s->bind, 0);
+
+ rc = ssh_bind_listen(s->bind);
+ if (rc != SSH_OK) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH listen failed: %s", ssh_get_error(s->bind));
+ server_fail(s, errmsg);
+ return $R_CONT(c$cont, B_None);
+ }
+
+ s->fd = ssh_bind_get_fd(s->bind);
+ if (s->fd < 0) {
+ server_fail(s, "Failed to get bind fd");
+ return $R_CONT(c$cont, B_None);
+ }
+ if (fd_set_nonblocking(s->fd) != 0) {
+ server_fail(s, "Failed to set bind fd nonblocking");
+ return $R_CONT(c$cont, B_None);
+ }
+
+ s->poll = acton_calloc(1, sizeof(uv_poll_t));
+ s->poll->data = s;
+ int uv_rc = uv_poll_init(get_uv_loop(), s->poll, s->fd);
+ if (uv_rc != 0) {
+ char errmsg[256] = {0};
+ uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg));
+ server_fail(s, errmsg);
+ return $R_CONT(c$cont, B_None);
+ }
+ uv_rc = uv_poll_start(s->poll, UV_READABLE, server_poll_cb);
+ if (uv_rc != 0) {
+ char errmsg[256] = {0};
+ uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg));
+ server_fail(s, errmsg);
+ return $R_CONT(c$cont, B_None);
+ }
+
+ s->state = SERVER_STATE_LISTENING;
+ server_notify_listen(s, NULL);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerD_closeG_local(sshQ_Server self, $Cont c$cont) {
+ ssh_server_ctx *s = server_from_actor(self);
+ if (s == NULL)
+ return $R_CONT(c$cont, B_None);
+ server_close_internal(s, "closed");
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerD__cleanup_nativeG_local(sshQ_Server self, $Cont c$cont) {
+ ssh_server_ctx *s = server_from_actor(self);
+ if (s != NULL)
+ server_close_internal(s, "collected");
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD__pin_affinityG_local(sshQ_ServerSession self, $Cont c$cont) {
+ ssh_server_ctx *server = server_from_actor(self->server);
+ ssh_server_session_ctx *s = session_from_pending_token(server, self->session_id);
+ if (s != NULL && s->owner_wt >= 0) {
+ set_actor_affinity(s->owner_wt);
+ } else {
+ pin_actor_affinity();
+ }
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD__attachG_local(sshQ_ServerSession self, $Cont c$cont, B_u64 session_id) {
+ ssh_server_ctx *server = server_from_actor(self->server);
+ ssh_server_session_ctx *s = session_from_pending_token(server, session_id);
+ if (s == NULL)
+ return $R_CONT(c$cont, B_None);
+ if (s->session == NULL || s->state == SESSION_STATE_CLOSED ||
+ s->state == SESSION_STATE_CLOSING || s->state == SESSION_STATE_ERROR) {
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server session attach skipped closed id=%llu",
+ (unsigned long long)fromB_u64(session_id));
+ }
+ return $R_CONT(c$cont, B_None);
+ }
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server session attach: id=%llu", (unsigned long long)fromB_u64(session_id));
+ }
+ char errmsg[256] = {0};
+ if (session_start_poll(s, errmsg, sizeof(errmsg)) != 0) {
+ session_close_internal(s, errmsg, 1);
+ return $R_CONT(c$cont, B_None);
+ }
+ s->actor = self;
+ self->_session_id = toB_u64((unsigned long)s);
+ s->attached = 1;
+ s->pending_id = 0;
+ s->on_auth = ($action2)self->_on_auth;
+ s->on_channel_open = ($action)self->_on_channel_open;
+ s->on_exec = (self->_on_exec == B_None) ? NULL : ($action3)self->_on_exec;
+ s->on_subsystem = (self->_on_subsystem == B_None) ? NULL : ($action3)self->_on_subsystem;
+ s->on_close = (self->_on_close == B_None) ? NULL : ($action2)self->_on_close;
+ s->auth_timeout = fromB_float(self->server->_auth_timeout);
+ s->keepalive_interval = fromB_float(self->server->_keepalive_interval);
+ s->keepalive_enabled = fromB_bool(self->server->_keepalive_enabled) ? 1 : 0;
+ if (s->state == SESSION_STATE_KEYEX) {
+ stop_timer(&s->attach_timer, session_timer_close_cb);
+ session_start_keyex_timer(s);
+ }
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server session attach: callbacks set");
+ }
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD__drive_attachedG_local(sshQ_ServerSession self, $Cont c$cont) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ if (s == NULL || !s->attached)
+ return $R_CONT(c$cont, B_None);
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server session drive attached");
+ }
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_accept_authG_local(sshQ_ServerSession self, $Cont c$cont) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ if (s == NULL || s->pending_auth == NULL)
+ return $R_CONT(c$cont, B_None);
+
+ int rc = ssh_message_auth_reply_success(s->pending_auth, 0);
+ ssh_message_free(s->pending_auth);
+ s->pending_auth = NULL;
+ if (session_check_reply_rc(s, rc, "SSH auth accept failed") != 0)
+ return $R_CONT(c$cont, B_None);
+ s->state = SESSION_STATE_READY;
+ stop_timer(&s->auth_timer, session_timer_close_cb);
+ session_start_keepalive(s);
+ session_drive(s);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_reject_authG_local(sshQ_ServerSession self, $Cont c$cont, B_str reason) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ if (s == NULL || s->pending_auth == NULL)
+ return $R_CONT(c$cont, B_None);
+
+ ssh_message_auth_set_methods(s->pending_auth, SSH_AUTH_METHOD_PASSWORD);
+ int rc = ssh_message_reply_default(s->pending_auth);
+ ssh_message_free(s->pending_auth);
+ s->pending_auth = NULL;
+ (void)reason;
+ if (session_check_reply_rc(s, rc, "SSH auth reject failed") != 0)
+ return $R_CONT(c$cont, B_None);
+ session_drive(s);
+
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_accept_channel_openG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel,
+ $action on_data,
+ $action on_stderr,
+ $action on_close) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ if (s == NULL || s->pending_channel_open == NULL)
+ return $R_CONT(c$cont, B_None);
+
+ ssh_channel chan = ssh_channel_new(s->session);
+ if (chan == NULL) {
+ int rc = ssh_message_reply_default(s->pending_channel_open);
+ ssh_message_free(s->pending_channel_open);
+ s->pending_channel_open = NULL;
+ if (on_close) {
+ $action2 f = ($action2)on_close;
+ f->$class->__asyn__(f, channel, to$str((char *)"Failed to accept channel open"));
+ }
+ if (session_check_reply_rc(s, rc, "SSH channel open accept failed") != 0)
+ return $R_CONT(c$cont, B_None);
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+ }
+ int rc = ssh_message_channel_request_open_reply_accept_channel(s->pending_channel_open, chan);
+ if (rc != SSH_OK && rc != SSH_AGAIN) {
+ ssh_channel_free(chan);
+ rc = ssh_message_reply_default(s->pending_channel_open);
+ ssh_message_free(s->pending_channel_open);
+ s->pending_channel_open = NULL;
+ if (on_close) {
+ $action2 f = ($action2)on_close;
+ f->$class->__asyn__(f, channel, to$str((char *)"Failed to accept channel open"));
+ }
+ if (session_check_reply_rc(s, rc, "SSH channel open accept failed") != 0)
+ return $R_CONT(c$cont, B_None);
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+ }
+ if (rc == SSH_AGAIN)
+ s->write_ready = 0;
+ ssh_channel_set_blocking(chan, 0);
+ ssh_message_free(s->pending_channel_open);
+ s->pending_channel_open = NULL;
+
+ ssh_server_channel_ctx *ch = acton_calloc(1, sizeof(ssh_server_channel_ctx));
+ ch->channel = chan;
+ ch->session = s;
+ ch->actor = channel;
+ ch->callbacks = NULL;
+ ch->state = SCHAN_STATE_OPEN;
+ ch->send_eof = 0;
+ ch->close_requested = 0;
+ ch->close_sent = 0;
+ ch->eof_sent = 0;
+ ch->stdout_eof = 0;
+ ch->stderr_eof = 0;
+ ch->pending_req = NULL;
+ ch->pending_req_type = SCHAN_REQ_NONE;
+ ch->on_data = ($action2)on_data;
+ ch->on_stderr = ($action2)on_stderr;
+ ch->on_close = ($action2)on_close;
+ if (server_channel_setup_callbacks(ch) != SSH_OK) {
+ server_channel_notify_close(ch, "Failed to set SSH channel callbacks");
+ if (ch->channel != NULL)
+ ssh_channel_close(ch->channel);
+ server_channel_finalize(ch);
+ acton_free(ch);
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+ }
+ if (ssh_debug_enabled) {
+ ssh_debug_log("server channel new ch=%p callbacks=%p", (void *)ch, (void *)ch->callbacks);
+ }
+
+ ch->next = s->channels;
+ s->channels = ch;
+ channel->_channel_id = toB_u64((unsigned long)ch);
+
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_reject_channelG_local(sshQ_ServerSession self, $Cont c$cont, B_str reason) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ if (s == NULL || s->pending_channel_open == NULL)
+ return $R_CONT(c$cont, B_None);
+
+ int rc = ssh_message_reply_default(s->pending_channel_open);
+ ssh_message_free(s->pending_channel_open);
+ s->pending_channel_open = NULL;
+ (void)reason;
+ if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0)
+ return $R_CONT(c$cont, B_None);
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_closeG_local(sshQ_ServerSession self, $Cont c$cont) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ if (s == NULL)
+ return $R_CONT(c$cont, B_None);
+ session_close_internal(s, "closed", 0);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD__cleanup_nativeG_local(sshQ_ServerSession self, $Cont c$cont) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ if (s != NULL)
+ session_close_internal(s, "collected", 1);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_channel_accept_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ ssh_server_channel_ctx *ch = server_channel_from_actor(channel);
+ if (server_channel_validate(s, ch) != 0 || ch->pending_req == NULL)
+ return $R_CONT(c$cont, B_None);
+
+ int rc = ssh_message_channel_request_reply_success(ch->pending_req);
+ ssh_message_free(ch->pending_req);
+ ch->pending_req = NULL;
+ ch->pending_req_type = SCHAN_REQ_NONE;
+ if (session_check_reply_rc(s, rc, "SSH channel request accept failed") != 0)
+ return $R_CONT(c$cont, B_None);
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_channel_reject_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_str reason) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ ssh_server_channel_ctx *ch = server_channel_from_actor(channel);
+ if (server_channel_validate(s, ch) != 0 || ch->pending_req == NULL)
+ return $R_CONT(c$cont, B_None);
+
+ int rc = ssh_message_reply_default(ch->pending_req);
+ ssh_message_free(ch->pending_req);
+ ch->pending_req = NULL;
+ ch->pending_req_type = SCHAN_REQ_NONE;
+ (void)reason;
+ if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0)
+ return $R_CONT(c$cont, B_None);
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_channel_writeG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_bytes data) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ ssh_server_channel_ctx *ch = server_channel_from_actor(channel);
+ if (server_channel_validate(s, ch) != 0)
+ return $R_CONT(c$cont, B_None);
+ server_channel_queue_write(ch, data, 0);
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_channel_write_stderrG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_bytes data) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ ssh_server_channel_ctx *ch = server_channel_from_actor(channel);
+ if (server_channel_validate(s, ch) != 0)
+ return $R_CONT(c$cont, B_None);
+ server_channel_queue_write(ch, data, 1);
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_channel_send_eofG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ ssh_server_channel_ctx *ch = server_channel_from_actor(channel);
+ if (server_channel_validate(s, ch) != 0)
+ return $R_CONT(c$cont, B_None);
+ ch->send_eof = 1;
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_channel_send_exit_statusG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_int status) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ ssh_server_channel_ctx *ch = server_channel_from_actor(channel);
+ if (server_channel_validate(s, ch) != 0 || ch->channel == NULL)
+ return $R_CONT(c$cont, B_None);
+ int rc = ssh_channel_request_send_exit_status(ch->channel, fromB_int(status));
+ if (rc != SSH_OK && rc != SSH_AGAIN) {
+ char errmsg[256] = {0};
+ snprintf(errmsg, sizeof(errmsg), "SSH server send exit status failed: %s", ssh_get_error(s->session));
+ server_channel_notify_close(ch, errmsg);
+ ch->state = SCHAN_STATE_ERROR;
+ }
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerSessionD_channel_closeG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) {
+ ssh_server_session_ctx *s = session_from_actor(self);
+ ssh_server_channel_ctx *ch = server_channel_from_actor(channel);
+ if (server_channel_validate(s, ch) != 0)
+ return $R_CONT(c$cont, B_None);
+ ch->send_eof = 1;
+ ch->close_requested = 1;
+ session_drive(s);
+ return $R_CONT(c$cont, B_None);
+}
+
+$R sshQ_ServerChannelD__cleanup_nativeG_local(sshQ_ServerChannel self, $Cont c$cont) {
+ ssh_server_channel_ctx *ch = server_channel_from_actor(self);
+ if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR)
+ return $R_CONT(c$cont, B_None);
+ ssh_server_session_ctx *s = ch->session;
+ if (s == NULL)
+ return $R_CONT(c$cont, B_None);
+ ch->send_eof = 1;
+ ch->close_requested = 1;
+ session_drive(s);
return $R_CONT(c$cont, B_None);
}
diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act
new file mode 100644
index 0000000..d59c60f
--- /dev/null
+++ b/src/test_ssh_server.act
@@ -0,0 +1,5881 @@
+import random
+import testing
+import logging
+
+import acton.rts
+import net
+import ssh
+
+LISTEN_RETRY_LIMIT = 16
+
+
+def _pick_test_port(base: int, span: int, attempt: int, salt: int) -> int:
+ return base + (random.randint(0, span - 1) + attempt * salt) % span
+
+
+def _is_retryable_listen_error(err: str) -> bool:
+ return err.find("Address already in use") >= 0
+
+
+actor ServerClientTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+ NETCONF_END = b"]]>]]>"
+ NETCONF_HELLO = b'\n' + \
+ b'\n' + \
+ b' \n' + \
+ b' urn:ietf:params:netconf:base:1.0\n' + \
+ b' \n' + \
+ b']]>]]>'
+ SERVER_HELLO = b'urn:ietf:params:netconf:base:1.0]]>]]>'
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_session: ?ssh.ServerSession = None
+ var server_buf = b""
+ var client_buf = b""
+ var client_channel: ?ssh.Channel = None
+ var shutdown_started = False
+ var client_closed = False
+ var server_closed = False
+ var session_closed = False
+ var client_channel_closed = False
+ var server_channel_closed = False
+ var client_close_requested = False
+ var server_close_requested = False
+ var client_channel_close_requested = False
+ var session_close_requested = False
+ var shutdown_timer_started = False
+
+ def maybe_finish():
+ if done:
+ return
+ if shutdown_started and client_closed and server_closed and session_closed and \
+ client_channel_closed and server_channel_closed:
+ log.info("test success", None)
+ done = True
+ t.success()
+
+ def request_client_channel_close():
+ if client_channel_close_requested:
+ return
+ client_channel_close_requested = True
+ if client_channel is not None:
+ client_channel.close()
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def request_session_close():
+ if session_close_requested or session_closed:
+ return
+ session_close_requested = True
+ if server_session is not None:
+ log.info("request session close", None)
+ server_session.close()
+ else:
+ log.info("request session close: no session", None)
+
+ def begin_shutdown(reason: str):
+ if shutdown_started:
+ return
+ shutdown_started = True
+ log.info("shutdown start: " + reason, None)
+ request_session_close()
+ request_client_channel_close()
+ request_client_close()
+ if not shutdown_timer_started:
+ shutdown_timer_started = True
+ after 10.0: on_timeout()
+ maybe_finish()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ log.info("test error: " + msg, None)
+ done = True
+ shutdown_started = True
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ log.info("timeout fired", None)
+ finish_error("timeout waiting for SSH test")
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if done:
+ return
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ log.info("server listening", None)
+ server = s
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ log.info("server session ready", None)
+ server_session = sess
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ log.info("auth request for " + req.user, None)
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ log.info("channel open requested", None)
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ server_channel = ch
+ sess.accept_channel(ch)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ log.info("subsystem request: " + name, None)
+ if name == "netconf":
+ ch.accept_request()
+ else:
+ ch.reject_request("unsupported subsystem")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ log.info("exec request: " + cmd, None)
+ ch.reject_request("exec disabled")
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ if shutdown_started:
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ server_buf += data
+ log.info("server data len " + str(len(data)), None)
+ if server_buf.find(NETCONF_END) >= 0:
+ log.info("server got NETCONF hello", None)
+ ch.write(SERVER_HELLO)
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ log.info("server channel closed: " + reason, None)
+ server_channel_closed = True
+ maybe_finish()
+
+ def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ log.info("hostkey state: " + state, None)
+ client.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ log.info("client connect error: " + err, None)
+ finish_error("client connect error: " + err)
+ return
+ log.info("client connected", None)
+ client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ if shutdown_started:
+ request_session_close()
+ maybe_finish()
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ log.info("client channel open error: " + err, None)
+ finish_error("client channel open error: " + err)
+ return
+ log.info("client channel open", None)
+ ch.request_subsystem("netconf")
+ ch.write(NETCONF_HELLO)
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ client_buf += data
+ log.info("client data len " + str(len(data)), None)
+ if client_buf.find(NETCONF_END) >= 0:
+ log.info("client got NETCONF hello", None)
+ begin_shutdown("client got NETCONF hello")
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ log.info("client exit " + str(code), None)
+ return
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ log.info("client channel closed: " + reason, None)
+ client_channel_closed = True
+ if shutdown_started and client is not None:
+ client.close()
+ maybe_finish()
+
+ def start_client():
+ log.info("starting client", None)
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 42000 + (random.randint(0, 2000) + attempts * 37) % 2000
+ log.info("starting server on " + str(port), None)
+ server_close_requested = False
+ server_closed = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+
+
+def _test_ssh_server_subsystem(t: testing.EnvT):
+ """Exercise server subsystem handling with a local client."""
+ ServerClientTester(t)
+
+
+actor ClientCloseFlushTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var payload_len = 1024 * 1024
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_buf = b""
+ var client_close_requested = False
+ var server_close_requested = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def finish_success():
+ if done:
+ return
+ done = True
+ log.info("test success", None)
+ request_client_close()
+ request_server_close()
+ t.success()
+
+ def on_timeout():
+ if done:
+ return
+ finish_error("timeout waiting for client close flush test")
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ log.info("server listening", None)
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ ch.reject_request("exec disabled")
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name == "netconf":
+ ch.accept_request()
+ else:
+ ch.reject_request("unsupported subsystem")
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ server_buf += data
+ return
+ if len(server_buf) != payload_len:
+ finish_error("server received " + str(len(server_buf)) + " bytes, expected " + str(payload_len))
+ return
+ finish_success()
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ return
+
+ def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ client.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("client channel open error: " + err)
+ return
+ ch.request_subsystem("netconf")
+ payload = b"A" * payload_len
+ ch.write(payload)
+ ch.close()
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ return
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 44000 + (random.randint(0, 2000) + attempts * 59) % 2000
+ server_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 15.0: on_timeout()
+
+
+def _test_ssh_client_close_flush(t: testing.EnvT):
+ """Client close should flush queued writes before EOF."""
+ ClientCloseFlushTester(t)
+
+
+actor ClientSessionCloseFlushTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var payload_len = 1024 * 1024
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_buf = b""
+ var client_close_requested = False
+ var server_close_requested = False
+ var saw_terminal_event = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def finish_success():
+ if done:
+ return
+ done = True
+ log.info("test success", None)
+ request_client_close()
+ request_server_close()
+ t.success()
+
+ def check_server_bytes(reason: str):
+ if done or not saw_terminal_event:
+ return
+ if len(server_buf) != payload_len:
+ finish_error(reason + ": server received " + str(len(server_buf)) + " bytes, expected " + str(payload_len))
+ return
+ finish_success()
+
+ def on_timeout():
+ if done:
+ return
+ finish_error("timeout waiting for client session close flush test")
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ log.info("server listening", None)
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ ch.reject_request("exec disabled")
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name == "netconf":
+ ch.accept_request()
+ else:
+ ch.reject_request("unsupported subsystem")
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ if done:
+ return
+ saw_terminal_event = True
+ check_server_bytes("session closed")
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ server_buf += data
+ if not client_close_requested:
+ request_client_close()
+ return
+ if done:
+ return
+ saw_terminal_event = True
+ check_server_bytes("server EOF")
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ if done:
+ return
+ saw_terminal_event = True
+ check_server_bytes("server channel close")
+
+ def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ client.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("client channel open error: " + err)
+ return
+ ch.request_subsystem("netconf")
+ payload = b"B" * payload_len
+ ch.write(payload)
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ return
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 50000 + (random.randint(0, 2000) + attempts * 101) % 2000
+ server_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_client_session_close_flush(t: testing.EnvT):
+ """Client.close should not drop queued channel writes."""
+ ClientSessionCloseFlushTester(t)
+
+
+actor AuthRejectTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var saw_auth = False
+ var got_connect_error = False
+ var client_close_requested = False
+ var server_close_requested = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if saw_auth and got_connect_error:
+ done = True
+ request_client_close()
+ request_server_close()
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error("timeout waiting for auth reject test")
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ saw_auth = True
+ sess.reject_auth("invalid credentials")
+ maybe_finish()
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open after auth reject")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request after auth reject")
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request after auth reject")
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is None:
+ finish_error("expected auth failure but connect succeeded")
+ return
+ got_connect_error = True
+ maybe_finish()
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="wrong-pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 46000 + (random.randint(0, 2000) + attempts * 73) % 2000
+ server_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 2.0: on_timeout()
+
+
+def _test_ssh_auth_reject(t: testing.EnvT):
+ """Server reject_auth should fail client connect without hanging."""
+ AuthRejectTester(t)
+
+
+actor ServerSessionCloseFlushTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var payload_len = 1024 * 1024
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_session: ?ssh.ServerSession = None
+ var client_buf = b""
+ var client_close_requested = False
+ var server_close_requested = False
+ var session_close_requested = False
+ var saw_terminal_event = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def request_session_close():
+ if session_close_requested:
+ return
+ session_close_requested = True
+ if server_session is not None:
+ server_session.close()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def finish_success():
+ if done:
+ return
+ done = True
+ log.info("test success", None)
+ request_client_close()
+ request_server_close()
+ t.success()
+
+ def check_client_bytes(reason: str):
+ if done or not saw_terminal_event:
+ return
+ if len(client_buf) != payload_len:
+ finish_error(reason + ": client received " + str(len(client_buf)) + " bytes, expected " + str(payload_len))
+ return
+ finish_success()
+
+ def on_timeout():
+ if done:
+ return
+ finish_error("timeout waiting for server session close flush test")
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+
+ def on_session(sess: ssh.ServerSession):
+ server_session = sess
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ ch.reject_request("exec disabled")
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name != "netconf":
+ ch.reject_request("unsupported subsystem")
+ return
+ ch.accept_request()
+ payload = b"C" * payload_len
+ ch.write(payload)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ return
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ if done:
+ return
+ saw_terminal_event = True
+ check_client_bytes("client closed")
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("client channel open error: " + err)
+ return
+ ch.request_subsystem("netconf")
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ client_buf += data
+ if not session_close_requested:
+ request_session_close()
+ return
+ if done:
+ return
+ saw_terminal_event = True
+ check_client_bytes("client EOF")
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ if done:
+ return
+ saw_terminal_event = True
+ check_client_bytes("client channel close")
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 52000 + (random.randint(0, 2000) + attempts * 107) % 2000
+ server_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_server_session_close_flush(t: testing.EnvT):
+ """ServerSession.close should not drop queued channel writes."""
+ ServerSessionCloseFlushTester(t)
+
+
+actor ConcurrentChannelTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ PAYLOAD = b"channel-ok"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var num_channels = 8
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_session: ?ssh.ServerSession = None
+ var client_close_requested = False
+ var server_close_requested = False
+ var client_closed = False
+ var server_closed = False
+ var session_closed = False
+ var client_opened = 0
+ var client_closed_channels = 0
+ var server_subsystems = 0
+ var server_closed_channels = 0
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_shutdown():
+ if done:
+ return
+ if client_closed_channels == num_channels and server_closed_channels == num_channels:
+ request_client_close()
+ if server_session is not None:
+ server_session.close()
+ request_server_close()
+
+ def maybe_finish():
+ if done:
+ return
+ if client_closed and server_closed and session_closed and \
+ client_opened == num_channels and server_subsystems == num_channels and \
+ client_closed_channels == num_channels and server_closed_channels == num_channels:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ if server_session is not None:
+ server_session.close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error("timeout waiting for concurrent channel test")
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ server_session = sess
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name != "netconf":
+ finish_error("unexpected subsystem: " + name)
+ return
+ server_subsystems += 1
+ ch.accept_request()
+ ch.write(PAYLOAD)
+ ch.close()
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("unexpected client payload on concurrent channel test")
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("unexpected client stderr on concurrent channel test")
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ server_closed_channels += 1
+ maybe_shutdown()
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ start_channel(c, 0)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ maybe_finish()
+
+ def start_channel(c: ssh.Client, idx: int):
+ if idx >= num_channels:
+ return
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("channel " + str(idx) + " open error: " + err)
+ return
+ client_opened += 1
+ ch.request_subsystem("netconf")
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ if data is None:
+ return
+ if data != PAYLOAD:
+ finish_error("channel " + str(idx) + " received unexpected payload")
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("channel " + str(idx) + " received unexpected stderr")
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ client_closed_channels += 1
+ maybe_shutdown()
+ maybe_finish()
+
+ ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+ start_channel(c, idx + 1)
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 54000 + (random.randint(0, 2000) + attempts * 109) % 2000
+ server_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_concurrent_channels(t: testing.EnvT):
+ """Multiple concurrent channels should share one session cleanly."""
+ ConcurrentChannelTester(t)
+
+
+actor ConcurrentWriteServerChannelHandler(payload_len: int,
+ ack: bytes,
+ on_complete: action() -> None,
+ on_error: action(str) -> None):
+ var recv_buf = b""
+
+ action def on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ recv_buf += data
+ return
+ if len(recv_buf) != payload_len:
+ on_error("server received " + str(len(recv_buf)) + " bytes, expected " + str(payload_len))
+ return
+ ch.write(ack)
+ ch.close()
+
+ action def on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ on_error("unexpected client stderr on concurrent write test")
+
+ action def on_close(ch: ssh.ServerChannel, reason: str):
+ on_complete()
+
+
+actor ConcurrentChannelWriteTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ ACK = b"ack"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var num_channels = 8
+ var payload_len = 256 * 1024
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_session: ?ssh.ServerSession = None
+ var client_close_requested = False
+ var server_close_requested = False
+ var client_closed = False
+ var server_closed = False
+ var session_closed = False
+ var client_opened = 0
+ var client_acks = 0
+ var client_closed_channels = 0
+ var server_subsystems = 0
+ var server_closed_channels = 0
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_shutdown():
+ if done:
+ return
+ if client_closed_channels == num_channels and server_closed_channels == num_channels:
+ request_client_close()
+ if server_session is not None:
+ server_session.close()
+ request_server_close()
+
+ def maybe_finish():
+ if done:
+ return
+ if client_closed and server_closed and session_closed and \
+ client_opened == num_channels and server_subsystems == num_channels and \
+ client_closed_channels == num_channels and server_closed_channels == num_channels:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ if server_session is not None:
+ server_session.close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error("timeout waiting for concurrent channel write test: client_opened=" + str(client_opened) + " client_acks=" + str(client_acks) + " server_subsystems=" + str(server_subsystems) + " client_closed_channels=" + str(client_closed_channels) + " server_closed_channels=" + str(server_closed_channels) + " client_closed=" + str(client_closed) + " session_closed=" + str(session_closed) + " server_closed=" + str(server_closed))
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ server_session = sess
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_server_channel_close():
+ server_closed_channels += 1
+ maybe_shutdown()
+ maybe_finish()
+
+ def on_channel_open(sess: ssh.ServerSession):
+ handler = ConcurrentWriteServerChannelHandler(
+ payload_len,
+ ACK,
+ on_server_channel_close,
+ finish_error,
+ )
+ ch = ssh.ServerChannel(sess, handler.on_data, handler.on_stderr, handler.on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name != "netconf":
+ finish_error("unexpected subsystem: " + name)
+ return
+ server_subsystems += 1
+ ch.accept_request()
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ start_channel(c, 0)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ maybe_finish()
+
+ def start_channel(c: ssh.Client, idx: int):
+ if idx >= num_channels:
+ return
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("channel " + str(idx) + " open error: " + err)
+ return
+ client_opened += 1
+ ch.request_subsystem("netconf")
+ payload = b"W" * payload_len
+ ch.write(payload)
+ ch.send_eof()
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ if data is None:
+ return
+ if data != ACK:
+ finish_error("channel " + str(idx) + " received unexpected payload")
+ return
+ client_acks += 1
+ ch.close()
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("channel " + str(idx) + " received unexpected stderr")
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ client_closed_channels += 1
+ maybe_shutdown()
+ maybe_finish()
+
+ ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+ start_channel(c, idx + 1)
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = _pick_test_port(35000, 25000, attempts, 113)
+ server_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 20.0: on_timeout()
+
+
+def _test_ssh_concurrent_channel_writes(t: testing.EnvT):
+ """Multiple channels should flush writes over one session."""
+ ConcurrentChannelWriteTester(t)
+
+
+actor KeepaliveTrafficTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ PING = b"keepalive-ping"
+ ACK = b"keepalive-ack"
+ ROUNDS = 10
+ KEEPALIVE_INTERVAL = 0.05
+ SEND_DELAY = 0.06
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var session: ?ssh.ServerSession = None
+ var client_channel: ?ssh.Channel = None
+ var client_close_requested = False
+ var server_close_requested = False
+ var client_closed = False
+ var server_closed = False
+ var session_closed = False
+ var server_channel_closed = False
+ var ack_received = False
+ var server_recv_bytes = 0
+
+ def total_bytes() -> int:
+ return len(PING) * ROUNDS
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if ack_received and server_recv_bytes == total_bytes() and \
+ server_channel_closed and client_closed and session_closed and \
+ server_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ if session is not None:
+ session.close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def send_ping(remaining: int):
+ if done or remaining <= 0:
+ return
+ if client_channel is None:
+ finish_error("client channel missing during keepalive write burst")
+ return
+ if client_channel is not None:
+ client_channel.write(PING)
+ if remaining == 1:
+ client_channel.send_eof()
+ return
+ after SEND_DELAY: send_ping(remaining - 1)
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for keepalive traffic test "
+ "(ack_received=" + str(ack_received) +
+ ", server_recv_bytes=" + str(server_recv_bytes) +
+ ", server_channel_closed=" + str(server_channel_closed) +
+ ", client_closed=" + str(client_closed) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < LISTEN_RETRY_LIMIT and _is_retryable_listen_error(err):
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ session = sess
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ session = sess
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name != "echo":
+ finish_error("unexpected subsystem request: " + name)
+ return
+ ch.accept_request()
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ server_recv_bytes += len(data)
+ if server_recv_bytes > total_bytes():
+ finish_error("server received too much data during keepalive test")
+ return
+ if server_recv_bytes == total_bytes():
+ ch.write(ACK)
+ ch.send_eof()
+ ch.close()
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("server received unexpected stderr during keepalive test")
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ log.info("server channel closed: " + reason, None)
+ server_channel_closed = True
+ if session is not None:
+ session.close()
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client = c
+ client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ maybe_finish()
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("client channel open error: " + err)
+ return
+ client_channel = ch
+ ch.request_subsystem("echo")
+ after SEND_DELAY: send_ping(ROUNDS)
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ if data != ACK:
+ finish_error("client received unexpected keepalive reply")
+ return
+ if ack_received:
+ finish_error("client received duplicate keepalive reply")
+ return
+ ack_received = True
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("client received unexpected stderr during keepalive test")
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ log.info("client channel closed: " + reason, None)
+ request_client_close()
+ maybe_finish()
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ keepalive_interval=KEEPALIVE_INTERVAL,
+ )
+
+ def start_server():
+ attempts += 1
+ port = _pick_test_port(36000, 24000, attempts, 271)
+ server_close_requested = False
+ client_close_requested = False
+ client_closed = False
+ server_closed = False
+ session_closed = False
+ server_channel_closed = False
+ ack_received = False
+ server_recv_bytes = 0
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ keepalive_interval=KEEPALIVE_INTERVAL,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_keepalive_traffic(t: testing.EnvT):
+ """Keepalives should not disrupt an active channel."""
+ KeepaliveTrafficTester(t)
+
+
+actor CrossHandleOwnershipTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ CLIENT1_PAYLOAD = b"client-one"
+ CLIENT2_PAYLOAD = b"client-two"
+ BAD_CLIENT_PAYLOAD = b"wrong-client"
+ SERVER1_REPLY = b"server-one"
+ SERVER2_REPLY = b"server-two"
+ BAD_SERVER_PAYLOAD = b"wrong-server"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client1: ?ssh.Client = None
+ var client2: ?ssh.Client = None
+ var client_channel1: ?ssh.Channel = None
+ var client_channel2: ?ssh.Channel = None
+ var server_session1: ?ssh.ServerSession = None
+ var server_session2: ?ssh.ServerSession = None
+ var server_channel1: ?ssh.ServerChannel = None
+ var server_channel2: ?ssh.ServerChannel = None
+ var client1_buf = b""
+ var client2_buf = b""
+ var server1_buf = b""
+ var server2_buf = b""
+ var io_started = False
+ var replies_sent = False
+ var shutdown_started = False
+ var client1_opened = False
+ var client2_opened = False
+ var client1_ok = False
+ var client2_ok = False
+ var server1_ok = False
+ var server2_ok = False
+ var client_close_count = 0
+ var client_channel_close_count = 0
+ var server_channel_close_count = 0
+ var session_close_count = 0
+ var server_closed = False
+ var client1_close_requested = False
+ var client2_close_requested = False
+ var server_close_requested = False
+
+ def request_client1_close():
+ if client1_close_requested:
+ return
+ client1_close_requested = True
+ if client1 is not None:
+ client1.close()
+
+ def request_client2_close():
+ if client2_close_requested:
+ return
+ client2_close_requested = True
+ if client2 is not None:
+ client2.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_request_client_closes():
+ if not shutdown_started:
+ return
+ if client_channel_close_count >= 2:
+ request_client1_close()
+ request_client2_close()
+
+ def maybe_request_server_shutdown():
+ if not shutdown_started:
+ return
+ if session_close_count >= 2:
+ request_server_close()
+
+ def begin_shutdown():
+ if shutdown_started:
+ return
+ shutdown_started = True
+ if client_channel1 is not None:
+ client_channel1.close()
+ if client_channel2 is not None:
+ client_channel2.close()
+ if server_session1 is not None:
+ server_session1.close()
+ if server_session2 is not None:
+ server_session2.close()
+ maybe_request_client_closes()
+ maybe_request_server_shutdown()
+
+ def maybe_finish():
+ if done:
+ return
+ if client1_ok and client2_ok and server1_ok and server2_ok and \
+ client_close_count == 2 and client_channel_close_count == 2 and \
+ server_channel_close_count == 2 and session_close_count == 2 and \
+ server_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ begin_shutdown()
+ t.error(Exception(msg))
+
+ def maybe_start_io():
+ if io_started or done:
+ return
+ if not client1_opened or not client2_opened:
+ return
+ if client1 is not None:
+ if client_channel1 is not None:
+ if client_channel2 is not None:
+ if server_channel1 is not None:
+ if server_channel2 is not None:
+ io_started = True
+ client1.channel_write(client_channel2, BAD_CLIENT_PAYLOAD)
+ client_channel1.write(CLIENT1_PAYLOAD)
+ client_channel2.write(CLIENT2_PAYLOAD)
+
+ def maybe_send_replies():
+ if replies_sent or done:
+ return
+ if not server1_ok or not server2_ok:
+ return
+ if server_session1 is not None:
+ if server_channel1 is not None:
+ if server_channel2 is not None:
+ replies_sent = True
+ server_session1.channel_write(server_channel2, BAD_SERVER_PAYLOAD)
+ server_channel1.write(SERVER1_REPLY)
+ server_channel2.write(SERVER2_REPLY)
+ return
+ finish_error("server channels not ready for reply")
+
+ def maybe_shutdown():
+ if client1_ok and client2_ok:
+ begin_shutdown()
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for cross-handle ownership test "
+ "(client_close_count=" + str(client_close_count) +
+ ", client_channel_close_count=" + str(client_channel_close_count) +
+ ", server_channel_close_count=" + str(server_channel_close_count) +
+ ", session_close_count=" + str(session_close_count) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client1()
+ after 0.01: start_client2()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ if server_channel1 is None:
+ server_session1 = sess
+ ch = ssh.ServerChannel(sess, srv1_on_data, srv1_on_stderr, srv1_on_close)
+ server_channel1 = ch
+ sess.accept_channel(ch)
+ else:
+ if server_channel2 is None:
+ server_session2 = sess
+ ch = ssh.ServerChannel(sess, srv2_on_data, srv2_on_stderr, srv2_on_close)
+ server_channel2 = ch
+ sess.accept_channel(ch)
+ else:
+ finish_error("unexpected extra server channel")
+ return
+ maybe_start_io()
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name != "echo":
+ finish_error("unexpected subsystem: " + name)
+ return
+ ch.accept_request()
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_close_count += 1
+ maybe_request_server_shutdown()
+ maybe_finish()
+
+ def srv1_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ server1_buf += data
+ if server1_buf.find(BAD_CLIENT_PAYLOAD) >= 0:
+ finish_error("cross-client payload leaked into server channel 1")
+ return
+ if len(server1_buf) > len(CLIENT1_PAYLOAD):
+ finish_error("server channel 1 received extra data")
+ return
+ if len(server1_buf) == len(CLIENT1_PAYLOAD):
+ if server1_buf == CLIENT1_PAYLOAD:
+ if server1_ok:
+ finish_error("client-one payload delivered twice")
+ return
+ server1_ok = True
+ else:
+ if server1_buf == CLIENT2_PAYLOAD:
+ if server2_ok:
+ finish_error("client-two payload delivered twice")
+ return
+ server2_ok = True
+ else:
+ finish_error("server channel 1 received wrong payload")
+ return
+ maybe_send_replies()
+
+ def srv1_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("server channel 1 received unexpected stderr")
+
+ def srv1_on_close(ch: ssh.ServerChannel, reason: str):
+ server_channel_close_count += 1
+ maybe_finish()
+
+ def srv2_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ server2_buf += data
+ if server2_buf.find(BAD_CLIENT_PAYLOAD) >= 0:
+ finish_error("cross-client payload leaked into server channel 2")
+ return
+ if len(server2_buf) > len(CLIENT2_PAYLOAD):
+ finish_error("server channel 2 received extra data")
+ return
+ if len(server2_buf) == len(CLIENT2_PAYLOAD):
+ if server2_buf == CLIENT2_PAYLOAD:
+ if server2_ok:
+ finish_error("client-two payload delivered twice")
+ return
+ server2_ok = True
+ else:
+ if server2_buf == CLIENT1_PAYLOAD:
+ if server1_ok:
+ finish_error("client-one payload delivered twice")
+ return
+ server1_ok = True
+ else:
+ finish_error("server channel 2 received wrong payload")
+ return
+ maybe_send_replies()
+
+ def srv2_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("server channel 2 received unexpected stderr")
+
+ def srv2_on_close(ch: ssh.ServerChannel, reason: str):
+ server_channel_close_count += 1
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client1_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client 1 connect error: " + err)
+ return
+ client_channel1 = ssh.Channel(c, ch1_open, ch1_out, ch1_err, ch1_exit, ch1_close)
+
+ def on_client2_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client 2 connect error: " + err)
+ return
+ client_channel2 = ssh.Channel(c, ch2_open, ch2_out, ch2_err, ch2_exit, ch2_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_close_count += 1
+ maybe_finish()
+
+ def ch1_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("client channel 1 open error: " + err)
+ return
+ client1_opened = True
+ client_channel1 = ch
+ ch.request_subsystem("echo")
+ maybe_start_io()
+
+ def ch1_out(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ client1_buf += data
+ if client1_buf.find(BAD_SERVER_PAYLOAD) >= 0:
+ finish_error("cross-session server payload leaked into client channel 1")
+ return
+ if len(client1_buf) > len(SERVER1_REPLY):
+ finish_error("client channel 1 received extra data")
+ return
+ if len(client1_buf) == len(SERVER1_REPLY):
+ if client1_buf == SERVER1_REPLY:
+ if client1_ok:
+ finish_error("server-one reply delivered twice")
+ return
+ client1_ok = True
+ else:
+ if client1_buf == SERVER2_REPLY:
+ if client2_ok:
+ finish_error("server-two reply delivered twice")
+ return
+ client2_ok = True
+ else:
+ finish_error("client channel 1 received wrong payload")
+ return
+ maybe_shutdown()
+
+ def ch1_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("client channel 1 received unexpected stderr")
+
+ def ch1_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch1_close(ch: ssh.Channel, reason: str):
+ client_channel_close_count += 1
+ maybe_request_client_closes()
+ maybe_finish()
+
+ def ch2_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("client channel 2 open error: " + err)
+ return
+ client2_opened = True
+ client_channel2 = ch
+ ch.request_subsystem("echo")
+ maybe_start_io()
+
+ def ch2_out(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ client2_buf += data
+ if client2_buf.find(BAD_SERVER_PAYLOAD) >= 0:
+ finish_error("cross-session server payload leaked into client channel 2")
+ return
+ if len(client2_buf) > len(SERVER2_REPLY):
+ finish_error("client channel 2 received extra data")
+ return
+ if len(client2_buf) == len(SERVER2_REPLY):
+ if client2_buf == SERVER2_REPLY:
+ if client2_ok:
+ finish_error("server-two reply delivered twice")
+ return
+ client2_ok = True
+ else:
+ if client2_buf == SERVER1_REPLY:
+ if client1_ok:
+ finish_error("server-one reply delivered twice")
+ return
+ client1_ok = True
+ else:
+ finish_error("client channel 2 received wrong payload")
+ return
+ maybe_shutdown()
+
+ def ch2_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("client channel 2 received unexpected stderr")
+
+ def ch2_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch2_close(ch: ssh.Channel, reason: str):
+ client_channel_close_count += 1
+ maybe_request_client_closes()
+ maybe_finish()
+
+ def start_client1():
+ client1 = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client1_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_client2():
+ client2 = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client2_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 60000 + (random.randint(0, 4000) + attempts * 131) % 4000
+ server_close_requested = False
+ client1_close_requested = False
+ client2_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_cross_handle_ownership(t: testing.EnvT):
+ """Wrong client or session should not drive another channel."""
+ CrossHandleOwnershipTester(t)
+
+
+actor UnsupportedRequestTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var saw_subsystem = False
+ var saw_exec = False
+ var client_close_requested = False
+ var server_close_requested = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if saw_subsystem and saw_exec:
+ done = True
+ request_client_close()
+ request_server_close()
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for unsupported request test "
+ "(saw_subsystem=" + str(saw_subsystem) +
+ ", saw_exec=" + str(saw_exec) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ saw_exec = True
+ ch.reject_request("exec disabled")
+ maybe_finish()
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ saw_subsystem = True
+ ch.reject_request("unsupported subsystem")
+ maybe_finish()
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ return
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ ssh.Channel(c, ch_sub_open, ch_sub_out, ch_sub_err, ch_sub_exit, ch_sub_close)
+ after 0.05: start_exec_channel(c)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+
+ def ch_sub_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("subsystem channel open error: " + err)
+ return
+ ch.request_subsystem("unsupported-subsystem")
+
+ def ch_sub_out(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_sub_err(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_sub_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_sub_close(ch: ssh.Channel, reason: str):
+ return
+
+ def start_exec_channel(c: ssh.Client):
+ if done:
+ return
+ ssh.Channel(c, ch_exec_open, ch_exec_out, ch_exec_err, ch_exec_exit, ch_exec_close)
+
+ def ch_exec_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("exec channel open error: " + err)
+ return
+ ch.request_exec("echo denied")
+
+ def ch_exec_out(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_exec_err(ch: ssh.Channel, data: ?bytes):
+ return
+
+ def ch_exec_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_exec_close(ch: ssh.Channel, reason: str):
+ return
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 64000 + (random.randint(0, 1000) + attempts * 97) % 1000
+ server_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_unsupported_requests(t: testing.EnvT):
+ """Server reject_request should handle unsupported subsystem/exec."""
+ UnsupportedRequestTester(t)
+
+
+actor RunCommandTimeoutTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ PARTIAL_OUT = b"partial-out"
+ PARTIAL_ERR = b"partial-err"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var run_exit_called = False
+ var client_closed = False
+ var server_closed = False
+ var session_closed = False
+ var client_close_requested = False
+ var server_close_requested = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if run_exit_called and client_closed and server_closed and session_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for RunCommand timeout test "
+ "(run_exit=" + str(run_exit_called) +
+ ", client_closed=" + str(client_closed) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ if cmd != "stall":
+ finish_error("unexpected exec command: " + cmd)
+ return
+ ch.accept_request()
+ ch.write(PARTIAL_OUT)
+ ch.write_stderr(PARTIAL_ERR)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ return
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client = c
+ ssh.RunCommand(c, "stall", on_run_exit, timeout=0.1)
+
+ def on_run_exit(ch: ssh.Channel,
+ code: int,
+ sig: ?str,
+ stdout: bytes,
+ stderr: bytes,
+ err: ?str):
+ if run_exit_called:
+ finish_error("RunCommand on_exit called more than once")
+ return
+ run_exit_called = True
+ if err != "timeout":
+ finish_error("expected timeout error, got " + str(err))
+ return
+ if len(stdout) > len(PARTIAL_OUT) or PARTIAL_OUT.find(stdout) != 0:
+ finish_error("unexpected stdout on timeout")
+ return
+ if len(stderr) > len(PARTIAL_ERR) or PARTIAL_ERR.find(stderr) != 0:
+ finish_error("unexpected stderr on timeout")
+ return
+ if code != 0:
+ finish_error("timeout should not invent exit code")
+ return
+ if sig is not None:
+ finish_error("timeout should not invent exit signal")
+ return
+ request_client_close()
+ maybe_finish()
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ if session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 52000 + (random.randint(0, 2000) + attempts * 149) % 2000
+ server_close_requested = False
+ client_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_runcommand_timeout(t: testing.EnvT):
+ """RunCommand timeout should complete exactly once with buffered output."""
+ RunCommandTimeoutTester(t)
+
+
+actor RunCommandExitStatusTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ OUT1 = b"alpha-"
+ OUT2 = b"beta"
+ ERR1 = b"warn-"
+ ERR2 = b"detail"
+ OUT = OUT1 + OUT2
+ ERR = ERR1 + ERR2
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var run_exit_called = False
+ var client_closed = False
+ var server_closed = False
+ var session_closed = False
+ var client_close_requested = False
+ var server_close_requested = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if run_exit_called and client_closed and server_closed and session_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for RunCommand exit status test "
+ "(run_exit=" + str(run_exit_called) +
+ ", client_closed=" + str(client_closed) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ if cmd != "report":
+ finish_error("unexpected exec command: " + cmd)
+ return
+ ch.accept_request()
+ ch.write(OUT1)
+ ch.write(OUT2)
+ ch.write_stderr(ERR1)
+ ch.write_stderr(ERR2)
+ ch.send_exit_status(23)
+ ch.send_eof()
+ ch.close()
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ return
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client = c
+ ssh.RunCommand(c, "report", on_run_exit)
+
+ def on_run_exit(ch: ssh.Channel,
+ code: int,
+ sig: ?str,
+ stdout: bytes,
+ stderr: bytes,
+ err: ?str):
+ if run_exit_called:
+ finish_error("RunCommand on_exit called more than once")
+ return
+ run_exit_called = True
+ if err is not None:
+ finish_error("unexpected RunCommand error: " + str(err))
+ return
+ if stdout != OUT:
+ finish_error("unexpected stdout payload")
+ return
+ if stderr != ERR:
+ finish_error("unexpected stderr payload")
+ return
+ if code != 23:
+ finish_error("unexpected exit code: " + str(code))
+ return
+ if sig is not None:
+ finish_error("unexpected exit signal: " + str(sig))
+ return
+ request_client_close()
+ maybe_finish()
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ if session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 58000 + (random.randint(0, 2000) + attempts * 151) % 2000
+ server_close_requested = False
+ client_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 5.0: on_timeout()
+
+
+def _test_ssh_runcommand_exit_status(t: testing.EnvT):
+ """RunCommand should collect stdout, stderr, and exit status."""
+ RunCommandExitStatusTester(t)
+
+
+actor ConcurrentRunCommandTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ NUM_CMDS = 4
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_channel_closes = 0
+ var run_exits = 0
+ var client_closed = False
+ var server_closed = False
+ var session_closed = False
+ var client_close_requested = False
+ var server_close_requested = False
+ var run0_done = False
+ var run1_done = False
+ var run2_done = False
+ var run3_done = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def out1(cmd: str) -> bytes:
+ if cmd == "report0":
+ return b"out0-a:" * 512
+ if cmd == "report1":
+ return b"out1-a:" * 512
+ if cmd == "report2":
+ return b"out2-a:" * 512
+ return b"out3-a:" * 512
+
+ def out2(cmd: str) -> bytes:
+ if cmd == "report0":
+ return b"out0-b:" * 512
+ if cmd == "report1":
+ return b"out1-b:" * 512
+ if cmd == "report2":
+ return b"out2-b:" * 512
+ return b"out3-b:" * 512
+
+ def err1(cmd: str) -> bytes:
+ if cmd == "report0":
+ return b"err0-a:" * 256
+ if cmd == "report1":
+ return b"err1-a:" * 256
+ if cmd == "report2":
+ return b"err2-a:" * 256
+ return b"err3-a:" * 256
+
+ def err2(cmd: str) -> bytes:
+ if cmd == "report0":
+ return b"err0-b:" * 256
+ if cmd == "report1":
+ return b"err1-b:" * 256
+ if cmd == "report2":
+ return b"err2-b:" * 256
+ return b"err3-b:" * 256
+
+ def full_out(cmd: str) -> bytes:
+ return out1(cmd) + out2(cmd)
+
+ def full_err(cmd: str) -> bytes:
+ return err1(cmd) + err2(cmd)
+
+ def exit_code(cmd: str) -> int:
+ if cmd == "report0":
+ return 30
+ if cmd == "report1":
+ return 31
+ if cmd == "report2":
+ return 32
+ return 33
+
+ def maybe_finish():
+ if done:
+ return
+ if run_exits == NUM_CMDS and server_channel_closes == NUM_CMDS and \
+ client_closed and server_closed and session_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def mark_run_done(cmd: str):
+ if cmd == "report0":
+ if run0_done:
+ finish_error("RunCommand report0 completed more than once")
+ return
+ run0_done = True
+ return
+ if cmd == "report1":
+ if run1_done:
+ finish_error("RunCommand report1 completed more than once")
+ return
+ run1_done = True
+ return
+ if cmd == "report2":
+ if run2_done:
+ finish_error("RunCommand report2 completed more than once")
+ return
+ run2_done = True
+ return
+ if cmd == "report3":
+ if run3_done:
+ finish_error("RunCommand report3 completed more than once")
+ return
+ run3_done = True
+ return
+ finish_error("unexpected RunCommand completion for " + cmd)
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for concurrent RunCommand test "
+ "(run_exits=" + str(run_exits) +
+ ", server_channel_closes=" + str(server_channel_closes) +
+ ", client_closed=" + str(client_closed) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def send_report(ch: ssh.ServerChannel, cmd: str, exit_first: bool):
+ if exit_first:
+ ch.send_exit_status(exit_code(cmd))
+ ch.write(out1(cmd))
+ ch.write_stderr(err1(cmd))
+ ch.write(out2(cmd))
+ ch.write_stderr(err2(cmd))
+ if not exit_first:
+ ch.send_exit_status(exit_code(cmd))
+ ch.send_eof()
+ ch.close()
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ if cmd != "report0" and cmd != "report1" and cmd != "report2" and cmd != "report3":
+ finish_error("unexpected exec command: " + cmd)
+ return
+ ch.accept_request()
+ if cmd == "report0" or cmd == "report2":
+ send_report(ch, cmd, True)
+ else:
+ send_report(ch, cmd, False)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ server_channel_closes += 1
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client = c
+ start_run(c, 0)
+ start_run(c, 1)
+ start_run(c, 2)
+ start_run(c, 3)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ if session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def start_run(c: ssh.Client, idx: int):
+ cmd = "report" + str(idx)
+
+ def on_run_exit(ch: ssh.Channel,
+ code: int,
+ sig: ?str,
+ stdout: bytes,
+ stderr: bytes,
+ err: ?str):
+ if err is not None:
+ finish_error("unexpected RunCommand error for " + cmd + ": " + str(err))
+ return
+ if stdout != full_out(cmd):
+ finish_error("unexpected stdout payload for " + cmd)
+ return
+ if stderr != full_err(cmd):
+ finish_error("unexpected stderr payload for " + cmd)
+ return
+ if code != exit_code(cmd):
+ finish_error("unexpected exit code for " + cmd + ": " + str(code))
+ return
+ if sig is not None:
+ finish_error("unexpected exit signal for " + cmd + ": " + str(sig))
+ return
+ mark_run_done(cmd)
+ if done:
+ return
+ run_exits += 1
+ if run_exits == NUM_CMDS:
+ request_client_close()
+ maybe_finish()
+
+ ssh.RunCommand(c, cmd, on_run_exit)
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 56000 + (random.randint(0, 2000) + attempts * 157) % 2000
+ server_close_requested = False
+ client_close_requested = False
+ server_channel_closes = 0
+ run_exits = 0
+ client_closed = False
+ server_closed = False
+ session_closed = False
+ run0_done = False
+ run1_done = False
+ run2_done = False
+ run3_done = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_concurrent_runcommand(t: testing.EnvT):
+ """RunCommand should isolate concurrent stderr and exit paths."""
+ ConcurrentRunCommandTester(t)
+
+
+actor ConcurrentRunCommandRejectTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ OK0_OUT = b"ok0-out"
+ OK0_ERR = b"ok0-err"
+ OK1_OUT = b"ok1-out"
+ OK1_ERR = b"ok1-err"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_channel_closes = 0
+ var run_exits = 0
+ var client_closed = False
+ var server_closed = False
+ var session_closed = False
+ var client_close_requested = False
+ var server_close_requested = False
+ var ok0_done = False
+ var deny_done = False
+ var ok1_done = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if run_exits == 3 and server_channel_closes == 3 and \
+ client_closed and server_closed and session_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def mark_done(cmd: str):
+ if cmd == "ok0":
+ if ok0_done:
+ finish_error("ok0 completed more than once")
+ return
+ ok0_done = True
+ return
+ if cmd == "deny":
+ if deny_done:
+ finish_error("deny completed more than once")
+ return
+ deny_done = True
+ return
+ if cmd == "ok1":
+ if ok1_done:
+ finish_error("ok1 completed more than once")
+ return
+ ok1_done = True
+ return
+ finish_error("unexpected completion for " + cmd)
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for concurrent RunCommand reject test "
+ "(run_exits=" + str(run_exits) +
+ ", server_channel_closes=" + str(server_channel_closes) +
+ ", client_closed=" + str(client_closed) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ if cmd == "deny":
+ ch.reject_request("denied")
+ ch.close()
+ return
+ if cmd == "ok0":
+ ch.accept_request()
+ ch.write(OK0_OUT)
+ ch.write_stderr(OK0_ERR)
+ ch.send_exit_status(40)
+ ch.send_eof()
+ ch.close()
+ return
+ if cmd == "ok1":
+ ch.accept_request()
+ ch.send_exit_status(41)
+ ch.write(OK1_OUT)
+ ch.write_stderr(OK1_ERR)
+ ch.send_eof()
+ ch.close()
+ return
+ finish_error("unexpected exec command: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ return
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ server_channel_closes += 1
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client = c
+ start_run(c, "ok0")
+ start_run(c, "deny")
+ start_run(c, "ok1")
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ if session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def start_run(c: ssh.Client, cmd: str):
+ def on_run_exit(ch: ssh.Channel,
+ code: int,
+ sig: ?str,
+ stdout: bytes,
+ stderr: bytes,
+ err: ?str):
+ if cmd == "deny":
+ if err is None:
+ finish_error("expected reject error for deny")
+ return
+ if sig is not None:
+ finish_error("unexpected signal on deny: " + str(sig))
+ return
+ elif cmd == "ok0":
+ if err is not None:
+ finish_error("unexpected error for ok0: " + str(err))
+ return
+ if stdout != OK0_OUT or stderr != OK0_ERR or code != 40:
+ finish_error("unexpected result for ok0")
+ return
+ if sig is not None:
+ finish_error("unexpected signal on ok0: " + str(sig))
+ return
+ elif cmd == "ok1":
+ if err is not None:
+ finish_error("unexpected error for ok1: " + str(err))
+ return
+ if stdout != OK1_OUT or stderr != OK1_ERR or code != 41:
+ finish_error("unexpected result for ok1")
+ return
+ if sig is not None:
+ finish_error("unexpected signal on ok1: " + str(sig))
+ return
+ else:
+ finish_error("unexpected command in on_run_exit: " + cmd)
+ return
+ mark_done(cmd)
+ if done:
+ return
+ run_exits += 1
+ if run_exits == 3:
+ request_client_close()
+ maybe_finish()
+
+ ssh.RunCommand(c, cmd, on_run_exit)
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 60000 + (random.randint(0, 2000) + attempts * 163) % 2000
+ server_close_requested = False
+ client_close_requested = False
+ server_channel_closes = 0
+ run_exits = 0
+ client_closed = False
+ server_closed = False
+ session_closed = False
+ ok0_done = False
+ deny_done = False
+ ok1_done = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_concurrent_runcommand_reject(t: testing.EnvT):
+ """A rejected RunCommand should not poison sibling channels."""
+ ConcurrentRunCommandRejectTester(t)
+
+
+actor AuthTimeoutTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var client_done = False
+ var hostkey_seen = False
+ var server_closed = False
+ var session_closed = False
+ var client_close_requested = False
+ var server_close_requested = False
+ var timeout_grace = False
+
+ def maybe_finish():
+ if done:
+ return
+ if session_closed and server_closed:
+ request_client_close()
+ done = True
+ t.success()
+ return
+ if timeout_grace and hostkey_seen and client_done and server_closed:
+ done = True
+ t.success()
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ request_client_close()
+ request_server_close()
+ done = True
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ if not timeout_grace:
+ timeout_grace = True
+ request_client_close()
+ request_server_close()
+ after 0.4: on_timeout()
+ return
+ maybe_finish()
+ if done:
+ return
+ finish_error("timeout waiting for auth timeout test")
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ return
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ finish_error("unexpected auth callback")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ if reason == "SSH authentication timeout":
+ session_closed = True
+ request_server_close()
+ maybe_finish()
+ return
+ if timeout_grace and reason == "Server closed":
+ session_closed = True
+ maybe_finish()
+ return
+ if reason != "SSH authentication timeout":
+ finish_error("unexpected session close reason: " + reason)
+ return
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ log.info("hostkey state: " + state, None)
+ hostkey_seen = True
+ return
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ if hostkey_seen:
+ log.info("client terminal connect error: " + err, None)
+ client_done = True
+ if session_closed:
+ request_server_close()
+ maybe_finish()
+ return
+ finish_error("client connect error: " + err)
+ return
+ client = c
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_done = True
+ if session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 56000 + (random.randint(0, 2000) + attempts * 157) % 2000
+ server_close_requested = False
+ client_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ auth_timeout=1.0,
+ )
+
+ after 0: start_server()
+ after 2.4: on_timeout()
+
+
+def _test_ssh_auth_timeout(t: testing.EnvT):
+ """Server auth timeout should still fire after attach."""
+ if len(t.env.argv) > 3 and t.env.argv[3] == "stress":
+ t.skip("AuthTimeoutTester is deterministic-only in stress mode")
+ AuthTimeoutTester(t)
+
+
+actor HostkeyWaitDisconnectTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var session: ?ssh.ServerSession = None
+ var hostkey_seen = False
+ var client_connect_count = 0
+ var session_close_count = 0
+ var server_close_count = 0
+ var session_close_requested = False
+ var server_close_requested = False
+
+ def maybe_finish():
+ if done:
+ return
+ if hostkey_seen and client_connect_count == 1 and session_close_count == 1 and \
+ server_close_count == 1:
+ done = True
+ t.success()
+
+ def request_session_close():
+ if session_close_requested:
+ return
+ if not hostkey_seen:
+ return
+ if session is None:
+ return
+ session_close_requested = True
+ sess = session
+ if sess is not None:
+ sess.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ request_server_close()
+ done = True
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for hostkey wait disconnect test "
+ "(hostkey_seen=" + str(hostkey_seen) +
+ ", client_connect_count=" + str(client_connect_count) +
+ ", session_close_count=" + str(session_close_count) +
+ ", server_close_count=" + str(server_close_count) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_close_count += 1
+ if server_close_count > 1:
+ finish_error("server close callback called more than once")
+ return
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ session = sess
+ request_session_close()
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ finish_error("unexpected auth callback")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_close_count += 1
+ if session_close_count > 1:
+ finish_error("session close callback called more than once")
+ return
+ request_server_close()
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ log.info("hostkey state: " + state, None)
+ if hostkey_seen:
+ finish_error("hostkey callback called more than once")
+ return
+ hostkey_seen = True
+ request_session_close()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ client_connect_count += 1
+ if client_connect_count > 1:
+ finish_error("client connect callback called more than once")
+ return
+ if err is None:
+ finish_error("client unexpectedly connected")
+ return
+ maybe_finish()
+
+ def on_client_close(c: ssh.Client, reason: str):
+ finish_error("unexpected client close callback: " + reason)
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 36000 + (random.randint(0, 20000) + attempts * 191) % 20000
+ server_close_requested = False
+ session_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ auth_timeout=5.0,
+ )
+
+ after 0: start_server()
+ after 8.0: on_timeout()
+
+
+def _test_ssh_hostkey_wait_disconnect(t: testing.EnvT):
+ """Client should fail promptly if the server closes in hostkey wait."""
+ HostkeyWaitDisconnectTester(t)
+
+
+actor CloseCallbackReentryTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ REPLY = b"close-ok"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var session: ?ssh.ServerSession = None
+ var client_channel: ?ssh.Channel = None
+ var saw_reply = False
+ var client_stdout_eof_count = 0
+ var client_stderr_eof_count = 0
+ var client_exit_count = 0
+ var client_channel_close_count = 0
+ var client_close_count = 0
+ var server_channel_close_count = 0
+ var session_close_count = 0
+ var server_close_count = 0
+ var client_close_requested = False
+ var server_close_requested = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if saw_reply and client_stdout_eof_count == 1 and client_stderr_eof_count == 1 and \
+ client_exit_count == 1 and client_channel_close_count == 1 and \
+ client_close_count == 1 and server_channel_close_count == 1 and \
+ session_close_count == 1 and server_close_count == 1:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for close callback reentry test "
+ "(reply=" + str(saw_reply) +
+ ", client_stdout_eof_count=" + str(client_stdout_eof_count) +
+ ", client_stderr_eof_count=" + str(client_stderr_eof_count) +
+ ", client_exit_count=" + str(client_exit_count) +
+ ", client_channel_close_count=" + str(client_channel_close_count) +
+ ", client_close_count=" + str(client_close_count) +
+ ", server_channel_close_count=" + str(server_channel_close_count) +
+ ", session_close_count=" + str(session_close_count) +
+ ", server_close_count=" + str(server_close_count) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_close_count += 1
+ if server_close_count > 1:
+ finish_error("server close callback called more than once")
+ return
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ session = sess
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ session = sess
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ if cmd != "reentry":
+ finish_error("unexpected exec command: " + cmd)
+ return
+ session = sess
+ ch.accept_request()
+ ch.write(REPLY)
+ ch.send_exit_status(7)
+ ch.send_eof()
+ ch.close()
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_close_count += 1
+ if session_close_count > 1:
+ finish_error("session close callback called more than once")
+ return
+ sess.close()
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("server received unexpected data")
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("server received unexpected stderr")
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ log.info("server channel closed: " + reason, None)
+ server_channel_close_count += 1
+ if server_channel_close_count > 1:
+ finish_error("server channel close callback called more than once")
+ return
+ ch.write(b"after-server-close")
+ ch.send_eof()
+ ch.close()
+ if session is not None:
+ session.close()
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client = c
+ client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_close_count += 1
+ if client_close_count > 1:
+ finish_error("client close callback called more than once")
+ return
+ c.close()
+ maybe_finish()
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("client channel open error: " + err)
+ return
+ client_channel = ch
+ ch.request_exec("reentry")
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ if data is None:
+ client_stdout_eof_count += 1
+ if client_stdout_eof_count > 1:
+ finish_error("client stdout EOF callback called more than once")
+ return
+ if client_channel_close_count > 0:
+ finish_error("client stdout EOF arrived after close callback")
+ return
+ maybe_finish()
+ return
+ if saw_reply:
+ finish_error("client received duplicate reply")
+ return
+ if data != REPLY:
+ finish_error("unexpected reply payload: " + str(data))
+ return
+ saw_reply = True
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("client received unexpected stderr")
+ return
+ client_stderr_eof_count += 1
+ if client_stderr_eof_count > 1:
+ finish_error("client stderr EOF callback called more than once")
+ return
+ if client_channel_close_count > 0:
+ finish_error("client stderr EOF arrived after close callback")
+ return
+ maybe_finish()
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ client_exit_count += 1
+ if client_exit_count > 1:
+ finish_error("client exit callback called more than once")
+ return
+ if client_channel_close_count > 0:
+ finish_error("client exit callback arrived after close callback")
+ return
+ if code != 7:
+ finish_error("unexpected exit code: " + str(code))
+ return
+ if sig is not None:
+ finish_error("unexpected exit signal: " + str(sig))
+ return
+ maybe_finish()
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ log.info("client channel closed: " + reason, None)
+ client_channel_close_count += 1
+ if client_channel_close_count > 1:
+ finish_error("client channel close callback called more than once")
+ return
+ if client_stdout_eof_count != 1:
+ finish_error("client close callback ran before stdout EOF")
+ return
+ if client_stderr_eof_count != 1:
+ finish_error("client close callback ran before stderr EOF")
+ return
+ if client_exit_count != 1:
+ finish_error("client close callback ran before exit callback")
+ return
+ ch.write(b"after-client-close")
+ ch.send_eof()
+ ch.close()
+ request_client_close()
+ maybe_finish()
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 54000 + (random.randint(0, 3000) + attempts * 163) % 3000
+ server_close_requested = False
+ client_close_requested = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 5.0: on_timeout()
+
+
+def _test_ssh_close_callback_reentry(t: testing.EnvT):
+ """Close callbacks should tolerate reentrant close calls."""
+ CloseCallbackReentryTester(t)
+
+
+actor KeyExchangeTimeoutTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var raw_client: ?net.TCPConnection = None
+ var raw_connected = False
+ var raw_bytes_in = 0
+ var session_ready = False
+ var session_closed = False
+ var server_closed = False
+ var raw_close_requested = False
+ var server_close_requested = False
+ var timeout_grace = False
+
+ def request_raw_close():
+ if raw_close_requested:
+ return
+ raw_close_requested = True
+ if raw_client is not None:
+ raw_client.close(on_raw_close)
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if raw_connected and session_ready and session_closed and server_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_raw_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ if not timeout_grace:
+ timeout_grace = True
+ after 5.0: on_timeout()
+ return
+ finish_error(
+ "timeout waiting for key exchange timeout test "
+ "(raw_connected=" + str(raw_connected) +
+ ", session_ready=" + str(session_ready) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) +
+ ", raw_bytes_in=" + str(raw_bytes_in) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_raw_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ session_ready = True
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ finish_error("unexpected auth request on raw TCP idle client")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open on raw TCP idle client")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request on raw TCP idle client")
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request on raw TCP idle client")
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ if reason != "SSH key exchange timeout":
+ finish_error("unexpected session close reason: " + reason)
+ return
+ session_closed = True
+ request_raw_close()
+ request_server_close()
+ maybe_finish()
+
+ def on_raw_connect(c: net.TCPConnection):
+ raw_client = c
+ raw_connected = True
+
+ def on_raw_receive(c: net.TCPConnection, payload: bytes):
+ raw_bytes_in += len(payload)
+
+ def on_raw_error(c: net.TCPConnection, errmsg: str):
+ finish_error("unexpected raw client error: " + errmsg)
+
+ def on_raw_remote_close(c: net.TCPConnection):
+ return
+
+ def on_raw_close(c: net.TCPConnection):
+ return
+
+ def start_raw_client():
+ raw_close_requested = False
+ raw_connected = False
+ raw_bytes_in = 0
+ raw_client = net.TCPConnection(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ port,
+ on_raw_connect,
+ on_raw_receive,
+ on_raw_error,
+ on_raw_remote_close,
+ )
+
+ def start_server():
+ attempts += 1
+ port = _pick_test_port(35000, 25000, attempts, 179)
+ server_close_requested = False
+ server_closed = False
+ session_ready = False
+ session_closed = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ auth_timeout=0.2,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_key_exchange_timeout(t: testing.EnvT):
+ """Attached sessions should time out during key exchange."""
+ KeyExchangeTimeoutTester(t)
+
+
+actor ServerSurvivesGarbageTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ GARBAGE = b"not-ssh\r\nstill-not-ssh\r\n"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var raw_client: ?net.TCPConnection = None
+ var good_client: ?ssh.Client = None
+ var raw_connected = False
+ var raw_bytes_in = 0
+ var raw_close_requested = False
+ var raw_session_seen = False
+ var raw_session_closed = False
+ var good_client_started = False
+ var good_client_connected = False
+ var good_client_closed = False
+ var good_session_seen = False
+ var good_session_closed = False
+ var server_close_requested = False
+ var server_closed = False
+
+ def maybe_finish():
+ if done:
+ return
+ if raw_connected and raw_session_seen and raw_session_closed and \
+ good_client_connected and good_client_closed and good_session_seen and \
+ good_session_closed and server_closed:
+ done = True
+ t.success()
+
+ def request_raw_close():
+ if raw_close_requested:
+ return
+ raw_close_requested = True
+ if raw_client is not None:
+ raw_client.close(on_raw_close)
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ request_raw_close()
+ if good_client is not None:
+ good_client.close()
+ request_server_close()
+ done = True
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for bad-peer isolation test "
+ "(raw_connected=" + str(raw_connected) +
+ ", raw_bytes_in=" + str(raw_bytes_in) +
+ ", raw_session_seen=" + str(raw_session_seen) +
+ ", raw_session_closed=" + str(raw_session_closed) +
+ ", good_client_started=" + str(good_client_started) +
+ ", good_client_connected=" + str(good_client_connected) +
+ ", good_client_closed=" + str(good_client_closed) +
+ ", good_session_seen=" + str(good_session_seen) +
+ ", good_session_closed=" + str(good_session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_raw_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ if not raw_session_seen:
+ raw_session_seen = True
+ return
+ if not good_session_seen:
+ good_session_seen = True
+ return
+ finish_error("unexpected extra session")
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if not good_client_started:
+ finish_error("raw peer unexpectedly reached auth")
+ return
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ finish_error("unexpected auth request for good client")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ if not raw_session_closed:
+ raw_session_closed = True
+ if good_client_started:
+ finish_error("raw session closed after good client started")
+ return
+ start_good_client()
+ maybe_finish()
+ return
+ if not good_session_closed:
+ good_session_closed = True
+ request_server_close()
+ maybe_finish()
+ return
+ finish_error("unexpected extra session close")
+
+ def on_raw_connect(c: net.TCPConnection):
+ raw_client = c
+ raw_connected = True
+ c.write(GARBAGE)
+ after 0.05: request_raw_close()
+
+ def on_raw_receive(c: net.TCPConnection, payload: bytes):
+ raw_bytes_in += len(payload)
+
+ def on_raw_error(c: net.TCPConnection, errmsg: str):
+ finish_error("unexpected raw client error: " + errmsg)
+
+ def on_raw_remote_close(c: net.TCPConnection):
+ return
+
+ def on_raw_close(c: net.TCPConnection):
+ return
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_good_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("good client connect error: " + err)
+ return
+ good_client_connected = True
+ c.close()
+
+ def on_good_client_close(c: ssh.Client, reason: str):
+ good_client_closed = True
+ if good_session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def start_raw_client():
+ raw_close_requested = False
+ raw_connected = False
+ raw_bytes_in = 0
+ raw_client = net.TCPConnection(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ port,
+ on_raw_connect,
+ on_raw_receive,
+ on_raw_error,
+ on_raw_remote_close,
+ )
+
+ def start_good_client():
+ if good_client_started:
+ return
+ good_client_started = True
+ good_client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_good_client_connect,
+ on_good_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 39000 + (random.randint(0, 12000) + attempts * 211) % 12000
+ server_close_requested = False
+ server_closed = False
+ raw_session_seen = False
+ raw_session_closed = False
+ good_client_started = False
+ good_client_connected = False
+ good_client_closed = False
+ good_session_seen = False
+ good_session_closed = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_server_survives_garbage(t: testing.EnvT):
+ """One bad TCP peer should not poison the next SSH session."""
+ ServerSurvivesGarbageTester(t)
+
+
+actor RapidDisconnectChurnTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ ROUNDS = 25
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var raw_client: ?net.TCPConnection = None
+ var good_client: ?ssh.Client = None
+ var rounds_started = 0
+ var raw_connect_count = 0
+ var raw_session_seen_count = 0
+ var raw_session_closed_count = 0
+ var raw_close_requested = False
+ var good_client_started = False
+ var good_client_connected = False
+ var good_client_closed = False
+ var good_session_seen = False
+ var good_session_closed = False
+ var server_close_requested = False
+ var server_closed = False
+
+ def maybe_finish():
+ if done:
+ return
+ if raw_connect_count == ROUNDS and raw_session_seen_count == ROUNDS and \
+ raw_session_closed_count == ROUNDS and good_client_connected and \
+ good_client_closed and good_session_seen and good_session_closed and \
+ server_closed:
+ done = True
+ t.success()
+
+ def request_raw_close():
+ if raw_close_requested:
+ return
+ raw_close_requested = True
+ if raw_client is not None:
+ raw_client.close(on_raw_close)
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ request_raw_close()
+ if good_client is not None:
+ good_client.close()
+ request_server_close()
+ done = True
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for rapid disconnect churn test "
+ "(rounds_started=" + str(rounds_started) +
+ ", raw_connect_count=" + str(raw_connect_count) +
+ ", raw_session_seen_count=" + str(raw_session_seen_count) +
+ ", raw_session_closed_count=" + str(raw_session_closed_count) +
+ ", good_client_started=" + str(good_client_started) +
+ ", good_client_connected=" + str(good_client_connected) +
+ ", good_client_closed=" + str(good_client_closed) +
+ ", good_session_seen=" + str(good_session_seen) +
+ ", good_session_closed=" + str(good_session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_next_raw()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ if good_client_started:
+ if good_session_seen:
+ finish_error("good session callback called more than once")
+ return
+ good_session_seen = True
+ return
+ raw_session_seen_count += 1
+ if raw_session_seen_count > ROUNDS:
+ finish_error("too many raw sessions")
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if not good_client_started:
+ finish_error("raw churn unexpectedly reached auth")
+ return
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ finish_error("unexpected auth request for good client")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ if good_client_started and raw_session_closed_count == ROUNDS:
+ if good_session_closed:
+ finish_error("good session close callback called more than once")
+ return
+ good_session_closed = True
+ request_server_close()
+ maybe_finish()
+ return
+ raw_session_closed_count += 1
+ if raw_session_closed_count > ROUNDS:
+ finish_error("too many raw session closes")
+ return
+ if raw_session_closed_count < ROUNDS:
+ start_next_raw()
+ return
+ start_good_client()
+
+ def on_raw_connect(c: net.TCPConnection):
+ raw_client = c
+ raw_connect_count += 1
+ after 0.01: request_raw_close()
+
+ def on_raw_receive(c: net.TCPConnection, payload: bytes):
+ return
+
+ def on_raw_error(c: net.TCPConnection, errmsg: str):
+ finish_error("unexpected raw client error: " + errmsg)
+
+ def on_raw_remote_close(c: net.TCPConnection):
+ return
+
+ def on_raw_close(c: net.TCPConnection):
+ return
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_good_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("good client connect error: " + err)
+ return
+ good_client_connected = True
+ c.close()
+
+ def on_good_client_close(c: ssh.Client, reason: str):
+ good_client_closed = True
+ if good_session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def start_next_raw():
+ if good_client_started:
+ finish_error("started raw churn after good client")
+ return
+ if rounds_started >= ROUNDS:
+ finish_error("started too many raw churn rounds")
+ return
+ rounds_started += 1
+ raw_close_requested = False
+ raw_client = net.TCPConnection(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ port,
+ on_raw_connect,
+ on_raw_receive,
+ on_raw_error,
+ on_raw_remote_close,
+ )
+
+ def start_good_client():
+ if good_client_started:
+ return
+ good_client_started = True
+ good_client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_good_client_connect,
+ on_good_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 39000 + (random.randint(0, 12000) + attempts * 223) % 12000
+ rounds_started = 0
+ raw_connect_count = 0
+ raw_session_seen_count = 0
+ raw_session_closed_count = 0
+ raw_close_requested = False
+ good_client_started = False
+ good_client_connected = False
+ good_client_closed = False
+ good_session_seen = False
+ good_session_closed = False
+ server_close_requested = False
+ server_closed = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_rapid_disconnect_churn(t: testing.EnvT):
+ """Rapid raw disconnect churn should not poison the next SSH client."""
+ RapidDisconnectChurnTester(t)
+
+
+actor CloseRejectsLateWritesTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ START = b"start"
+ ACK = b"ack"
+ LATE = b"late-data"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var session: ?ssh.ServerSession = None
+ var client_channel: ?ssh.Channel = None
+ var server_buf = b""
+ var ack_received = False
+ var shutdown_started = False
+ var client_closed = False
+ var server_channel_closed = False
+ var session_closed = False
+ var server_closed = False
+ var client_close_requested = False
+ var server_close_requested = False
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_finish():
+ if done:
+ return
+ if ack_received and client_closed and server_channel_closed and \
+ session_closed and server_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def late_write_burst(remaining: int):
+ if done or remaining <= 0:
+ return
+ if client_channel is not None:
+ client_channel.write(LATE)
+ after 0: late_write_burst(remaining - 1)
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for late-write close test "
+ "(ack_received=" + str(ack_received) +
+ ", client_closed=" + str(client_closed) +
+ ", server_channel_closed=" + str(server_channel_closed) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ session = sess
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ session = sess
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name != "echo":
+ finish_error("unexpected subsystem request: " + name)
+ return
+ ch.accept_request()
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ server_buf += data
+ if server_buf.find(LATE) >= 0:
+ finish_error("late write reached server after close started")
+ return
+ if not ack_received and server_buf.find(START) >= 0:
+ ch.write(ACK)
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("server received unexpected stderr")
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ log.info("server channel closed: " + reason, None)
+ server_channel_closed = True
+ if session is not None:
+ session.close()
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client = c
+ client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ if session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def ch_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("client channel open error: " + err)
+ return
+ client_channel = ch
+ ch.request_subsystem("echo")
+ ch.write(START)
+
+ def ch_out(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ if data != ACK:
+ finish_error("unexpected client payload: " + str(data))
+ return
+ if ack_received:
+ finish_error("client received duplicate ACK")
+ return
+ ack_received = True
+ if not shutdown_started:
+ shutdown_started = True
+ request_client_close()
+ after 0: late_write_burst(32)
+
+ def ch_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("client received unexpected stderr")
+
+ def ch_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch_close(ch: ssh.Channel, reason: str):
+ log.info("client channel closed: " + reason, None)
+ request_client_close()
+ maybe_finish()
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 50000 + (random.randint(0, 3000) + attempts * 181) % 3000
+ server_close_requested = False
+ client_close_requested = False
+ ack_received = False
+ shutdown_started = False
+ client_closed = False
+ server_channel_closed = False
+ session_closed = False
+ server_closed = False
+ server_buf = b""
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 5.0: on_timeout()
+
+
+def _test_ssh_close_rejects_late_writes(t: testing.EnvT):
+ """Channel writes should be ignored once parent close starts."""
+ CloseRejectsLateWritesTester(t)
+
+
+actor SessionLimitTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client1: ?ssh.Client = None
+ var client2: ?ssh.Client = None
+ var client3: ?ssh.Client = None
+ var client1_close_requested = False
+ var client3_close_requested = False
+ var server_close_requested = False
+ var client1_ready = False
+ var client1_closed = False
+ var client2_failed = False
+ var client3_started = False
+ var client3_ready = False
+ var client3_closed = False
+ var server_closed = False
+ var session_ready_count = 0
+ var session_close_count = 0
+
+ def request_client1_close():
+ if client1_close_requested:
+ return
+ client1_close_requested = True
+ if client1 is not None:
+ client1.close()
+
+ def request_client3_close():
+ if client3_close_requested:
+ return
+ client3_close_requested = True
+ if client3 is not None:
+ client3.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_start_client3():
+ if done or client3_started:
+ return
+ if client2_failed and client1_closed and session_close_count >= 1:
+ client3_started = True
+ start_client3()
+
+ def maybe_finish():
+ if done:
+ return
+ if client1_ready and client1_closed and client2_failed and client3_ready and \
+ client3_closed and server_closed and session_ready_count == 2 and \
+ session_close_count == 2:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client1_close()
+ request_client3_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for session limit test "
+ "(client1_ready=" + str(client1_ready) +
+ ", client1_closed=" + str(client1_closed) +
+ ", client2_failed=" + str(client2_failed) +
+ ", client3_ready=" + str(client3_ready) +
+ ", client3_closed=" + str(client3_closed) +
+ ", session_ready_count=" + str(session_ready_count) +
+ ", session_close_count=" + str(session_close_count) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client1()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ session_ready_count += 1
+ if session_ready_count > 2:
+ finish_error("unexpected extra session became ready")
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open during session limit test")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request: " + name)
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_close_count += 1
+ maybe_start_client3()
+ if client3_closed and session_close_count >= 2:
+ request_server_close()
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client1_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client 1 connect error: " + err)
+ return
+ client1 = c
+ client1_ready = True
+ start_client2()
+
+ def on_client2_connect(c: ssh.Client, err: ?str):
+ if err is None:
+ finish_error("client 2 unexpectedly connected over session limit")
+ return
+ client2_failed = True
+ request_client1_close()
+ maybe_start_client3()
+
+ def on_client3_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client 3 connect error after freeing session slot: " + err)
+ return
+ client3 = c
+ client3_ready = True
+ request_client3_close()
+
+ def on_client1_close(c: ssh.Client, reason: str):
+ log.info("client 1 closed: " + reason, None)
+ client1_closed = True
+ maybe_start_client3()
+ maybe_finish()
+
+ def on_client3_close(c: ssh.Client, reason: str):
+ log.info("client 3 closed: " + reason, None)
+ client3_closed = True
+ if session_close_count >= 2:
+ request_server_close()
+ maybe_finish()
+
+ def on_client2_close(c: ssh.Client, reason: str):
+ log.info("client 2 closed: " + reason, None)
+
+ def start_client1():
+ client1_close_requested = False
+ client1 = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client1_connect,
+ on_client1_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_client2():
+ client2 = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client2_connect,
+ on_client2_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_client3():
+ client3_close_requested = False
+ client3 = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client3_connect,
+ on_client3_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 50000 + (random.randint(0, 3000) + attempts * 461) % 3000
+ client1 = None
+ client2 = None
+ client3 = None
+ client1_close_requested = False
+ client3_close_requested = False
+ server_close_requested = False
+ client1_ready = False
+ client1_closed = False
+ client2_failed = False
+ client3_started = False
+ client3_ready = False
+ client3_closed = False
+ server_closed = False
+ session_ready_count = 0
+ session_close_count = 0
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ max_sessions=1,
+ )
+
+ after 0: start_server()
+ after 15.0: on_timeout()
+
+
+def _test_ssh_session_limit(t: testing.EnvT):
+ """A rejected extra session should not poison the server."""
+ SessionLimitTester(t)
+
+
+actor ChannelLimitClientFactory(client: ssh.Client,
+ ch2_open_cb: action(ssh.Channel, ?str) -> None,
+ ch2_out_cb: action(ssh.Channel, ?bytes) -> None,
+ ch2_err_cb: action(ssh.Channel, ?bytes) -> None,
+ ch2_exit_cb: action(ssh.Channel, int, ?str) -> None,
+ ch2_close_cb: action(ssh.Channel, str) -> None,
+ ch3_open_cb: action(ssh.Channel, ?str) -> None,
+ ch3_out_cb: action(ssh.Channel, ?bytes) -> None,
+ ch3_err_cb: action(ssh.Channel, ?bytes) -> None,
+ ch3_exit_cb: action(ssh.Channel, int, ?str) -> None,
+ ch3_close_cb: action(ssh.Channel, str) -> None):
+ var _ch2: ?ssh.Channel = None
+ var _ch3: ?ssh.Channel = None
+
+ action def open_ch2() -> None:
+ _ch2 = ssh.Channel(client, ch2_open_cb, ch2_out_cb, ch2_err_cb, ch2_exit_cb, ch2_close_cb)
+
+ action def open_ch3() -> None:
+ _ch3 = ssh.Channel(client, ch3_open_cb, ch3_out_cb, ch3_err_cb, ch3_exit_cb, ch3_close_cb)
+
+
+actor ChannelLimitTester(t: testing.EnvT):
+ log = logging.Logger(t.log_handler)
+
+ MSG1 = b"one"
+ ACK1 = b"ack-one"
+ MSG3 = b"three"
+ ACK3 = b"ack-three"
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var session: ?ssh.ServerSession = None
+ var client: ?ssh.Client = None
+ var factory: ?ChannelLimitClientFactory = None
+ var ch1: ?ssh.Channel = None
+ var ch2: ?ssh.Channel = None
+ var ch3: ?ssh.Channel = None
+ var client_close_requested = False
+ var server_close_requested = False
+ var ch3_started = False
+ var ch1_acked = False
+ var ch2_failed = False
+ var ch1_closed = False
+ var ch3_acked = False
+ var ch3_closed = False
+ var client_closed = False
+ var session_closed = False
+ var server_closed = False
+ var server_channel_open_count = 0
+ var server_channel_close_count = 0
+
+ def request_client_close():
+ if client_close_requested:
+ return
+ client_close_requested = True
+ if client is not None:
+ client.close()
+
+ def request_server_close():
+ if server_close_requested:
+ return
+ server_close_requested = True
+ if server is not None:
+ server.close()
+
+ def maybe_start_ch3():
+ if done or ch3_started:
+ return
+ if not (ch2_failed and ch1_closed and server_channel_close_count >= 1):
+ return
+ if factory is not None:
+ ch3_started = True
+ factory.open_ch3()
+ else:
+ finish_error("client factory missing when reopening channel after limit rejection")
+
+ def maybe_finish():
+ if done:
+ return
+ if ch1_acked and ch2_failed and ch1_closed and ch3_acked and ch3_closed and \
+ client_closed and session_closed and server_closed and \
+ server_channel_open_count == 2 and server_channel_close_count == 2:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ request_client_close()
+ request_server_close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for channel limit test "
+ "(ch1_acked=" + str(ch1_acked) +
+ ", ch2_failed=" + str(ch2_failed) +
+ ", ch1_closed=" + str(ch1_closed) +
+ ", ch3_acked=" + str(ch3_acked) +
+ ", ch3_closed=" + str(ch3_closed) +
+ ", client_closed=" + str(client_closed) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) +
+ ", server_channel_open_count=" + str(server_channel_open_count) +
+ ", server_channel_close_count=" + str(server_channel_close_count) + ")"
+ )
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if attempts < 5:
+ request_server_close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ session = sess
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ session = sess
+ server_channel_open_count += 1
+ if server_channel_open_count > 2:
+ finish_error("unexpected extra server channel open")
+ return
+ ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close)
+ sess.accept_channel(ch)
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request: " + cmd)
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ if name != "echo":
+ finish_error("unexpected subsystem request: " + name)
+ return
+ ch.accept_request()
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ request_server_close()
+ maybe_finish()
+
+ def srv_on_data(ch: ssh.ServerChannel, data: ?bytes):
+ if data is None:
+ return
+ if data == MSG1:
+ ch.write(ACK1)
+ return
+ if data == MSG3:
+ ch.write(ACK3)
+ return
+ finish_error("unexpected server payload: " + str(data))
+
+ def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes):
+ if data is not None:
+ finish_error("server received unexpected stderr")
+
+ def srv_on_close(ch: ssh.ServerChannel, reason: str):
+ log.info("server channel closed: " + reason, None)
+ server_channel_close_count += 1
+ maybe_start_ch3()
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client = c
+ factory = ChannelLimitClientFactory(c, ch2_open, ch2_out, ch2_err, ch2_exit, ch2_close,
+ ch3_open, ch3_out, ch3_err, ch3_exit, ch3_close)
+ ch1 = ssh.Channel(c, ch1_open, ch1_out, ch1_err, ch1_exit, ch1_close)
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ client_closed = True
+ if session_closed:
+ request_server_close()
+ maybe_finish()
+
+ def ch1_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("channel 1 open error: " + err)
+ return
+ ch1 = ch
+ ch.request_subsystem("echo")
+ ch.write(MSG1)
+
+ def ch1_out(ch: ssh.Channel, data: ?bytes):
+ if data is None:
+ return
+ if data != ACK1:
+ finish_error("channel 1 received unexpected payload: " + str(data))
+ return
+ if ch1_acked:
+ finish_error("channel 1 received duplicate ACK")
+ return
+ ch1_acked = True
+ if factory is not None:
+ factory.open_ch2()
+ else:
+ finish_error("client factory missing for second channel open")
+
+ def ch1_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("channel 1 received unexpected stderr")
+
+ def ch1_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch1_close(ch: ssh.Channel, reason: str):
+ log.info("channel 1 closed: " + reason, None)
+ ch1_closed = True
+ maybe_start_ch3()
+ maybe_finish()
+
+ def ch2_open(ch: ssh.Channel, err: ?str):
+ if err is None:
+ finish_error("channel 2 unexpectedly opened over channel limit")
+ return
+ ch2_failed = True
+ if ch1 is not None:
+ ch1.close()
+ maybe_start_ch3()
+
+ def ch2_out(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("channel 2 received unexpected data")
+
+ def ch2_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("channel 2 received unexpected stderr")
+
+ def ch2_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch2_close(ch: ssh.Channel, reason: str):
+ return
+
+ def ch3_open(ch: ssh.Channel, err: ?str):
+ if err is not None:
+ finish_error("channel 3 open error after freeing slot: " + err)
+ return
+ ch3 = ch
+ ch.request_subsystem("echo")
+ ch.write(MSG3)
+
+ def ch3_out(ch: ssh.Channel, data: ?bytes):
+ if data is None:
+ return
+ if data != ACK3:
+ finish_error("channel 3 received unexpected payload: " + str(data))
+ return
+ if ch3_acked:
+ finish_error("channel 3 received duplicate ACK")
+ return
+ ch3_acked = True
+ ch.close()
+
+ def ch3_err(ch: ssh.Channel, data: ?bytes):
+ if data is not None:
+ finish_error("channel 3 received unexpected stderr")
+
+ def ch3_exit(ch: ssh.Channel, code: int, sig: ?str):
+ return
+
+ def ch3_close(ch: ssh.Channel, reason: str):
+ log.info("channel 3 closed: " + reason, None)
+ ch3_closed = True
+ request_client_close()
+ maybe_finish()
+
+ def start_client():
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = 50000 + (random.randint(0, 3000) + attempts * 503) % 3000
+ session = None
+ client = None
+ factory = None
+ ch1 = None
+ ch2 = None
+ ch3 = None
+ client_close_requested = False
+ server_close_requested = False
+ ch3_started = False
+ ch1_acked = False
+ ch2_failed = False
+ ch1_closed = False
+ ch3_acked = False
+ ch3_closed = False
+ client_closed = False
+ session_closed = False
+ server_closed = False
+ server_channel_open_count = 0
+ server_channel_close_count = 0
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ max_channels_per_session=1,
+ )
+
+ after 0: start_server()
+ after 10.0: on_timeout()
+
+
+def _test_ssh_channel_limit(t: testing.EnvT):
+ """A rejected extra channel should not poison the session."""
+ ChannelLimitTester(t)
+
+
+actor ServerCleanupTester(t: testing.EnvT):
+ t.require("gc_cleanup")
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var probe: ?ssh.Server = None
+ var server_listened = False
+ var server_collected = False
+ var probe_listened = False
+ var probe_closed = False
+ var gc_started = False
+
+ def maybe_finish():
+ if done:
+ return
+ if server_listened and server_collected and probe_listened and probe_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for GC-cleaned server "
+ "(server_listened=" + str(server_listened) +
+ ", server_collected=" + str(server_collected) +
+ ", probe_listened=" + str(probe_listened) +
+ ", probe_closed=" + str(probe_closed) + ")"
+ )
+
+ def drive_gc():
+ if done:
+ return
+ blob = b"G" * (4 * 1024 * 1024)
+ work = 0
+ for i in range(32768):
+ work += i
+ if len(blob) == work:
+ finish_error("unreachable GC drive path")
+ return
+ acton.rts.gc(t.env.syscap)
+ acton.rts.gc(t.env.syscap)
+ acton.rts.gc(t.env.syscap)
+ acton.rts.gc(t.env.syscap)
+ after 0.02: drive_gc()
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT:
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ server_listened = True
+ server = None
+ if not gc_started:
+ gc_started = True
+ after 0: drive_gc()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ if reason == "collected":
+ server_collected = True
+ if not probe_listened and probe is None:
+ start_probe()
+ maybe_finish()
+
+ def on_probe_listen(s: ssh.Server, err: ?str):
+ probe = s
+ if err is not None:
+ log.info("probe listen error: " + err, None)
+ probe = None
+ if _is_retryable_listen_error(err):
+ after 0.1: start_probe()
+ return
+ finish_error("probe listen error: " + err)
+ return
+ probe_listened = True
+ s.close()
+
+ def on_probe_close(s: ssh.Server, reason: str):
+ log.info("probe closed: " + reason, None)
+ probe_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ finish_error("unexpected session on GC-cleaned server test")
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ finish_error("unexpected auth request on GC-cleaned server test")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open on GC-cleaned server test")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request on GC-cleaned server test")
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request on GC-cleaned server test")
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ finish_error("unexpected session close on GC-cleaned server test")
+
+ def start_probe():
+ if done or probe_listened or probe is not None:
+ return
+ probe = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_probe_listen,
+ on_probe_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ def start_server():
+ attempts += 1
+ port = _pick_test_port(35000, 25000, attempts, 541)
+ server = None
+ probe = None
+ server_listened = False
+ server_collected = False
+ probe_listened = False
+ probe_closed = False
+ gc_started = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 60.0: on_timeout()
+
+
+def _test_ssh_server_gc_cleanup(t: testing.EnvT):
+ """An unreachable server actor should eventually close itself."""
+ ServerCleanupTester(t)
+
+
+actor ClientCleanupTester(t: testing.EnvT):
+ t.require("gc_cleanup")
+ log = logging.Logger(t.log_handler)
+
+ var done = False
+ var attempts = 0
+ var port: int = 0
+ var server: ?ssh.Server = None
+ var client: ?ssh.Client = None
+ var server_close_requested = False
+ var server_listened = False
+ var client_connected = False
+ var session_opened = False
+ var session_closed = False
+ var server_closed = False
+ var gc_started = False
+
+ def maybe_finish():
+ if done:
+ return
+ if server_listened and client_connected and session_opened and session_closed and server_closed:
+ done = True
+ t.success()
+
+ def finish_error(msg: str):
+ if done:
+ return
+ done = True
+ log.info("test error: " + msg, None)
+ if not server_close_requested and server is not None:
+ server_close_requested = True
+ server.close()
+ t.error(Exception(msg))
+
+ def on_timeout():
+ if done:
+ return
+ finish_error(
+ "timeout waiting for GC-cleaned client "
+ "(server_listened=" + str(server_listened) +
+ ", client_connected=" + str(client_connected) +
+ ", session_opened=" + str(session_opened) +
+ ", session_closed=" + str(session_closed) +
+ ", server_closed=" + str(server_closed) + ")"
+ )
+
+ def drive_gc():
+ if done:
+ return
+ blob = b"G" * (4 * 1024 * 1024)
+ work = 0
+ for i in range(32768):
+ work += i
+ if len(blob) == work:
+ finish_error("unreachable GC drive path")
+ return
+ acton.rts.gc(t.env.syscap)
+ acton.rts.gc(t.env.syscap)
+ acton.rts.gc(t.env.syscap)
+ acton.rts.gc(t.env.syscap)
+ after 0.02: drive_gc()
+
+ def on_server_listen(s: ssh.Server, err: ?str):
+ server = s
+ if err is not None:
+ log.info("server listen error: " + err, None)
+ if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT:
+ if not server_close_requested and server is not None:
+ server_close_requested = True
+ server.close()
+ start_server()
+ return
+ finish_error("server listen error: " + err)
+ return
+ server_listened = True
+ start_client()
+
+ def on_server_close(s: ssh.Server, reason: str):
+ log.info("server closed: " + reason, None)
+ server_closed = True
+ maybe_finish()
+
+ def on_session(sess: ssh.ServerSession):
+ session_opened = True
+
+ def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest):
+ if req.method == "password" and req.user == "user" and req.password == "pass":
+ sess.accept_auth()
+ else:
+ sess.reject_auth("invalid credentials")
+
+ def on_channel_open(sess: ssh.ServerSession):
+ finish_error("unexpected channel open on GC-cleaned client test")
+
+ def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str):
+ finish_error("unexpected exec request on GC-cleaned client test")
+
+ def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str):
+ finish_error("unexpected subsystem request on GC-cleaned client test")
+
+ def on_session_close(sess: ssh.ServerSession, reason: str):
+ log.info("session closed: " + reason, None)
+ session_closed = True
+ if not server_close_requested and server is not None:
+ server_close_requested = True
+ server.close()
+ maybe_finish()
+
+ def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo):
+ c.accept_hostkey()
+
+ def on_client_connect(c: ssh.Client, err: ?str):
+ if err is not None:
+ finish_error("client connect error: " + err)
+ return
+ client_connected = True
+ client = None
+ if not gc_started:
+ gc_started = True
+ after 0: drive_gc()
+
+ def on_client_close(c: ssh.Client, reason: str):
+ log.info("client closed: " + reason, None)
+ return
+
+ def start_client():
+ client = None
+ client_connected = False
+ session_opened = False
+ session_closed = False
+ gc_started = False
+ client = ssh.Client(
+ net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ "user",
+ on_client_connect,
+ on_client_close,
+ on_hostkey,
+ password="pass",
+ port=u16(port),
+ known_hosts="/tmp/acton_ssh_test_known_hosts",
+ )
+
+ def start_server():
+ attempts += 1
+ port = _pick_test_port(35000, 25000, attempts, 557)
+ client = None
+ server_close_requested = False
+ server_listened = False
+ client_connected = False
+ session_opened = False
+ session_closed = False
+ server_closed = False
+ server = ssh.Server(
+ net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))),
+ "127.0.0.1",
+ u16(port),
+ on_server_listen,
+ on_server_close,
+ on_session,
+ on_auth,
+ on_channel_open,
+ on_exec,
+ on_subsystem,
+ on_session_close,
+ )
+
+ after 0: start_server()
+ after 20.0: on_timeout()
+
+
+def _test_ssh_client_gc_cleanup(t: testing.EnvT):
+ """An unreachable client actor should eventually close itself."""
+ ClientCleanupTester(t)
+
+
+__modname__ = "test_ssh_server"
+
+__unit_tests: dict[str, testing.UnitTest] = {
+}
+
+__simple_sync_tests: dict[str, testing.SimpleSyncTest] = {
+}
+
+__sync_tests: dict[str, testing.SyncTest] = {
+}
+
+__async_tests: dict[str, testing.AsyncTest] = {
+}
+
+__env_tests: dict[str, testing.EnvTest] = {
+ "_test_ServerClientTester": testing.EnvTest(_test_ssh_server_subsystem, "_test_ServerClientTester", "", __modname__),
+ "_test_ClientCloseFlushTester": testing.EnvTest(_test_ssh_client_close_flush, "_test_ClientCloseFlushTester", "", __modname__),
+ "_test_ClientSessionCloseFlushTester": testing.EnvTest(_test_ssh_client_session_close_flush, "_test_ClientSessionCloseFlushTester", "", __modname__),
+ "_test_AuthRejectTester": testing.EnvTest(_test_ssh_auth_reject, "_test_AuthRejectTester", "", __modname__),
+ "_test_ServerSessionCloseFlushTester": testing.EnvTest(_test_ssh_server_session_close_flush, "_test_ServerSessionCloseFlushTester", "", __modname__),
+ "_test_ConcurrentChannelTester": testing.EnvTest(_test_ssh_concurrent_channels, "_test_ConcurrentChannelTester", "", __modname__),
+ "_test_ConcurrentChannelWriteTester": testing.EnvTest(_test_ssh_concurrent_channel_writes, "_test_ConcurrentChannelWriteTester", "", __modname__),
+ "_test_KeepaliveTrafficTester": testing.EnvTest(_test_ssh_keepalive_traffic, "_test_KeepaliveTrafficTester", "", __modname__),
+ "_test_CrossHandleOwnershipTester": testing.EnvTest(_test_ssh_cross_handle_ownership, "_test_CrossHandleOwnershipTester", "", __modname__),
+ "_test_UnsupportedRequestTester": testing.EnvTest(_test_ssh_unsupported_requests, "_test_UnsupportedRequestTester", "", __modname__),
+ "_test_RunCommandTimeoutTester": testing.EnvTest(_test_ssh_runcommand_timeout, "_test_RunCommandTimeoutTester", "", __modname__),
+ "_test_RunCommandExitStatusTester": testing.EnvTest(_test_ssh_runcommand_exit_status, "_test_RunCommandExitStatusTester", "", __modname__),
+ "_test_ConcurrentRunCommandTester": testing.EnvTest(_test_ssh_concurrent_runcommand, "_test_ConcurrentRunCommandTester", "", __modname__),
+ "_test_ConcurrentRunCommandRejectTester": testing.EnvTest(_test_ssh_concurrent_runcommand_reject, "_test_ConcurrentRunCommandRejectTester", "", __modname__),
+ "_test_AuthTimeoutTester": testing.EnvTest(_test_ssh_auth_timeout, "_test_AuthTimeoutTester", "", __modname__),
+ "_test_HostkeyWaitDisconnectTester": testing.EnvTest(_test_ssh_hostkey_wait_disconnect, "_test_HostkeyWaitDisconnectTester", "", __modname__),
+ "_test_CloseCallbackReentryTester": testing.EnvTest(_test_ssh_close_callback_reentry, "_test_CloseCallbackReentryTester", "", __modname__),
+ "_test_KeyExchangeTimeoutTester": testing.EnvTest(_test_ssh_key_exchange_timeout, "_test_KeyExchangeTimeoutTester", "", __modname__),
+ "_test_ServerSurvivesGarbageTester": testing.EnvTest(_test_ssh_server_survives_garbage, "_test_ServerSurvivesGarbageTester", "", __modname__),
+ "_test_RapidDisconnectChurnTester": testing.EnvTest(_test_ssh_rapid_disconnect_churn, "_test_RapidDisconnectChurnTester", "", __modname__),
+ "_test_CloseRejectsLateWritesTester": testing.EnvTest(_test_ssh_close_rejects_late_writes, "_test_CloseRejectsLateWritesTester", "", __modname__),
+ "_test_SessionLimitTester": testing.EnvTest(_test_ssh_session_limit, "_test_SessionLimitTester", "", __modname__),
+ "_test_ChannelLimitTester": testing.EnvTest(_test_ssh_channel_limit, "_test_ChannelLimitTester", "", __modname__),
+ "_test_ServerCleanupTester": testing.EnvTest(_test_ssh_server_gc_cleanup, "_test_ServerCleanupTester", "", __modname__),
+ "_test_ClientCleanupTester": testing.EnvTest(_test_ssh_client_gc_cleanup, "_test_ClientCleanupTester", "", __modname__),
+}
+
+
+actor __test_main(env):
+ testing.test_runner(env, __unit_tests, __simple_sync_tests, __sync_tests, __async_tests, __env_tests)