From f6ea055e6a34757622b5cd203e871cbb32c0c4e3 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Tue, 6 Jan 2026 16:41:13 +0100 Subject: [PATCH 01/38] Fix channel read pumping to avoid hangs --- build.act.json | 6 +- src/ssh.act | 493 +++++++- src/ssh.ext.c | 3069 ++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 3425 insertions(+), 143 deletions(-) diff --git a/build.act.json b/build.act.json index ec3c9b7..04c7953 100644 --- a/build.act.json +++ b/build.act.json @@ -2,8 +2,10 @@ "dependencies": {}, "zig_dependencies": { "libssh": { - "url": "https://github.com/actonlang/libssh/archive/refs/heads/zig-build.tar.gz", - "hash": "122076b4676ec7d777e5efc24ae4b7f1ab4fa3df407ab9649fdd55f9fc594ca053ec", + "path": "../libssh", + "options": { + "WITH_SERVER": "true" + }, "artifacts": [ "ssh" ] diff --git a/src/ssh.act b/src/ssh.act index a8b0268..df1a6c4 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -1,69 +1,494 @@ import net -import testing + def version() -> str: """Get the libssh version""" NotImplemented -def _test_version(): - testing.assertEqual("0.11.0", version()) + +def _debug(msg: str) -> None: + """Internal debug logging (gated by ACTON_SSH_DEBUG).""" + NotImplemented + + +HOSTKEY_OK = "ok" +HOSTKEY_UNKNOWN = "unknown" +HOSTKEY_NOT_FOUND = "not_found" +HOSTKEY_CHANGED = "changed" +HOSTKEY_OTHER = "other" +HOSTKEY_ERROR = "error" + +class HostKeyInfo(): + """Server host key details""" + key_type: str + fingerprint: str + + def __init__(self, key_type: str, fingerprint: str): + self.key_type = key_type + self.fingerprint = fingerprint + + +class AuthRequest(): + """Authentication request details""" + method: str + user: str + password: ?str + pubkey: ?bytes + + def __init__(self, method: str, user: str, password: ?str=None, pubkey: ?bytes=None): + self.method = method + self.user = user + self.password = password + self.pubkey = pubkey + actor Client(cap: net.TCPConnectCap, host: str, username: str, - on_connect: action(Client) -> None, + on_connect: action(Client, ?str) -> None, on_close: action(Client, str) -> None, - key: ?str=None, + on_hostkey: ?action(Client, str, HostKeyInfo) -> None=None, password: ?str=None, port: u16=22, + known_hosts: ?str=None, + connect_timeout: float=10.0, + auth_timeout: float=10.0, + keepalive_interval: float=30.0, + keepalive_enabled: bool=True, ): """SSH Client""" # haha, this is really a pointer :P - var _ssh_session: u64 = 0 + var _client: u64 = 0 + + proc def _pin_affinity() -> None: + NotImplemented + _pin_affinity() proc def _init() -> None: """Initialize the SSH client""" NotImplemented _init() - print("SSH Client connected") -# action def close(on_close: action(TLSConnection) -> None) -> None: -# """Close the connection""" -# NotImplemented -# -# def reconnect(): -# close(_connect) -# -# def _connect(c): -# NotImplemented -# + action def accept_hostkey() -> None: + """Accept the currently pending host key (not persisted)""" + NotImplemented + + action def reject_hostkey(reason: str) -> None: + """Reject the currently pending host key""" + NotImplemented + + action def close() -> None: + """Close the SSH client""" + NotImplemented + + # Internal channel operations (used by Channel actor) + action def channel_create(channel: Channel, + on_open: action(Channel, ?str) -> None, + on_stdout: action(Channel, ?bytes) -> None, + on_stderr: action(Channel, ?bytes) -> None, + on_exit: action(Channel, int, ?str) -> None, + on_close: action(Channel, str) -> None) -> None: + NotImplemented + + action def channel_request_exec(channel: Channel, cmd: str) -> None: + NotImplemented + + action def channel_request_shell(channel: Channel, + term: str, + cols: int, + rows: int, + width_px: int, + height_px: int, + with_pty: bool) -> None: + NotImplemented + + action def channel_request_subsystem(channel: Channel, name: str) -> None: + NotImplemented + + action def channel_write(channel: Channel, data: bytes) -> None: + NotImplemented + + action def channel_send_eof(channel: Channel) -> None: + NotImplemented + + action def channel_close(channel: Channel) -> None: + NotImplemented + + +actor Channel(client: Client, + on_open: action(Channel, ?str) -> None, + on_stdout: action(Channel, ?bytes) -> None, + on_stderr: action(Channel, ?bytes) -> None, + on_exit: action(Channel, int, ?str) -> None, + on_close: action(Channel, str) -> None): + """SSH Channel""" + + # pointer to internal channel state + var _channel_id: u64 = 0 + var _on_open: action(Channel, ?str) -> None = on_open + var _on_stdout: action(Channel, ?bytes) -> None = on_stdout + var _on_stderr: action(Channel, ?bytes) -> None = on_stderr + var _on_exit: action(Channel, int, ?str) -> None = on_exit + var _on_close: action(Channel, str) -> None = on_close + + proc def _init() -> None: + client.channel_create(self, _on_open, _on_stdout, _on_stderr, _on_exit, _on_close) + _init() + + action def request_exec(cmd: str) -> None: + """Request to execute a command""" + client.channel_request_exec(self, cmd) + + action def request_shell(term: str="xterm-256color", + cols: int=80, + rows: int=24, + width_px: int=0, + height_px: int=0, + with_pty: bool=True) -> None: + """Request an interactive shell""" + client.channel_request_shell(self, term, cols, rows, width_px, height_px, with_pty) + + action def request_subsystem(name: str) -> None: + """Request a subsystem""" + client.channel_request_subsystem(self, name) + + action def write(data: bytes) -> None: + """Write data to the channel""" + client.channel_write(self, data) + + action def send_eof() -> None: + """Send EOF to the channel""" + client.channel_send_eof(self) + + action def close() -> None: + """Close the channel""" + client.channel_close(self) + + +actor RunCommand(client: Client, + cmd: str, + on_exit: action(Channel, int, ?str, bytes, bytes, ?str) -> None, + timeout: ?float=None): + """Run a command and collect output""" + + var out_buf = b"" + var err_buf = b"" + var _out_done = False + var _err_done = False + var _exited = False + var _exit_code = 0 + var _exit_signal: ?str = None + var _error: ?str = None + var _done = False + + def _finish(ch: Channel): + if _done: + return + _done = True + on_exit(ch, _exit_code, _exit_signal, out_buf, err_buf, _error) + + def _on_open(ch: Channel, err: ?str): + if err is not None: + _error = err + _finish(ch) + return + ch.request_exec(cmd) + + def _on_stdout(ch: Channel, data: ?bytes): + if data is not None: + out_buf += data + else: + _out_done = True + _check_done(ch) + + def _on_stderr(ch: Channel, data: ?bytes): + if data is not None: + err_buf += data + else: + _err_done = True + _check_done(ch) + + def _on_exit(ch: Channel, code: int, sig: ?str): + _exited = True + _exit_code = code + _exit_signal = sig + _check_done(ch) + + def _on_close(ch: Channel, reason: str): + if _error is None and reason != "closed": + _error = reason + if _error is not None: + _finish(ch) + return + + def _check_done(ch: Channel): + if _out_done and _err_done and _exited and _error is None: + _finish(ch) + + var _channel = Channel(client, _on_open, _on_stdout, _on_stderr, _on_exit, _on_close) + + if timeout is not None: + def _on_timeout(): + _channel.close() + after timeout: _on_timeout() + + +actor Server(cap: net.TCPListenCap, + host: str, + port: u16, + on_listen: action(Server, ?str) -> None, + on_close: action(Server, str) -> None, + on_session: action(ServerSession) -> None, + on_auth: action(ServerSession, AuthRequest) -> None, + on_channel_open: action(ServerSession) -> None, + on_exec: ?action(ServerSession, ServerChannel, str) -> None=None, + on_subsystem: ?action(ServerSession, ServerChannel, str) -> None=None, + on_session_close: ?action(ServerSession, str) -> None=None, + host_key_path: ?str=None, + host_key_type: str="ed25519", + host_key_bits: int=2048, + auth_timeout: float=10.0, + keepalive_interval: float=30.0, + keepalive_enabled: bool=True, + ): + """SSH Server""" + + var _server: u64 = 0 + var _host: str = host + var _port: u16 = port + var _host_key_path: ?str = host_key_path + var _host_key_type: str = host_key_type + var _host_key_bits: int = host_key_bits + var _auth_timeout: float = auth_timeout + var _keepalive_interval: float = keepalive_interval + var _keepalive_enabled: bool = keepalive_enabled + var _on_listen: action(Server, ?str) -> None = on_listen + var _on_close: action(Server, str) -> None = on_close + + proc def _pin_affinity() -> None: + NotImplemented + _pin_affinity() + + proc def _init() -> None: + """Initialize the SSH server""" + NotImplemented + _init() + + action def close() -> None: + """Close the SSH server""" + NotImplemented + + action def on_session_pending(session_id: u64) -> None: + _debug("server on_session_pending: " + str(session_id)) + ServerSession(self, + session_id, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close) + _debug("server on_session_pending: created session actor for " + str(session_id)) + + action def on_session_ready(session: ServerSession) -> None: + on_session(session) + + +actor ServerSession(server: Server, + session_id: u64, + on_auth: action(ServerSession, AuthRequest) -> None, + on_channel_open: action(ServerSession) -> None, + on_exec: ?action(ServerSession, ServerChannel, str) -> None=None, + on_subsystem: ?action(ServerSession, ServerChannel, str) -> None=None, + on_close: ?action(ServerSession, str) -> None=None): + """SSH Server Session""" -# TODO: implement support for channels -# AFAIK, all things over ssh are done via channels, so need some channel -# primitive, maybe an actor per channel that then multiplexes into the Client -# session? Prolly need some higher level wrappers for common things like -# starting a shell or running a single command. SFTP / SCP would be nice too, -# but for sometime in the future. Custom subsystems need to be supported too. + var _session_id: u64 = 0 + var _on_auth: action(ServerSession, AuthRequest) -> None = on_auth + var _on_channel_open: action(ServerSession) -> None = on_channel_open + var _on_exec: ?action(ServerSession, ServerChannel, str) -> None = on_exec + var _on_subsystem: ?action(ServerSession, ServerChannel, str) -> None = on_subsystem + var _on_close: ?action(ServerSession, str) -> None = on_close + proc def _pin_affinity() -> None: + NotImplemented + _pin_affinity() + + proc def _attach(session_id: u64) -> None: + NotImplemented + action def _attach_ready() -> None: + _debug("server session attach ready: " + str(session_id)) + _attach(session_id) + _debug("server session attach: _attach returned for " + str(session_id)) + server.on_session_ready(self) + _debug("server session attach: on_session_ready sent for " + str(session_id)) + + after 0: _attach_ready() + + action def accept_auth() -> None: + """Accept the pending authentication request""" + NotImplemented + + action def reject_auth(reason: str) -> None: + """Reject the pending authentication request""" + NotImplemented + + action def accept_channel(channel: ServerChannel) -> None: + """Accept the pending channel open request""" + channel.accept_open() + + action def accept_channel_open(channel: ServerChannel, + on_data: action(ServerChannel, ?bytes) -> None, + on_stderr: action(ServerChannel, ?bytes) -> None, + on_close: action(ServerChannel, str) -> None) -> None: + NotImplemented + + action def reject_channel(reason: str) -> None: + """Reject the pending channel open request""" + NotImplemented + + action def close() -> None: + """Close the SSH session""" + NotImplemented + + # Internal channel operations (used by ServerChannel actor) + action def channel_accept_request(channel: ServerChannel) -> None: + NotImplemented + + action def channel_reject_request(channel: ServerChannel, reason: str) -> None: + NotImplemented + + action def channel_write(channel: ServerChannel, data: bytes) -> None: + NotImplemented + + action def channel_write_stderr(channel: ServerChannel, data: bytes) -> None: + NotImplemented + + action def channel_send_eof(channel: ServerChannel) -> None: + NotImplemented + + action def channel_send_exit_status(channel: ServerChannel, status: int) -> None: + NotImplemented + + action def channel_close(channel: ServerChannel) -> None: + NotImplemented + + +actor ServerChannel(session: ServerSession, + on_data: action(ServerChannel, ?bytes) -> None, + on_stderr: action(ServerChannel, ?bytes) -> None, + on_close: action(ServerChannel, str) -> None): + """SSH Server Channel""" + + var _channel_id: u64 = 0 + var _on_data: action(ServerChannel, ?bytes) -> None = on_data + var _on_stderr: action(ServerChannel, ?bytes) -> None = on_stderr + var _on_close: action(ServerChannel, str) -> None = on_close + + action def accept_request() -> None: + """Accept the pending channel request""" + session.channel_accept_request(self) + + action def accept_open() -> None: + """Accept the pending channel open request""" + session.accept_channel_open(self, _on_data, _on_stderr, _on_close) + + action def reject_request(reason: str) -> None: + """Reject the pending channel request""" + session.channel_reject_request(self, reason) + + action def write(data: bytes) -> None: + """Write data to the channel (stdout)""" + session.channel_write(self, data) + + action def write_stderr(data: bytes) -> None: + """Write data to the channel (stderr)""" + session.channel_write_stderr(self, data) + + action def send_eof() -> None: + """Send EOF to the channel""" + session.channel_send_eof(self) + + action def send_exit_status(status: int) -> None: + """Send an exit status for the channel""" + session.channel_send_exit_status(self, status) + + action def close() -> None: + """Close the channel""" + session.channel_close(self) actor main(env): - def on_connect(client: Client): - print("Connected") + NETCONF_HELLO = b'\n' + \ + b'\n' + \ + b' \n' + \ + b' urn:ietf:params:netconf:base:1.0\n' + \ + b' \n' + \ + b']]>]]>' + var out_buf = b"" + var c: ?Client = None + + def on_connect(client: Client, err: ?str): + if err is not None: + print("SSH error", err) + env.exit(1) + return + print("SSH connected") + + def on_close(client: Client, reason: str): + print("SSH closed", reason) + env.exit(0) - def on_close(client: Client, error: str): - print("Error", error) + def on_hostkey(client: Client, state: str, info: HostKeyInfo): + print("Host key", state, info.key_type, info.fingerprint) + # WARNING: accepting any host key is insecure; do this only for testing + client.accept_hostkey() - print(version()) - c = Client( + def ch_open(ch: Channel, err: ?str): + if err is not None: + print("Channel open error", err) + if c is not None: + c.close() + return + ch.request_subsystem("netconf") + ch.write(NETCONF_HELLO) + + def ch_out(ch: Channel, data: ?bytes): + if data is not None: + out_buf += data + if out_buf.find(b"]]>]]>") >= 0: + print(out_buf) + ch.close() + + def ch_err(ch: Channel, data: ?bytes): + if data is not None: + print(data) + + def ch_exit(ch: Channel, code: int, sig: ?str): + print("exit", code, sig) + + def ch_close(ch: Channel, reason: str): + print("channel closed", reason) + if c is not None: + c.close() + + client = Client( net.TCPConnectCap(net.TCPCap(net.NetCap(env.cap))), - "localhost", - "foo", + "127.0.0.1", + "admin", on_connect, on_close, - password="bar", - port=2223, + on_hostkey, + password="admin", + port=42830, + known_hosts="/tmp/acton_ssh_known_hosts", ) - env.exit(0) + + c = client + Channel(client, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def _on_close_timeout(): + if c is not None: + c.close() + after 10.0: _on_close_timeout() diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 2112598..f92a728 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -1,120 +1,2975 @@ +#include +#include #include -#include -// TODO: figure out how to include rts/log so we get access to log_error etc +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -void noop_free(void *ptr) { +#include "rts/log.h" + +uv_loop_t *get_uv_loop(void); +extern struct $Cont $Done$instance; + +#define SSH_READ_BUFSIZE 4096 +#define SSH_IO_PUMP_LIMIT 32 +static int ssh_debug_enabled = 0; +static int ssh_libssh_log_level = SSH_LOG_NOLOG; + +typedef enum { + CLIENT_STATE_INIT = 0, + CLIENT_STATE_CONNECTING, + CLIENT_STATE_HOSTKEY, + CLIENT_STATE_HOSTKEY_WAIT, + CLIENT_STATE_AUTH, + CLIENT_STATE_READY, + CLIENT_STATE_ERROR, + CLIENT_STATE_CLOSING, + CLIENT_STATE_CLOSED, +} client_state_t; + +typedef enum { + CHAN_STATE_INIT = 0, + CHAN_STATE_OPENING, + CHAN_STATE_OPEN, + CHAN_STATE_RUNNING, + CHAN_STATE_CLOSING, + CHAN_STATE_CLOSED, + CHAN_STATE_ERROR, +} channel_state_t; + +typedef enum { + CHAN_REQ_NONE = 0, + CHAN_REQ_SHELL, + CHAN_REQ_EXEC, + CHAN_REQ_SUBSYSTEM, +} channel_request_t; + +typedef struct write_chunk { + B_bytes data; + size_t offset; + struct write_chunk *next; +} write_chunk_t; + +typedef struct ssh_channel_ctx { + struct ssh_channel_ctx *next; + ssh_channel channel; + sshQ_Channel actor; + channel_state_t state; + channel_request_t pending_req; + int pty_pending; + int pty_done; + B_str exec_cmd; + B_str subsystem; + B_str term; + int cols; + int rows; + int width_px; + int height_px; + int send_eof; + int eof_sent; + int close_requested; + int close_sent; + int stdout_eof; + int stderr_eof; + int exit_sent; + int open_notified; + int open_succeeded; + int close_notified; + $action2 on_open; + $action2 on_close; + $action2 on_stdout; + $action2 on_stderr; + $action3 on_exit; + write_chunk_t *write_head; + write_chunk_t *write_tail; +} ssh_channel_ctx; + +typedef struct ssh_client_ctx { + sshQ_Client actor; + ssh_session session; + uv_poll_t *poll; + int poll_events; + uv_timer_t *connect_timer; + uv_timer_t *auth_timer; + uv_timer_t *keepalive_timer; + double connect_timeout; + double auth_timeout; + double keepalive_interval; + int keepalive_enabled; + client_state_t state; + int fd; + int connect_notified; + int connected_ok; + int close_notified; + int close_finalized; + int write_ready; + char *close_reason; + enum ssh_known_hosts_e hostkey_state; + ssh_channel_ctx *channels; + $action2 on_connect; + $action2 on_close; + $action3 on_hostkey; +} ssh_client_ctx; + +typedef enum { + SERVER_STATE_INIT = 0, + SERVER_STATE_LISTENING, + SERVER_STATE_ERROR, + SERVER_STATE_CLOSING, + SERVER_STATE_CLOSED, +} server_state_t; + +typedef enum { + SESSION_STATE_PENDING = 0, + SESSION_STATE_KEYEX, + SESSION_STATE_AUTH, + SESSION_STATE_READY, + SESSION_STATE_ERROR, + SESSION_STATE_CLOSING, + SESSION_STATE_CLOSED, +} session_state_t; + +typedef enum { + SCHAN_STATE_OPEN = 0, + SCHAN_STATE_CLOSING, + SCHAN_STATE_CLOSED, + SCHAN_STATE_ERROR, +} schan_state_t; + +typedef enum { + SCHAN_REQ_NONE = 0, + SCHAN_REQ_EXEC, + SCHAN_REQ_SUBSYSTEM, +} schan_req_t; + +typedef struct server_write_chunk { + B_bytes data; + size_t offset; + int is_stderr; + struct server_write_chunk *next; +} server_write_chunk_t; + +typedef struct ssh_server_channel_ctx { + struct ssh_server_channel_ctx *next; + ssh_channel channel; + struct ssh_server_session_ctx *session; + sshQ_ServerChannel actor; + schan_state_t state; + int send_eof; + int close_requested; + int close_sent; + int eof_sent; + int stdout_eof; + int stderr_eof; + int close_notified; + ssh_message pending_req; + schan_req_t pending_req_type; + $action2 on_data; + $action2 on_stderr; + $action2 on_close; + server_write_chunk_t *write_head; + server_write_chunk_t *write_tail; +} ssh_server_channel_ctx; + +typedef struct ssh_server_session_ctx { + struct ssh_server_session_ctx *next; + sshQ_ServerSession actor; + struct ssh_server_ctx *server; + ssh_session session; + uv_poll_t *poll; + int poll_events; + uv_timer_t *auth_timer; + uv_timer_t *keepalive_timer; + double auth_timeout; + double keepalive_interval; + int keepalive_enabled; + session_state_t state; + int attached; + int fd; + int owner_wt; + int write_ready; + int close_notified; + int close_finalized; + char *close_reason; + ssh_message pending_auth; + ssh_message pending_channel_open; + ssh_server_channel_ctx *channels; + $action2 on_auth; + $action on_channel_open; + $action3 on_exec; + $action3 on_subsystem; + $action2 on_close; +} ssh_server_session_ctx; + +typedef struct ssh_server_ctx { + sshQ_Server actor; + ssh_bind bind; + ssh_key hostkey; + uv_poll_t *poll; + int fd; + server_state_t state; + int listen_notified; + int listen_ok; + int close_notified; + int close_finalized; + char *close_reason; + ssh_server_session_ctx *sessions; + $action2 on_listen; + $action2 on_close; +} ssh_server_ctx; + +static void client_drive(ssh_client_ctx *c); +static void client_update_poll(ssh_client_ctx *c); +static void client_close_internal(ssh_client_ctx *c, const char *reason); +static void client_finalize(ssh_client_ctx *c); +static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch); +static int client_needs_write(ssh_client_ctx *c); +static void client_pump_io(ssh_client_ctx *c); + +static void server_accept(ssh_server_ctx *s); +static void server_close_internal(ssh_server_ctx *s, const char *reason); +static void server_finalize(ssh_server_ctx *s); +static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *sess); +static void session_drive(ssh_server_session_ctx *s); +static void session_update_poll(ssh_server_session_ctx *s); +static void session_pump_io(ssh_server_session_ctx *s); +static void session_close_internal(ssh_server_session_ctx *s, const char *reason); +static void session_finalize(ssh_server_session_ctx *s); +static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch); +static int session_needs_write(ssh_server_session_ctx *s); +static void session_poll_cb(uv_poll_t *handle, int status, int events); +static void server_poll_cb(uv_poll_t *handle, int status, int events); +static void session_auth_timeout_cb(uv_timer_t *timer); +static void session_keepalive_cb(uv_timer_t *timer); +static void client_poll_close_cb(uv_handle_t *handle); +static void client_timer_close_cb(uv_handle_t *handle); +static void server_poll_close_cb(uv_handle_t *handle); +static void session_poll_close_cb(uv_handle_t *handle); +static void session_timer_close_cb(uv_handle_t *handle); + +static void ssh_debug_log(const char *fmt, ...) { + if (!ssh_debug_enabled) + return; + va_list ap; + va_start(ap, fmt); + vfprintf(stderr, fmt, ap); + va_end(ap); + fprintf(stderr, "\n"); + fflush(stderr); +} + +static int parse_libssh_log_level(const char *value) { + if (value == NULL || value[0] == '\0') + return SSH_LOG_NOLOG; + char *endptr = NULL; + long lvl = strtol(value, &endptr, 10); + if (endptr != value && endptr && *endptr == '\0') { + if (lvl < 0) + return SSH_LOG_NOLOG; + if (lvl > SSH_LOG_TRACE) + lvl = SSH_LOG_TRACE; + return (int)lvl; + } + if (strcasecmp(value, "warn") == 0 || strcasecmp(value, "warning") == 0) + return SSH_LOG_WARN; + if (strcasecmp(value, "info") == 0 || strcasecmp(value, "protocol") == 0) + return SSH_LOG_INFO; + if (strcasecmp(value, "debug") == 0 || strcasecmp(value, "packet") == 0) + return SSH_LOG_DEBUG; + if (strcasecmp(value, "trace") == 0 || strcasecmp(value, "functions") == 0) + return SSH_LOG_TRACE; + if (strcasecmp(value, "none") == 0 || strcasecmp(value, "off") == 0) + return SSH_LOG_NOLOG; + return SSH_LOG_NOLOG; +} + +static void ssh_log_cb(int priority, const char *function, const char *buffer, void *userdata) { + (void)userdata; + if (buffer == NULL) + return; + fprintf(stderr, "libssh[%d] %s: %s\n", priority, function ? function : "", buffer); + fflush(stderr); +} + +static void ssh_configure_libssh_logging(void) { + if (ssh_libssh_log_level <= SSH_LOG_NOLOG) + return; + ssh_set_log_callback(ssh_log_cb); + ssh_set_log_level(ssh_libssh_log_level); +} + +static int session_apply_poll_events(ssh_session session, int events) { + if (session == NULL) + return -1; + int revents = 0; + if (events & UV_READABLE) + revents |= POLLIN; + if (events & UV_WRITABLE) + revents |= POLLOUT; +#ifdef UV_DISCONNECT + if (events & UV_DISCONNECT) + revents |= POLLHUP; +#endif +#ifdef UV_PRIORITIZED + if (events & UV_PRIORITIZED) + revents |= POLLPRI; +#endif + if (revents == 0) + return 0; + if (ssh_session_handle_poll(session, revents) != SSH_OK) + return -1; + return 0; +} + +static const char *hostkey_state_str(enum ssh_known_hosts_e state) { + switch (state) { + case SSH_KNOWN_HOSTS_OK: + return "ok"; + case SSH_KNOWN_HOSTS_UNKNOWN: + return "unknown"; + case SSH_KNOWN_HOSTS_NOT_FOUND: + return "not_found"; + case SSH_KNOWN_HOSTS_CHANGED: + return "changed"; + case SSH_KNOWN_HOSTS_OTHER: + return "other"; + case SSH_KNOWN_HOSTS_ERROR: + default: + return "error"; + } +} + +static ssh_client_ctx *client_from_actor(sshQ_Client self) { + if (self == NULL) + return NULL; + unsigned long ptr = fromB_u64(self->_client); + if (ptr == 0) + return NULL; + return (ssh_client_ctx *)ptr; +} + +static ssh_channel_ctx *channel_from_actor(sshQ_Channel channel) { + if (channel == NULL) + return NULL; + unsigned long ptr = fromB_u64(channel->_channel_id); + if (ptr == 0) + return NULL; + return (ssh_channel_ctx *)ptr; +} + +static ssh_server_ctx *server_from_actor(sshQ_Server self) { + if (self == NULL) + return NULL; + unsigned long ptr = fromB_u64(self->_server); + if (ptr == 0) + return NULL; + return (ssh_server_ctx *)ptr; +} + +static ssh_server_session_ctx *session_from_actor(sshQ_ServerSession self) { + if (self == NULL) + return NULL; + unsigned long ptr = fromB_u64(self->_session_id); + if (ptr == 0) + return NULL; + return (ssh_server_session_ctx *)ptr; +} + +static ssh_server_channel_ctx *server_channel_from_actor(sshQ_ServerChannel channel) { + if (channel == NULL) + return NULL; + unsigned long ptr = fromB_u64(channel->_channel_id); + if (ptr == 0) + return NULL; + return (ssh_server_channel_ctx *)ptr; +} + +static ssh_server_channel_ctx *server_channel_from_ssh(ssh_server_session_ctx *s, ssh_channel chan) { + if (s == NULL || chan == NULL) + return NULL; + ssh_server_channel_ctx *cur = s->channels; + while (cur != NULL) { + if (cur->channel == chan) + return cur; + cur = cur->next; + } + return NULL; +} + +static void close_poll(uv_poll_t **poll, uv_close_cb close_cb) { + if (*poll != NULL) { + if (uv_is_closing((uv_handle_t *)*poll)) + return; + uv_poll_stop(*poll); + uv_close((uv_handle_t *)*poll, close_cb); + } +} + +static void stop_timer(uv_timer_t **timer, uv_close_cb close_cb) { + if (*timer != NULL) { + if (uv_is_closing((uv_handle_t *)*timer)) + return; + uv_timer_stop(*timer); + uv_close((uv_handle_t *)*timer, close_cb); + } +} + +static void client_mark_writable(ssh_client_ctx *c) { + if (c != NULL && c->session != NULL && c->write_ready) { + ssh_set_fd_towrite(c->session); + } +} + +static void session_mark_writable(ssh_server_session_ctx *s) { + if (s != NULL && s->session != NULL && s->write_ready) { + ssh_set_fd_towrite(s->session); + } +} + +static int fd_has_data(int fd) { + if (fd < 0) + return 0; + struct pollfd pfd; + pfd.fd = fd; + pfd.events = POLLIN; + pfd.revents = 0; + int rc; + do { + rc = poll(&pfd, 1, 0); + } while (rc < 0 && errno == EINTR); + if (rc <= 0) + return 0; + if (pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) + return 0; + return (pfd.revents & POLLIN) != 0; +} + +static int fd_set_nonblocking(int fd) { + if (fd < 0) + return -1; + int flags = fcntl(fd, F_GETFL, 0); + if (flags < 0) + return -1; + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0) + return -1; + return 0; +} + +static int fd_is_writable(int fd) { + if (fd < 0) + return 0; + struct pollfd pfd; + pfd.fd = fd; + pfd.events = POLLOUT; + pfd.revents = 0; + int rc; + do { + rc = poll(&pfd, 1, 0); + } while (rc < 0 && errno == EINTR); + if (rc <= 0) + return 0; + if (pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) + return 0; + return (pfd.revents & POLLOUT) != 0; +} + +static void client_poll_close_cb(uv_handle_t *handle) { + ssh_client_ctx *c = (ssh_client_ctx *)handle->data; + if (c == NULL) + return; + if (c->poll == (uv_poll_t *)handle) { + c->poll = NULL; + c->poll_events = 0; + } + if (c->state == CLIENT_STATE_CLOSING) { + client_finalize(c); + } +} + +static void client_timer_close_cb(uv_handle_t *handle) { + ssh_client_ctx *c = (ssh_client_ctx *)handle->data; + if (c == NULL) + return; + if ((uv_timer_t *)handle == c->connect_timer) + c->connect_timer = NULL; + if ((uv_timer_t *)handle == c->auth_timer) + c->auth_timer = NULL; + if ((uv_timer_t *)handle == c->keepalive_timer) + c->keepalive_timer = NULL; +} + +static void server_poll_close_cb(uv_handle_t *handle) { + ssh_server_ctx *s = (ssh_server_ctx *)handle->data; + if (s == NULL) + return; + if (s->poll == (uv_poll_t *)handle) + s->poll = NULL; + if (s->state == SERVER_STATE_CLOSING) { + server_finalize(s); + } +} + +static void session_poll_close_cb(uv_handle_t *handle) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; + if (s == NULL) + return; + if (s->poll == (uv_poll_t *)handle) { + s->poll = NULL; + s->poll_events = 0; + } + if (s->state == SESSION_STATE_CLOSING) { + session_finalize(s); + } +} + +static void session_timer_close_cb(uv_handle_t *handle) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; + if (s == NULL) + return; + if ((uv_timer_t *)handle == s->auth_timer) + s->auth_timer = NULL; + if ((uv_timer_t *)handle == s->keepalive_timer) + s->keepalive_timer = NULL; +} + +static void client_notify_connect(ssh_client_ctx *c, const char *err) { + if (c == NULL) + return; + if (c->connect_notified) + return; + if (c->on_connect) { + $action2 f = ($action2)c->on_connect; + f->$class->__asyn__(f, c->actor, err ? to$str((char *)err) : B_None); + } + c->connect_notified = 1; + if (err == NULL) + c->connected_ok = 1; +} + +static void client_notify_close(ssh_client_ctx *c, const char *reason) { + if (c->close_notified) + return; + if (!c->connected_ok) + return; + if (c->on_close) { + $action2 f = ($action2)c->on_close; + f->$class->__asyn__(f, c->actor, to$str((char *)reason)); + } + c->close_notified = 1; +} + +static void client_fail(ssh_client_ctx *c, const char *msg) { + if (c == NULL || c->state == CLIENT_STATE_CLOSED || c->state == CLIENT_STATE_ERROR) + return; + c->state = CLIENT_STATE_ERROR; + if (!c->connected_ok) + client_notify_connect(c, msg); + client_close_internal(c, msg); +} + +static void channel_notify_open(ssh_channel_ctx *ch, const char *err) { + if (ch->open_notified) + return; + if (ch->on_open) { + $action2 f = ($action2)ch->on_open; + f->$class->__asyn__(f, ch->actor, err ? to$str((char *)err) : B_None); + } + ch->open_notified = 1; + if (err == NULL) + ch->open_succeeded = 1; +} + +static void channel_notify_close(ssh_channel_ctx *ch, const char *reason) { + if (ch->close_notified) + return; + if (!ch->open_succeeded) { + ch->close_notified = 1; + return; + } + if (ch->on_close) { + $action2 f = ($action2)ch->on_close; + f->$class->__asyn__(f, ch->actor, to$str((char *)reason)); + } + ch->close_notified = 1; +} + +static void channel_notify_error(ssh_channel_ctx *ch, const char *msg) { + if (ch->open_succeeded) + channel_notify_close(ch, msg); + else + channel_notify_open(ch, msg); +} + +static void channel_notify_exit(ssh_channel_ctx *ch, int exit_status, B_str signal) { + if (ch->exit_sent) + return; + if (ch->on_exit) { + $action3 f = ($action3)ch->on_exit; + f->$class->__asyn__(f, ch->actor, toB_int(exit_status), signal); + } + ch->exit_sent = 1; +} + +static void channel_notify_eof(ssh_channel_ctx *ch) { + if (ch->channel == NULL) + return; + if (ssh_channel_is_eof(ch->channel)) { + if (!ch->stdout_eof && ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stderr_eof = 1; + } + } +} + +static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { + if (ch->channel != NULL) { + int exit_status = -1; + B_str exit_signal = B_None; + if (ssh_channel_is_closed(ch->channel)) { + uint32_t exit_code = 0; + char *signal = NULL; + int core_dumped = 0; + int rc = ssh_channel_get_exit_state(ch->channel, &exit_code, &signal, &core_dumped); + if (rc == SSH_OK) { + exit_status = (int)exit_code; + if (signal != NULL) + exit_signal = to$str(signal); + } + if (signal) + ssh_string_free_char(signal); + (void)core_dumped; + } + channel_notify_exit(ch, exit_status, exit_signal); + ssh_channel_free(ch->channel); + ch->channel = NULL; + } else { + channel_notify_exit(ch, -1, B_None); + } + if (!ch->stdout_eof && ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stderr_eof = 1; + } + channel_notify_close(ch, "closed"); + if (ch->actor) + ch->actor->_channel_id = toB_u64(0); + ch->state = CHAN_STATE_CLOSED; + (void)c; +} + +static void channel_fail(ssh_client_ctx *c, ssh_channel_ctx *ch, const char *msg) { + if (ch->state == CHAN_STATE_ERROR || ch->state == CHAN_STATE_CLOSED) + return; + ch->state = CHAN_STATE_ERROR; + channel_notify_error(ch, msg); + if (ch->channel) { + ssh_channel_close(ch->channel); + } + channel_finalize(c, ch); +} + +static void channel_queue_write(ssh_channel_ctx *ch, B_bytes data) { + write_chunk_t *chunk = acton_calloc(1, sizeof(write_chunk_t)); + chunk->data = data; + chunk->offset = 0; + chunk->next = NULL; + if (ch->write_tail) { + ch->write_tail->next = chunk; + } else { + ch->write_head = chunk; + } + ch->write_tail = chunk; +} + +static void channel_try_write(ssh_client_ctx *c, ssh_channel_ctx *ch) { + while (ch->write_head != NULL) { + write_chunk_t *chunk = ch->write_head; + size_t remaining = chunk->data->nbytes - chunk->offset; + if (remaining == 0) { + ch->write_head = chunk->next; + if (ch->write_head == NULL) + ch->write_tail = NULL; + continue; + } + client_mark_writable(c); + int rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + if (rc > 0) { + chunk->offset += (size_t)rc; + if (chunk->offset >= chunk->data->nbytes) { + ch->write_head = chunk->next; + if (ch->write_head == NULL) + ch->write_tail = NULL; + } + } else if (rc == 0 || rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH channel write error: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } +} + +static int channel_read_stream(ssh_client_ctx *c, ssh_channel_ctx *ch, int is_stderr) { + char buf[SSH_READ_BUFSIZE]; + int read_any = 0; + for (;;) { + int n = ssh_channel_read_nonblocking(ch->channel, buf, sizeof(buf), is_stderr); + if (ssh_debug_enabled) { + ssh_debug_log("client channel read: rc=%d stderr=%d", n, is_stderr); + } + if (n > 0) { + read_any = 1; + B_bytes out = to$bytesD_len(buf, n); + if (is_stderr) { + if (ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, out); + } + } else { + if (ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, ch->actor, out); + } + } + continue; + } + if (n == 0 || n == SSH_AGAIN) { + break; + } + if (n == SSH_EOF) { + break; + } + if (n == SSH_ERROR) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH channel read error: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + break; + } + } + return read_any; +} + +static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { + if (ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) + return; + + if (ch->state == CHAN_STATE_INIT) { + ch->channel = ssh_channel_new(c->session); + if (ch->channel == NULL) { + channel_fail(c, ch, "Failed to create SSH channel"); + return; + } + ch->state = CHAN_STATE_OPENING; + } + + if (ch->state == CHAN_STATE_OPENING) { + client_mark_writable(c); + int rc = ssh_channel_open_session(ch->channel); + if (rc == SSH_OK) { + ch->state = CHAN_STATE_OPEN; + channel_notify_open(ch, NULL); + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to open SSH channel: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + + if (ch->state == CHAN_STATE_OPEN || ch->state == CHAN_STATE_RUNNING) { + if (ch->pty_pending && !ch->pty_done) { + const char *term = ch->term ? (const char *)fromB_str(ch->term) : "xterm-256color"; + client_mark_writable(c); + int rc = ssh_channel_request_pty_size(ch->channel, term, ch->cols, ch->rows); + if (rc == SSH_OK) { + ch->pty_done = 1; + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to request PTY: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + + if (ch->pending_req != CHAN_REQ_NONE) { + int rc = SSH_ERROR; + if (ch->pending_req == CHAN_REQ_SHELL) { + client_mark_writable(c); + rc = ssh_channel_request_shell(ch->channel); + } else if (ch->pending_req == CHAN_REQ_EXEC) { + client_mark_writable(c); + rc = ssh_channel_request_exec(ch->channel, (const char *)fromB_str(ch->exec_cmd)); + } else if (ch->pending_req == CHAN_REQ_SUBSYSTEM) { + client_mark_writable(c); + rc = ssh_channel_request_subsystem(ch->channel, (const char *)fromB_str(ch->subsystem)); + } + + if (rc == SSH_OK) { + if (ssh_debug_enabled) { + ssh_debug_log("client channel request ok"); + } + ch->pending_req = CHAN_REQ_NONE; + ch->state = CHAN_STATE_RUNNING; + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + if (ssh_debug_enabled) { + ssh_debug_log("client channel request again"); + } + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH channel request failed: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + } + + if (ch->state == CHAN_STATE_OPEN || ch->state == CHAN_STATE_RUNNING) { + if (ch->write_head != NULL) { + channel_try_write(c, ch); + if (ch->state == CHAN_STATE_ERROR) + return; + } + + if (ch->send_eof && !ch->eof_sent) { + client_mark_writable(c); + int rc = ssh_channel_send_eof(ch->channel); + if (rc == SSH_OK) { + ch->eof_sent = 1; + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to send EOF: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + + if (ch->close_requested && !ch->close_sent) { + client_mark_writable(c); + int rc = ssh_channel_close(ch->channel); + if (rc == SSH_OK) { + ch->close_sent = 1; + ch->state = CHAN_STATE_CLOSING; + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to close channel: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + + for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + did |= channel_read_stream(c, ch, 0); + did |= channel_read_stream(c, ch, 1); + if (!did) + break; + } + channel_notify_eof(ch); + } + + if (ch->channel != NULL && ssh_channel_is_closed(ch->channel)) { + channel_finalize(c, ch); + } +} + +static void client_drive_channels(ssh_client_ctx *c) { + ssh_channel_ctx *prev = NULL; + ssh_channel_ctx *ch = c->channels; + while (ch != NULL) { + ssh_channel_ctx *next = ch->next; + channel_drive(c, ch); + if (ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) { + if (prev != NULL) { + prev->next = next; + } else { + c->channels = next; + } + ch->next = NULL; + } else { + prev = ch; + } + ch = next; + } +} + +static int client_get_hostkey_info(ssh_client_ctx *c, B_str *key_type_out, B_str *fingerprint_out) { + ssh_key key = NULL; + unsigned char *hash = NULL; + size_t hash_len = 0; + char *fingerprint = NULL; + + int rc = ssh_get_server_publickey(c->session, &key); + if (rc != SSH_OK || key == NULL) + return -1; + + enum ssh_keytypes_e key_type = ssh_key_type(key); + const char *key_type_str = ssh_key_type_to_char(key_type); + + rc = ssh_get_publickey_hash(key, SSH_PUBLICKEY_HASH_SHA256, &hash, &hash_len); + if (rc != SSH_OK) { + ssh_key_free(key); + return -1; + } + + fingerprint = ssh_get_fingerprint_hash(SSH_PUBLICKEY_HASH_SHA256, hash, hash_len); + + if (key_type_out) + *key_type_out = to$str((char *)(key_type_str ? key_type_str : "")); + if (fingerprint_out) + *fingerprint_out = to$str((char *)(fingerprint ? fingerprint : "")); + + ssh_clean_pubkey_hash(&hash); + ssh_key_free(key); + if (fingerprint) + ssh_string_free_char(fingerprint); + + return 0; +} + +static int client_check_hostkey(ssh_client_ctx *c) { + enum ssh_known_hosts_e state = SSH_KNOWN_HOSTS_UNKNOWN; + int use_known_hosts = 0; + + if (c->actor != NULL && c->actor->known_hosts != B_None) + use_known_hosts = 1; + + if (use_known_hosts) { + state = ssh_session_is_known_server(c->session); + if (state == SSH_KNOWN_HOSTS_OK) + return 0; + + if (state == SSH_KNOWN_HOSTS_ERROR) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Host key check error: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return -1; + } + } + + c->hostkey_state = state; + + if (!c->on_hostkey) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Host key not accepted (%s)", hostkey_state_str(state)); + client_fail(c, errmsg); + return -1; + } + + B_str key_type = to$str((char *)""); + B_str fingerprint = to$str((char *)""); + if (client_get_hostkey_info(c, &key_type, &fingerprint) != 0) { + key_type = to$str((char *)""); + fingerprint = to$str((char *)""); + } + + sshQ_HostKeyInfo info = sshQ_HostKeyInfoG_new(key_type, fingerprint); + $action3 f = ($action3)c->on_hostkey; + f->$class->__asyn__(f, c->actor, to$str((char *)hostkey_state_str(state)), info); + return 1; +} + +static void connect_timeout_cb(uv_timer_t *timer) { + ssh_client_ctx *c = (ssh_client_ctx *)timer->data; + if (c == NULL) + return; + if (c->state == CLIENT_STATE_CONNECTING || c->state == CLIENT_STATE_HOSTKEY || c->state == CLIENT_STATE_HOSTKEY_WAIT) { + client_fail(c, "SSH connect timeout"); + } +} + +static void auth_timeout_cb(uv_timer_t *timer) { + ssh_client_ctx *c = (ssh_client_ctx *)timer->data; + if (c == NULL) + return; + if (c->state == CLIENT_STATE_AUTH) { + client_fail(c, "SSH authentication timeout"); + } +} + +static void keepalive_cb(uv_timer_t *timer) { + ssh_client_ctx *c = (ssh_client_ctx *)timer->data; + if (c == NULL) + return; + if (c->state != CLIENT_STATE_READY || c->session == NULL) + return; + int rc = ssh_send_ignore(c->session, "keepalive"); + if (rc == SSH_AGAIN) { + client_update_poll(c); + return; + } + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH keepalive failed: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return; + } + client_update_poll(c); +} + +static void client_start_connect_timer(ssh_client_ctx *c) { + if (c->connect_timeout <= 0.0 || c->connect_timer != NULL) + return; + c->connect_timer = acton_calloc(1, sizeof(uv_timer_t)); + c->connect_timer->data = c; + uv_timer_init(get_uv_loop(), c->connect_timer); + uv_timer_start(c->connect_timer, connect_timeout_cb, (uint64_t)(c->connect_timeout * 1000), 0); +} + +static void client_start_auth_timer(ssh_client_ctx *c) { + if (c->auth_timeout <= 0.0 || c->auth_timer != NULL) + return; + c->auth_timer = acton_calloc(1, sizeof(uv_timer_t)); + c->auth_timer->data = c; + uv_timer_init(get_uv_loop(), c->auth_timer); + uv_timer_start(c->auth_timer, auth_timeout_cb, (uint64_t)(c->auth_timeout * 1000), 0); +} + +static void client_start_keepalive(ssh_client_ctx *c) { + if (!c->keepalive_enabled || c->keepalive_interval <= 0.0 || c->keepalive_timer != NULL) + return; + c->keepalive_timer = acton_calloc(1, sizeof(uv_timer_t)); + c->keepalive_timer->data = c; + uv_timer_init(get_uv_loop(), c->keepalive_timer); + uint64_t interval_ms = (uint64_t)(c->keepalive_interval * 1000); + uv_timer_start(c->keepalive_timer, keepalive_cb, interval_ms, interval_ms); +} + +static void poll_cb(uv_poll_t *handle, int status, int events) { + ssh_client_ctx *c = (ssh_client_ctx *)handle->data; + if (c == NULL) + return; + if (ssh_debug_enabled) { + ssh_debug_log("client poll: status=%d events=0x%x state=%d", status, events, c->state); + } + if (status < 0) { + char errmsg[256] = {0}; + uv_strerror_r(status, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + client_fail(c, errmsg); + return; + } + if (events & UV_READABLE) + ssh_set_fd_toread(c->session); + if (events & UV_WRITABLE) { + c->write_ready = 1; + ssh_set_fd_towrite(c->session); + } + if (session_apply_poll_events(c->session, events) != 0) { + client_fail(c, "SSH poll callback error"); + return; + } + client_drive(c); + client_pump_io(c); +} + +static void client_update_poll(ssh_client_ctx *c) { + if (c->poll == NULL || c->session == NULL) + return; + if (c->state == CLIENT_STATE_CLOSING || c->state == CLIENT_STATE_CLOSED) + return; + if (uv_is_closing((uv_handle_t *)c->poll)) + return; + int status = ssh_get_status(c->session); + if (status & SSH_CLOSED_ERROR) { + client_fail(c, "SSH session closed with error"); + return; + } + if (status & SSH_CLOSED) { + client_close_internal(c, "SSH session closed"); + return; + } + int flags = ssh_get_poll_flags(c->session); + int pending = flags | status; + int events = UV_READABLE; + if (pending & SSH_WRITE_PENDING) + events |= UV_WRITABLE; + if ((events & UV_WRITABLE) == 0 && client_needs_write(c)) + events |= UV_WRITABLE; +#ifdef UV_DISCONNECT + events |= UV_DISCONNECT; +#endif + if (ssh_debug_enabled && (events != c->poll_events || (pending & SSH_WRITE_PENDING))) { + ssh_debug_log("client update poll: status=0x%x flags=0x%x pending=0x%x events=0x%x state=%d", + status, flags, pending, events, c->state); + } + if (events != c->poll_events) { + int uv_rc = uv_poll_start(c->poll, events, poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + client_fail(c, errmsg); + return; + } + c->poll_events = events; + } +} + +static void client_pump_io(ssh_client_ctx *c) { + if (c == NULL || c->session == NULL) + return; + for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + if (c->session == NULL) + return; + int has_data = fd_has_data(c->fd); + if (has_data) { + if (ssh_debug_enabled) { + ssh_debug_log("client pump: readable"); + } + ssh_set_fd_toread(c->session); + if (session_apply_poll_events(c->session, UV_READABLE) != 0) { + client_fail(c, "SSH poll callback error"); + return; + } + client_drive(c); + did = 1; + } + if (c->session == NULL) + return; + int pending = ssh_get_status(c->session) | ssh_get_poll_flags(c->session); + int can_write = (pending & SSH_WRITE_PENDING) && fd_is_writable(c->fd); + if (can_write) { + if (ssh_debug_enabled) { + ssh_debug_log("client pump: writable pending=0x%x", pending); + } + c->write_ready = 1; + ssh_set_fd_towrite(c->session); + if (session_apply_poll_events(c->session, UV_WRITABLE) != 0) { + client_fail(c, "SSH poll callback error"); + return; + } + client_drive(c); + did = 1; + } + if (!did) { + int status = ssh_get_status(c->session); + if (status & SSH_READ_PENDING) { + if (ssh_debug_enabled) { + ssh_debug_log("client pump: buffered read pending=0x%x", status); + } + client_drive(c); + /* Avoid spinning when only buffered data remains. */ + break; + } + } + if (!did) + break; + } +} + +static int client_needs_write(ssh_client_ctx *c) { + if (c == NULL) + return 0; + ssh_channel_ctx *ch = c->channels; + while (ch != NULL) { + if (ch->pending_req != CHAN_REQ_NONE || ch->write_head != NULL) + return 1; + if (ch->send_eof && !ch->eof_sent) + return 1; + if (ch->close_requested && !ch->close_sent) + return 1; + ch = ch->next; + } + return 0; +} + +static void client_on_ready(ssh_client_ctx *c) { + c->state = CLIENT_STATE_READY; + stop_timer(&c->connect_timer, client_timer_close_cb); + stop_timer(&c->auth_timer, client_timer_close_cb); + client_notify_connect(c, NULL); + client_start_keepalive(c); + client_drive_channels(c); + client_update_poll(c); +} + +static void client_drive(ssh_client_ctx *c) { + if (c == NULL) + return; + if (c->state == CLIENT_STATE_ERROR || c->state == CLIENT_STATE_CLOSED || c->state == CLIENT_STATE_CLOSING) + return; + + while (1) { + if (c->state == CLIENT_STATE_CONNECTING) { + client_mark_writable(c); + int rc = ssh_connect(c->session); + if (ssh_debug_enabled) { + log_debug("ssh_connect rc=%d state=%d", rc, c->state); + } + if (rc == SSH_OK) { + stop_timer(&c->connect_timer, client_timer_close_cb); + c->state = CLIENT_STATE_HOSTKEY; + continue; + } else if (rc == SSH_AGAIN) { + if (ssh_get_status(c->session) & SSH_WRITE_PENDING) + c->write_ready = 0; + client_update_poll(c); + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH connect failed: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return; + } + } + + if (c->state == CLIENT_STATE_HOSTKEY) { + int rc = client_check_hostkey(c); + if (rc == 0) { + c->state = CLIENT_STATE_AUTH; + client_start_auth_timer(c); + continue; + } else if (rc == 1) { + c->state = CLIENT_STATE_HOSTKEY_WAIT; + return; + } else { + return; + } + } + + if (c->state == CLIENT_STATE_HOSTKEY_WAIT) { + client_update_poll(c); + return; + } + + if (c->state == CLIENT_STATE_AUTH) { + if (c->actor->password == B_None) { + client_fail(c, "Password auth requested but no password provided"); + return; + } + client_mark_writable(c); + int rc = ssh_userauth_password(c->session, NULL, (const char *)fromB_str(c->actor->password)); + if (rc == SSH_AUTH_SUCCESS) { + client_on_ready(c); + return; + } else if (rc == SSH_AUTH_AGAIN) { + if (ssh_get_status(c->session) & SSH_WRITE_PENDING) + c->write_ready = 0; + client_update_poll(c); + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH auth failed: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return; + } + } + + if (c->state == CLIENT_STATE_READY) { + client_drive_channels(c); + client_update_poll(c); + return; + } + + return; + } +} + +static void client_finalize(ssh_client_ctx *c) { + if (c == NULL || c->close_finalized) + return; + c->close_finalized = 1; + + if (c->session != NULL) { + ssh_disconnect(c->session); + ssh_free(c->session); + c->session = NULL; + } + + client_notify_close(c, c->close_reason ? c->close_reason : "closed"); + c->state = CLIENT_STATE_CLOSED; + if (c->actor) + c->actor->_client = toB_u64(0); +} + +static void client_close_internal(ssh_client_ctx *c, const char *reason) { + if (c == NULL || c->state == CLIENT_STATE_CLOSED || c->state == CLIENT_STATE_CLOSING) + return; + + if (!c->connected_ok && !c->connect_notified) { + client_notify_connect(c, reason ? reason : "closed"); + } + + int notify_channel_error = (c->state == CLIENT_STATE_ERROR); + c->state = CLIENT_STATE_CLOSING; + if (reason != NULL && c->close_reason == NULL) + c->close_reason = acton_strdup(reason); + + stop_timer(&c->connect_timer, client_timer_close_cb); + stop_timer(&c->auth_timer, client_timer_close_cb); + stop_timer(&c->keepalive_timer, client_timer_close_cb); + + if (c->poll != NULL) { + close_poll(&c->poll, client_poll_close_cb); + c->poll_events = 0; + } + + ssh_channel_ctx *ch = c->channels; + while (ch != NULL) { + ssh_channel_ctx *next = ch->next; + if (notify_channel_error) + channel_notify_error(ch, "Session closed"); + channel_notify_eof(ch); + channel_finalize(c, ch); + ch->next = NULL; + ch = next; + } + c->channels = NULL; + + if (c->poll != NULL) + return; + + client_finalize(c); +} + +void sshQ___ext_init__() { + const char *dbg_env = getenv("ACTON_SSH_DEBUG"); + const char *log_env = getenv("ACTON_SSH_LIBSSH_LOG"); + if (dbg_env != NULL && dbg_env[0] != '\0') + ssh_debug_enabled = 1; + if (log_env != NULL && log_env[0] != '\0') { + ssh_libssh_log_level = parse_libssh_log_level(log_env); + } + int r = libssh_replace_allocator(acton_malloc, + acton_realloc, + acton_calloc, + acton_free, + acton_strdup, + acton_strndup); + if (r != SSH_OK) { + log_warn("SSH allocator replacement failed"); + } + r = ssh_threads_set_callbacks(ssh_threads_get_default()); + if (r != SSH_OK) { + log_warn("SSH thread callbacks setup failed"); + } + r = ssh_init(); + if (r != SSH_OK) { + log_warn("SSH init failed"); + } +} + +B_str sshQ_version() { + return to$str((char *)ssh_version(0)); +} + +B_NoneType sshQ__debug(B_str msg) { + if (ssh_debug_enabled) { + log_info("%s", fromB_str(msg)); + } + return B_None; +} + +$R sshQ_ClientD__pin_affinityG_local(sshQ_Client self, $Cont c$cont) { + pin_actor_affinity(); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD__initG_local(sshQ_Client self, $Cont c$cont) { + ssh_configure_libssh_logging(); + ssh_client_ctx *c = acton_calloc(1, sizeof(ssh_client_ctx)); + c->actor = self; + c->on_connect = ($action2)self->on_connect; + c->on_close = ($action2)self->on_close; + c->on_hostkey = ($action3)self->on_hostkey; + c->connect_timeout = fromB_float(self->connect_timeout); + c->auth_timeout = fromB_float(self->auth_timeout); + c->keepalive_interval = fromB_float(self->keepalive_interval); + c->keepalive_enabled = fromB_bool(self->keepalive_enabled) ? 1 : 0; + + self->_client = toB_u64((unsigned long)c); + + c->session = ssh_new(); + if (c->session == NULL) { + client_fail(c, "Failed to create SSH session"); + return $R_CONT(c$cont, B_None); + } + + int rc; + int strict = 1; + int port = (int)fromB_u16(self->port); + bool process_config = false; + + rc = ssh_options_set(c->session, SSH_OPTIONS_PROCESS_CONFIG, &process_config); + if (rc != SSH_OK) { + client_fail(c, "Failed to disable SSH config processing"); + return $R_CONT(c$cont, B_None); + } + + rc = ssh_options_set(c->session, SSH_OPTIONS_HOST, fromB_str(self->host)); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH host"); + return $R_CONT(c$cont, B_None); + } + rc = ssh_options_set(c->session, SSH_OPTIONS_PORT, &port); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH port"); + return $R_CONT(c$cont, B_None); + } + rc = ssh_options_set(c->session, SSH_OPTIONS_USER, fromB_str(self->username)); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH username"); + return $R_CONT(c$cont, B_None); + } + if (self->known_hosts != B_None) { + const char *known_hosts = (const char *)fromB_str(self->known_hosts); + rc = ssh_options_set(c->session, SSH_OPTIONS_KNOWNHOSTS, known_hosts); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH known_hosts path"); + return $R_CONT(c$cont, B_None); + } + rc = ssh_options_set(c->session, SSH_OPTIONS_GLOBAL_KNOWNHOSTS, known_hosts); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH global known_hosts path"); + return $R_CONT(c$cont, B_None); + } + } + rc = ssh_options_set(c->session, SSH_OPTIONS_STRICTHOSTKEYCHECK, &strict); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH strict host key checking"); + return $R_CONT(c$cont, B_None); + } + + ssh_set_blocking(c->session, 0); + + c->state = CLIENT_STATE_CONNECTING; + rc = ssh_connect(c->session); + if (rc == SSH_OK) { + c->state = CLIENT_STATE_HOSTKEY; + } else if (rc == SSH_AGAIN) { + c->state = CLIENT_STATE_CONNECTING; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH connect failed: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return $R_CONT(c$cont, B_None); + } + + c->fd = ssh_get_fd(c->session); + if (c->fd < 0) { + client_fail(c, "Failed to get SSH session fd"); + return $R_CONT(c$cont, B_None); + } + + c->poll = acton_calloc(1, sizeof(uv_poll_t)); + c->poll->data = c; + int uv_rc = uv_poll_init(get_uv_loop(), c->poll, c->fd); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + client_fail(c, errmsg); + return $R_CONT(c$cont, B_None); + } + c->poll_events = UV_READABLE | UV_WRITABLE; + uv_rc = uv_poll_start(c->poll, c->poll_events, poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + client_fail(c, errmsg); + return $R_CONT(c$cont, B_None); + } + + client_start_connect_timer(c); + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_accept_hostkeyG_local(sshQ_Client self, $Cont c$cont) { + ssh_client_ctx *c = client_from_actor(self); + if (c == NULL) + return $R_CONT(c$cont, B_None); + if (c->state != CLIENT_STATE_HOSTKEY_WAIT) + return $R_CONT(c$cont, B_None); + + c->state = CLIENT_STATE_AUTH; + client_start_auth_timer(c); + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_reject_hostkeyG_local(sshQ_Client self, $Cont c$cont, B_str reason) { + ssh_client_ctx *c = client_from_actor(self); + if (c == NULL) + return $R_CONT(c$cont, B_None); + if (c->state != CLIENT_STATE_HOSTKEY_WAIT) + return $R_CONT(c$cont, B_None); + + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Host key rejected: %s", fromB_str(reason)); + client_fail(c, errmsg); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_closeG_local(sshQ_Client self, $Cont c$cont) { + ssh_client_ctx *c = client_from_actor(self); + if (c == NULL) + return $R_CONT(c$cont, B_None); + client_close_internal(c, "closed"); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_createG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, + $action on_open, + $action on_stdout, + $action on_stderr, + $action on_exit, + $action on_close) { + ssh_client_ctx *c = client_from_actor(self); + if (c == NULL) { + if (on_open) { + $action2 f = ($action2)on_open; + f->$class->__asyn__(f, channel, to$str((char *)"Client not initialized")); + } + return $R_CONT(c$cont, B_None); + } + + ssh_channel_ctx *ch = acton_calloc(1, sizeof(ssh_channel_ctx)); + ch->actor = channel; + ch->state = CHAN_STATE_INIT; + ch->pending_req = CHAN_REQ_NONE; + ch->pty_pending = 0; + ch->pty_done = 0; + ch->term = to$str((char *)"xterm-256color"); + ch->cols = 80; + ch->rows = 24; + ch->width_px = 0; + ch->height_px = 0; + ch->on_open = ($action2)on_open; + ch->on_close = ($action2)on_close; + ch->on_stdout = ($action2)on_stdout; + ch->on_stderr = ($action2)on_stderr; + ch->on_exit = ($action3)on_exit; + + ch->next = c->channels; + c->channels = ch; + + channel->_channel_id = toB_u64((unsigned long)ch); + + if (c->state == CLIENT_STATE_READY) + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) { + if (c == NULL || ch == NULL) { + return -1; + } + if (ch->state == CHAN_STATE_ERROR || ch->state == CHAN_STATE_CLOSED) { + return -1; + } + return 0; +} + +$R sshQ_ClientD_channel_request_execG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str cmd) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + if (channel_validate(c, ch) != 0) { + if (ch != NULL) + channel_notify_error(ch, "Channel not ready"); + return $R_CONT(c$cont, B_None); + } + + if (ch->pending_req != CHAN_REQ_NONE) { + channel_notify_error(ch, "Channel already has a pending request"); + return $R_CONT(c$cont, B_None); + } + if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) { + channel_notify_error(ch, "Channel already running"); + return $R_CONT(c$cont, B_None); + } + + ch->exec_cmd = cmd; + ch->pending_req = CHAN_REQ_EXEC; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_request_shellG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str term, B_int cols, B_int rows, B_int width_px, B_int height_px, B_bool with_pty) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + if (channel_validate(c, ch) != 0) { + if (ch != NULL) + channel_notify_error(ch, "Channel not ready"); + return $R_CONT(c$cont, B_None); + } + + if (ch->pending_req != CHAN_REQ_NONE) { + channel_notify_error(ch, "Channel already has a pending request"); + return $R_CONT(c$cont, B_None); + } + if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) { + channel_notify_error(ch, "Channel already running"); + return $R_CONT(c$cont, B_None); + } + + ch->term = term; + ch->cols = fromB_int(cols); + ch->rows = fromB_int(rows); + ch->width_px = fromB_int(width_px); + ch->height_px = fromB_int(height_px); + ch->pty_pending = fromB_bool(with_pty) ? 1 : 0; + + ch->pending_req = CHAN_REQ_SHELL; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_request_subsystemG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str name) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + if (channel_validate(c, ch) != 0) { + if (ch != NULL) + channel_notify_error(ch, "Channel not ready"); + return $R_CONT(c$cont, B_None); + } + + if (ch->pending_req != CHAN_REQ_NONE) { + channel_notify_error(ch, "Channel already has a pending request"); + return $R_CONT(c$cont, B_None); + } + if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) { + channel_notify_error(ch, "Channel already running"); + return $R_CONT(c$cont, B_None); + } + + ch->subsystem = name; + ch->pending_req = CHAN_REQ_SUBSYSTEM; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_writeG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_bytes data) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + if (channel_validate(c, ch) != 0) { + if (ch != NULL) + channel_notify_error(ch, "Channel not ready"); + return $R_CONT(c$cont, B_None); + } + + channel_queue_write(ch, data); + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_send_eofG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + if (channel_validate(c, ch) != 0) + return $R_CONT(c$cont, B_None); + + ch->send_eof = 1; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_closeG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + if (channel_validate(c, ch) != 0) + return $R_CONT(c$cont, B_None); + + ch->send_eof = 1; + ch->close_requested = 1; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +// --- Server implementation + +static void server_notify_listen(ssh_server_ctx *s, const char *err) { + if (s == NULL) + return; + if (s->listen_notified) + return; + if (s->on_listen) { + $action2 f = ($action2)s->on_listen; + f->$class->__asyn__(f, s->actor, err ? to$str((char *)err) : B_None); + } + s->listen_notified = 1; + if (err == NULL) + s->listen_ok = 1; +} + +static void server_notify_close(ssh_server_ctx *s, const char *reason) { + if (s->close_notified) + return; + if (!s->listen_ok) + return; + if (s->on_close) { + $action2 f = ($action2)s->on_close; + f->$class->__asyn__(f, s->actor, to$str((char *)reason)); + } + s->close_notified = 1; +} + +static void session_notify_close(ssh_server_session_ctx *s, const char *reason) { + if (s->close_notified) + return; + if (s->on_close) { + $action2 f = ($action2)s->on_close; + f->$class->__asyn__(f, s->actor, to$str((char *)reason)); + } + s->close_notified = 1; } -void sshQ___ext_init__() { - // TODO: can we avoid custom malloc in libssh? like let libssh use stock - // malloc and instead we would explicitly call free() from a finalizer() - // All things related to buffers for receiving data and similarly would have - // to be allocated on the GC-heap though since that data is passed outside - // of the SSH actor - libssh_replace_allocator( - acton_gc_malloc, - acton_gc_realloc, - acton_gc_calloc, - noop_free, - acton_gc_strdup, - acton_gc_strndup); - int r = ssh_init(); - printf("SSH extension initialized %d\n", r); -} - -B_str sshQ_version () { - if (LIBSSH_VERSION_MINOR != 11) - return to$str("invalid"); - return to$str("0.11.0"); -} - -// TODO: crap function for test, to be replaced with something -int show_remote_processes(ssh_session session) -{ - ssh_channel channel; - int rc; - char buffer[256]; - int nbytes; - - channel = ssh_channel_new(session); - if (channel == NULL) - return SSH_ERROR; - - rc = ssh_channel_open_session(channel); - if (rc != SSH_OK) - { - ssh_channel_free(channel); - return rc; - } - - rc = ssh_channel_request_exec(channel, "ps aux"); - if (rc != SSH_OK) - { - ssh_channel_close(channel); - ssh_channel_free(channel); - return rc; - } - - nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0); - while (nbytes > 0) - { - if (write(1, buffer, nbytes) != (unsigned int) nbytes) - { - ssh_channel_close(channel); - ssh_channel_free(channel); - return SSH_ERROR; - } - nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0); - } - - if (nbytes < 0) - { - ssh_channel_close(channel); - ssh_channel_free(channel); - return SSH_ERROR; - } - - ssh_channel_send_eof(channel); - ssh_channel_close(channel); - ssh_channel_free(channel); - - return SSH_OK; -} - -$R sshQ_ClientD__initG_local (sshQ_Client self, $Cont c$cont) { - ssh_session session = ssh_new(); - if (session == NULL) { - //log_error("Failed to create SSH session"); - return $R_CONT(c$cont, B_None); - } - printf("session: %p\n", session); - self->_ssh_session = toB_u64((unsigned long)session); - printf("init self->session: %p\n", self->_ssh_session); - - ssh_options_set(session, SSH_OPTIONS_HOST, fromB_str(self->host)); - ssh_options_set(session, SSH_OPTIONS_PORT, &self->port->val); - ssh_options_set(session, SSH_OPTIONS_USER, fromB_str(self->username)); - - ssh_set_blocking(session, 1); - printf("Connecting to \n"); - int rc = ssh_connect(session); +static void server_channel_notify_close(ssh_server_channel_ctx *ch, const char *reason) { + if (ch->close_notified) + return; + if (ch->on_close) { + $action2 f = ($action2)ch->on_close; + f->$class->__asyn__(f, ch->actor, to$str((char *)reason)); + } + ch->close_notified = 1; +} + +static void server_channel_notify_eof(ssh_server_channel_ctx *ch) { + if (ch->channel == NULL) + return; + if (ssh_channel_is_eof(ch->channel)) { + if (!ch->stdout_eof && ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stderr_eof = 1; + } + } +} + +static void server_channel_finalize(ssh_server_channel_ctx *ch) { + if (ch->channel != NULL) { + ssh_channel_free(ch->channel); + ch->channel = NULL; + } + if (!ch->stdout_eof && ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stderr_eof = 1; + } + server_channel_notify_close(ch, "closed"); + if (ch->actor) + ch->actor->_channel_id = toB_u64(0); + ch->state = SCHAN_STATE_CLOSED; +} + +static void server_channel_queue_write(ssh_server_channel_ctx *ch, B_bytes data, int is_stderr) { + server_write_chunk_t *chunk = acton_calloc(1, sizeof(server_write_chunk_t)); + chunk->data = data; + chunk->offset = 0; + chunk->is_stderr = is_stderr; + chunk->next = NULL; + if (ch->write_tail) { + ch->write_tail->next = chunk; + } else { + ch->write_head = chunk; + } + ch->write_tail = chunk; + if (ssh_debug_enabled) { + ssh_debug_log("server channel queue write: %zu bytes", data ? data->nbytes : 0); + } +} + +static void server_channel_try_write(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { + while (ch->write_head != NULL) { + server_write_chunk_t *chunk = ch->write_head; + size_t remaining = chunk->data->nbytes - chunk->offset; + if (remaining == 0) { + ch->write_head = chunk->next; + if (ch->write_head == NULL) + ch->write_tail = NULL; + continue; + } + session_mark_writable(s); + int rc; + if (chunk->is_stderr) { + rc = ssh_channel_write_stderr(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + } else { + rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + } + if (rc > 0) { + chunk->offset += (size_t)rc; + if (ssh_debug_enabled) { + ssh_debug_log("server channel wrote %d bytes", rc); + } + if (chunk->offset >= chunk->data->nbytes) { + ch->write_head = chunk->next; + if (ch->write_head == NULL) + ch->write_tail = NULL; + } + } else if (rc == 0 || rc == SSH_AGAIN) { + s->write_ready = 0; + if (ssh_debug_enabled) { + unsigned int window = ssh_channel_window_size(ch->channel); + ssh_debug_log("server channel write pending rc=%d remaining=%zu window=%u", + rc, remaining, window); + } + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server channel write error: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + return; + } + } +} + +static int server_channel_read_stream(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch, int is_stderr) { + char buf[SSH_READ_BUFSIZE]; + int read_any = 0; + for (;;) { + int n = ssh_channel_read_nonblocking(ch->channel, buf, sizeof(buf), is_stderr); + if (ssh_debug_enabled) { + ssh_debug_log("server channel read: rc=%d stderr=%d", n, is_stderr); + } + if (n > 0) { + read_any = 1; + B_bytes out = to$bytesD_len(buf, n); + if (is_stderr) { + if (ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, out); + } + } else { + if (ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, ch->actor, out); + } + } + continue; + } + if (n == 0 || n == SSH_AGAIN) { + break; + } + if (n == SSH_EOF) { + break; + } + if (n == SSH_ERROR) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server channel read error: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + break; + } + } + return read_any; +} + +static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { + if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) + return; + + if (ch->write_head != NULL) { + server_channel_try_write(s, ch); + if (ch->state == SCHAN_STATE_ERROR) + return; + } + + if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL) { + session_mark_writable(s); + int rc = ssh_channel_send_eof(ch->channel); + if (rc == SSH_OK) { + ch->eof_sent = 1; + } else if (rc == SSH_AGAIN) { + s->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server send EOF failed: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + return; + } + } + + if (ch->close_requested && !ch->close_sent && ch->write_head == NULL) { + session_mark_writable(s); + int rc = ssh_channel_close(ch->channel); + if (rc == SSH_OK) { + ch->close_sent = 1; + ch->state = SCHAN_STATE_CLOSING; + } else if (rc == SSH_AGAIN) { + s->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server channel close failed: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + return; + } + } + + for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + did |= server_channel_read_stream(s, ch, 0); + did |= server_channel_read_stream(s, ch, 1); + if (!did) + break; + } + server_channel_notify_eof(ch); + + if (ch->channel != NULL && ssh_channel_is_closed(ch->channel)) { + server_channel_finalize(ch); + } +} + +static void session_drive_channels(ssh_server_session_ctx *s) { + ssh_server_channel_ctx *prev = NULL; + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + ssh_server_channel_ctx *next = ch->next; + server_channel_drive(s, ch); + if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) { + if (prev != NULL) { + prev->next = next; + } else { + s->channels = next; + } + ch->next = NULL; + } else { + prev = ch; + } + ch = next; + } +} + +static void session_fail(ssh_server_session_ctx *s, const char *msg) { + if (s == NULL || s->state == SESSION_STATE_CLOSED || s->state == SESSION_STATE_ERROR) + return; + s->state = SESSION_STATE_ERROR; + session_close_internal(s, msg); +} + +static void session_start_auth_timer(ssh_server_session_ctx *s) { + if (s == NULL || s->auth_timeout <= 0.0 || s->auth_timer != NULL) + return; + s->auth_timer = acton_calloc(1, sizeof(uv_timer_t)); + s->auth_timer->data = s; + uv_timer_init(get_uv_loop(), s->auth_timer); + uv_timer_start(s->auth_timer, session_auth_timeout_cb, (uint64_t)(s->auth_timeout * 1000.0), 0); +} + +static void session_start_keepalive(ssh_server_session_ctx *s) { + if (s == NULL || !s->keepalive_enabled || s->keepalive_interval <= 0.0 || s->keepalive_timer != NULL) + return; + s->keepalive_timer = acton_calloc(1, sizeof(uv_timer_t)); + s->keepalive_timer->data = s; + uv_timer_init(get_uv_loop(), s->keepalive_timer); + uv_timer_start(s->keepalive_timer, session_keepalive_cb, (uint64_t)(s->keepalive_interval * 1000.0), (uint64_t)(s->keepalive_interval * 1000.0)); +} + +static void server_fail(ssh_server_ctx *s, const char *msg) { + if (s == NULL || s->state == SERVER_STATE_CLOSED || s->state == SERVER_STATE_ERROR) + return; + s->state = SERVER_STATE_ERROR; + if (!s->listen_ok) + server_notify_listen(s, msg); + server_close_internal(s, msg); +} + +static void session_auth_timeout_cb(uv_timer_t *timer) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)timer->data; + if (s == NULL) + return; + if (s->state == SESSION_STATE_AUTH) { + session_fail(s, "SSH authentication timeout"); + } +} + +static void session_keepalive_cb(uv_timer_t *timer) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)timer->data; + if (s == NULL || s->session == NULL) + return; + if (s->state != SESSION_STATE_READY) + return; + int rc = ssh_send_ignore(s->session, "keepalive"); + if (rc == SSH_AGAIN) { + session_update_poll(s); + return; + } + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server keepalive failed: %s", ssh_get_error(s->session)); + session_fail(s, errmsg); + return; + } + session_update_poll(s); +} + +static void session_update_poll(ssh_server_session_ctx *s) { + if (s->poll == NULL || s->session == NULL) + return; + if (s->state == SESSION_STATE_CLOSING || s->state == SESSION_STATE_CLOSED) + return; + if (uv_is_closing((uv_handle_t *)s->poll)) + return; + int status = ssh_get_status(s->session); + if (status & SSH_CLOSED_ERROR) { + session_fail(s, "SSH session closed with error"); + return; + } + if (status & SSH_CLOSED) { + session_close_internal(s, "SSH session closed"); + return; + } + int flags = ssh_get_poll_flags(s->session); + int pending = flags | status; + int events = UV_READABLE; + if (pending & SSH_WRITE_PENDING) + events |= UV_WRITABLE; + if ((events & UV_WRITABLE) == 0 && session_needs_write(s)) + events |= UV_WRITABLE; +#ifdef UV_DISCONNECT + events |= UV_DISCONNECT; +#endif + if (ssh_debug_enabled && (events != s->poll_events || (pending & SSH_WRITE_PENDING))) { + ssh_debug_log("server update poll: status=0x%x flags=0x%x pending=0x%x events=0x%x state=%d", + status, flags, pending, events, s->state); + } + if (events != s->poll_events) { + int uv_rc = uv_poll_start(s->poll, events, session_poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + session_fail(s, errmsg); + return; + } + s->poll_events = events; + } +} + +static void session_pump_io(ssh_server_session_ctx *s) { + if (s == NULL || s->session == NULL) + return; + for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + if (s->session == NULL) + return; + int has_data = fd_has_data(s->fd); + if (has_data) { + if (ssh_debug_enabled) { + ssh_debug_log("server pump: readable"); + } + ssh_set_fd_toread(s->session); + if (session_apply_poll_events(s->session, UV_READABLE) != 0) { + session_fail(s, "SSH poll callback error"); + return; + } + session_drive(s); + did = 1; + } + if (s->session == NULL) + return; + int pending = ssh_get_status(s->session) | ssh_get_poll_flags(s->session); + int can_write = (pending & SSH_WRITE_PENDING) && fd_is_writable(s->fd); + if (can_write) { + if (ssh_debug_enabled) { + ssh_debug_log("server pump: writable pending=0x%x", pending); + } + s->write_ready = 1; + ssh_set_fd_towrite(s->session); + if (session_apply_poll_events(s->session, UV_WRITABLE) != 0) { + session_fail(s, "SSH poll callback error"); + return; + } + session_drive(s); + did = 1; + } + if (!did) { + int status = ssh_get_status(s->session); + if (status & SSH_READ_PENDING) { + if (ssh_debug_enabled) { + ssh_debug_log("server pump: buffered read pending=0x%x", status); + } + session_drive(s); + /* Avoid spinning when only buffered data remains. */ + break; + } + } + if (!did) + break; + } +} + +static int session_needs_write(ssh_server_session_ctx *s) { + if (s == NULL) + return 0; + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + if (ch->pending_req != NULL || ch->write_head != NULL) + return 1; + if (ch->send_eof && !ch->eof_sent) + return 1; + if (ch->close_requested && !ch->close_sent) + return 1; + ch = ch->next; + } + return 0; +} + +static void session_drive(ssh_server_session_ctx *s) { + if (s == NULL) + return; + if (s->state == SESSION_STATE_ERROR || s->state == SESSION_STATE_CLOSED || s->state == SESSION_STATE_CLOSING) + return; + if (!s->attached) + return; + + while (1) { + if (s->state == SESSION_STATE_KEYEX) { + session_mark_writable(s); + int rc = ssh_handle_key_exchange(s->session); + if (rc == SSH_OK) { + s->state = SESSION_STATE_AUTH; + session_start_auth_timer(s); + ssh_set_auth_methods(s->session, SSH_AUTH_METHOD_PASSWORD); + continue; + } else if (rc == SSH_AGAIN) { + if (ssh_get_status(s->session) & SSH_WRITE_PENDING) + s->write_ready = 0; + session_update_poll(s); + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH key exchange failed: %s", ssh_get_error(s->session)); + session_fail(s, errmsg); + return; + } + } + + if (s->state == SESSION_STATE_AUTH) { + if (s->pending_auth != NULL) { + session_update_poll(s); + return; + } + ssh_message msg = ssh_message_get(s->session); + if (msg == NULL) { + session_update_poll(s); + return; + } + int type = ssh_message_type(msg); + if (type == SSH_REQUEST_SERVICE) { + const char *service = ssh_message_service_service(msg); + if (service && strcmp(service, "ssh-userauth") == 0) { + ssh_message_service_reply_success(msg); + } else { + ssh_message_reply_default(msg); + } + ssh_message_free(msg); + continue; + } + if (type == SSH_REQUEST_AUTH && ssh_message_subtype(msg) == SSH_AUTH_METHOD_PASSWORD) { + if (s->on_auth == NULL) { + ssh_message_auth_set_methods(msg, SSH_AUTH_METHOD_PASSWORD); + ssh_message_reply_default(msg); + ssh_message_free(msg); + continue; + } + const char *user = ssh_message_auth_user(msg); +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + const char *pass = ssh_message_auth_password(msg); +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + s->pending_auth = msg; + sshQ_AuthRequest req = sshQ_AuthRequestG_new( + to$str((char *)"password"), + to$str((char *)(user ? user : "")), + pass ? to$str((char *)(pass)) : B_None, + B_None); + $action2 f = ($action2)s->on_auth; + f->$class->__asyn__(f, s->actor, req); + session_update_poll(s); + return; + } + ssh_message_reply_default(msg); + ssh_message_free(msg); + continue; + } + + if (s->state == SESSION_STATE_READY) { + if (s->pending_channel_open != NULL) { + session_drive_channels(s); + session_update_poll(s); + return; + } + while (1) { + ssh_message msg = ssh_message_get(s->session); + if (msg == NULL) + break; + int type = ssh_message_type(msg); + if (type == SSH_REQUEST_CHANNEL_OPEN) { + if (s->pending_channel_open != NULL) { + ssh_message_reply_default(msg); + ssh_message_free(msg); + } else if (ssh_message_subtype(msg) != SSH_CHANNEL_SESSION) { + ssh_message_reply_default(msg); + ssh_message_free(msg); + } else if (s->on_channel_open == NULL) { + ssh_message_reply_default(msg); + ssh_message_free(msg); + } else { + s->pending_channel_open = msg; + $action f = ($action)s->on_channel_open; + f->$class->__asyn__(f, s->actor); + break; + } + } else if (type == SSH_REQUEST_CHANNEL) { + ssh_channel chan = ssh_message_channel_request_channel(msg); + ssh_server_channel_ctx *ch = server_channel_from_ssh(s, chan); + if (ch == NULL || ch->pending_req != NULL) { + ssh_message_reply_default(msg); + ssh_message_free(msg); + } else if (ssh_message_subtype(msg) == SSH_CHANNEL_REQUEST_EXEC) { + if (s->on_exec == NULL) { + ssh_message_reply_default(msg); + ssh_message_free(msg); + } else { + const char *cmd = ssh_message_channel_request_command(msg); + ch->pending_req = msg; + ch->pending_req_type = SCHAN_REQ_EXEC; + $action3 f = ($action3)s->on_exec; + f->$class->__asyn__(f, s->actor, ch->actor, + to$str((char *)(cmd ? cmd : ""))); + break; + } + } else if (ssh_message_subtype(msg) == SSH_CHANNEL_REQUEST_SUBSYSTEM) { + if (s->on_subsystem == NULL) { + ssh_message_reply_default(msg); + ssh_message_free(msg); + } else { + const char *name = ssh_message_channel_request_subsystem(msg); + ch->pending_req = msg; + ch->pending_req_type = SCHAN_REQ_SUBSYSTEM; + $action3 f = ($action3)s->on_subsystem; + f->$class->__asyn__(f, s->actor, ch->actor, + to$str((char *)(name ? name : ""))); + break; + } + } else { + ssh_message_reply_default(msg); + ssh_message_free(msg); + } + } else if (type == SSH_REQUEST_SERVICE) { + const char *service = ssh_message_service_service(msg); + if (service && strcmp(service, "ssh-connection") == 0) { + ssh_message_service_reply_success(msg); + } else { + ssh_message_reply_default(msg); + } + ssh_message_free(msg); + } else { + ssh_message_reply_default(msg); + ssh_message_free(msg); + } + } + session_drive_channels(s); + session_update_poll(s); + return; + } + return; + } +} + +static void session_poll_cb(uv_poll_t *handle, int status, int events) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; + if (s == NULL) + return; + if (ssh_debug_enabled) { + ssh_debug_log("server poll: status=%d events=0x%x state=%d", status, events, s->state); + } + if (status < 0) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH session poll error: %s", uv_strerror(status)); + session_fail(s, errmsg); + return; + } + if (events & UV_READABLE) + ssh_set_fd_toread(s->session); + if (events & UV_WRITABLE) { + s->write_ready = 1; + ssh_set_fd_towrite(s->session); + } + if (session_apply_poll_events(s->session, events) != 0) { + session_fail(s, "SSH poll callback error"); + return; + } + session_drive(s); + session_pump_io(s); +} + +static void server_poll_cb(uv_poll_t *handle, int status, int events) { + ssh_server_ctx *s = (ssh_server_ctx *)handle->data; + if (s == NULL) + return; + if (ssh_debug_enabled) { + ssh_debug_log("server listen poll: status=%d events=0x%x", status, events); + } + if (status < 0) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server poll error: %s", uv_strerror(status)); + server_fail(s, errmsg); + return; + } + if (events & UV_READABLE) + server_accept(s); +} + +static void server_accept(ssh_server_ctx *s) { + if (s == NULL || s->state != SERVER_STATE_LISTENING) + return; + + while (1) { + socket_t fd = accept(s->fd, NULL, NULL); + if (fd == SSH_INVALID_SOCKET) { + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH accept failed: %s", strerror(errno)); + server_fail(s, errmsg); + return; + } + + if (fd_set_nonblocking(fd) != 0) { + close(fd); + server_fail(s, "Failed to set accepted fd nonblocking"); + return; + } + + ssh_session session = ssh_new(); + if (session == NULL) { + close(fd); + server_fail(s, "Failed to create SSH session"); + return; + } + + if (ssh_debug_enabled) { + ssh_debug_log("server accept: actor=%p", (void *)s->actor); + } + + int rc = ssh_bind_accept_fd(s->bind, session, fd); + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH accept failed: %s", ssh_get_error(s->bind)); + close(fd); + ssh_free(session); + server_fail(s, errmsg); + return; + } + + ssh_set_blocking(session, 0); + ssh_server_session_ctx *sess = acton_calloc(1, sizeof(ssh_server_session_ctx)); + sess->server = s; + sess->session = session; + sess->state = SESSION_STATE_KEYEX; + sess->fd = ssh_get_fd(session); + sess->owner_wt = s->actor ? (int)s->actor->$affinity : 0; + if (sess->fd < 0) { + ssh_free(session); + server_fail(s, "Failed to get SSH session fd"); + return; + } + sess->poll = acton_calloc(1, sizeof(uv_poll_t)); + sess->poll->data = sess; + int uv_rc = uv_poll_init(get_uv_loop(), sess->poll, sess->fd); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + ssh_free(session); + server_fail(s, errmsg); + return; + } + sess->poll_events = UV_READABLE | UV_WRITABLE; + uv_rc = uv_poll_start(sess->poll, sess->poll_events, session_poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + ssh_free(session); + server_fail(s, errmsg); + return; + } + + sess->next = s->sessions; + s->sessions = sess; + + if (ssh_debug_enabled) { + ssh_debug_log("server accept: session fd=%d", sess->fd); + } + + if (s->actor) { + sshQ_Server act = (sshQ_Server)s->actor; + if (ssh_debug_enabled) { + ssh_debug_log("server accept: scheduling session pending act=%p session=%p", (void *)act, (void *)sess); + } + act->$class->on_session_pending(act, toB_u64((unsigned long)sess)); + if (ssh_debug_enabled) { + ssh_debug_log("server accept: on_session_pending call returned session=%p", (void *)sess); + } + } + } +} + +static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *sess) { + if (s == NULL || sess == NULL) + return; + ssh_server_session_ctx *prev = NULL; + ssh_server_session_ctx *cur = s->sessions; + while (cur != NULL) { + if (cur == sess) { + if (prev != NULL) + prev->next = cur->next; + else + s->sessions = cur->next; + break; + } + prev = cur; + cur = cur->next; + } +} + +static void server_finalize(ssh_server_ctx *s) { + if (s == NULL || s->close_finalized) + return; + s->close_finalized = 1; + + if (s->bind != NULL) { + ssh_bind_free(s->bind); + s->bind = NULL; + } + if (s->hostkey != NULL) { + ssh_key_free(s->hostkey); + s->hostkey = NULL; + } + + server_notify_close(s, s->close_reason ? s->close_reason : "closed"); + s->state = SERVER_STATE_CLOSED; + if (s->actor) + s->actor->_server = toB_u64(0); +} + +static void session_finalize(ssh_server_session_ctx *s) { + if (s == NULL || s->close_finalized) + return; + s->close_finalized = 1; + + if (s->session != NULL) { + ssh_disconnect(s->session); + ssh_free(s->session); + s->session = NULL; + } + + server_remove_session(s->server, s); + session_notify_close(s, s->close_reason ? s->close_reason : "closed"); + s->state = SESSION_STATE_CLOSED; + if (s->actor) + s->actor->_session_id = toB_u64(0); +} + +static void server_close_internal(ssh_server_ctx *s, const char *reason) { + if (s == NULL || s->state == SERVER_STATE_CLOSED || s->state == SERVER_STATE_CLOSING) + return; + + if (!s->listen_ok && !s->listen_notified) { + server_notify_listen(s, reason ? reason : "closed"); + } + + s->state = SERVER_STATE_CLOSING; + if (reason != NULL && s->close_reason == NULL) + s->close_reason = acton_strdup(reason); + + if (s->poll != NULL) { + close_poll(&s->poll, server_poll_close_cb); + } + + ssh_server_session_ctx *sess = s->sessions; + while (sess != NULL) { + ssh_server_session_ctx *next = sess->next; + session_close_internal(sess, "Server closed"); + sess = next; + } + s->sessions = NULL; + + if (s->poll != NULL) + return; + + server_finalize(s); +} + +static void session_close_internal(ssh_server_session_ctx *s, const char *reason) { + if (s == NULL || s->state == SESSION_STATE_CLOSED || s->state == SESSION_STATE_CLOSING) + return; + + s->state = SESSION_STATE_CLOSING; + if (reason != NULL && s->close_reason == NULL) + s->close_reason = acton_strdup(reason); + + stop_timer(&s->auth_timer, session_timer_close_cb); + stop_timer(&s->keepalive_timer, session_timer_close_cb); + + if (s->poll != NULL) { + close_poll(&s->poll, session_poll_close_cb); + s->poll_events = 0; + } + + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + ssh_server_channel_ctx *next = ch->next; + server_channel_notify_close(ch, "Session closed"); + server_channel_finalize(ch); + ch->next = NULL; + ch = next; + } + s->channels = NULL; + + if (s->pending_auth) { + ssh_message_free(s->pending_auth); + s->pending_auth = NULL; + } + if (s->pending_channel_open) { + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + } + + if (s->poll != NULL) + return; + + session_finalize(s); +} + +static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_out) { + if (type_str == NULL) + return SSH_KEYTYPE_UNKNOWN; + if (strcmp(type_str, "ed25519") == 0) { + if (param_out) + *param_out = 0; + return SSH_KEYTYPE_ED25519; + } + if (strcmp(type_str, "rsa") == 0) { + if (param_out) + *param_out = 2048; + return SSH_KEYTYPE_RSA; + } + if (strcmp(type_str, "ecdsa") == 0) { + if (param_out) + *param_out = 256; + return SSH_KEYTYPE_ECDSA; + } + return SSH_KEYTYPE_UNKNOWN; +} + +$R sshQ_ServerD__pin_affinityG_local(sshQ_Server self, $Cont c$cont) { + pin_actor_affinity(); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerD__initG_local(sshQ_Server self, $Cont c$cont) { + ssh_configure_libssh_logging(); + if (ssh_debug_enabled) { + log_info("ssh server init: self=%p", (void *)self); + } + ssh_server_ctx *s = acton_calloc(1, sizeof(ssh_server_ctx)); + s->actor = self; + s->on_listen = ($action2)self->_on_listen; + s->on_close = ($action2)self->_on_close; + s->state = SERVER_STATE_INIT; + + self->_server = toB_u64((unsigned long)s); + + s->bind = ssh_bind_new(); + if (s->bind == NULL) { + server_fail(s, "Failed to create SSH bind"); + return $R_CONT(c$cont, B_None); + } + + int rc; + bool process_config = false; + + rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_PROCESS_CONFIG, &process_config); if (rc != SSH_OK) { - //log_error("Error connecting to SSH server: %s", ssh_get_error(session)); - $action2 f = ($action2) self->on_close; - f->$class->__asyn__(f, self, to$str(ssh_get_error(session))); + server_fail(s, "Failed to disable SSH bind config processing"); return $R_CONT(c$cont, B_None); } - rc = ssh_userauth_password(session, NULL, fromB_str(self->password)); - if (rc == SSH_OK) { - printf("Connected\n"); - show_remote_processes(session); + const char *host = (const char *)fromB_str(self->_host); + int port = (int)fromB_u16(self->_port); + rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_BINDADDR, host); + if (rc != SSH_OK) { + server_fail(s, "Failed to set bind address"); + return $R_CONT(c$cont, B_None); + } + rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_BINDPORT, &port); + if (rc != SSH_OK) { + server_fail(s, "Failed to set bind port"); + return $R_CONT(c$cont, B_None); + } + + if (self->_host_key_path != B_None) { + const char *path = (const char *)fromB_str(self->_host_key_path); + rc = ssh_pki_import_privkey_file(path, NULL, NULL, NULL, &s->hostkey); + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to load host key: %s", ssh_get_error(s->bind)); + server_fail(s, errmsg); + return $R_CONT(c$cont, B_None); + } + } else { + const char *type_str = (const char *)fromB_str(self->_host_key_type); + int param = fromB_int(self->_host_key_bits); + int default_param = 0; + enum ssh_keytypes_e type = parse_hostkey_type(type_str, &default_param); + if (type == SSH_KEYTYPE_UNKNOWN) { + server_fail(s, "Unsupported host key type"); + return $R_CONT(c$cont, B_None); + } + if (type == SSH_KEYTYPE_ED25519) { + param = 0; + } else if (param <= 0) { + param = default_param; + } + rc = ssh_pki_generate(type, param, &s->hostkey); + if (rc != SSH_OK) { + server_fail(s, "Failed to generate host key"); + return $R_CONT(c$cont, B_None); + } + } + + rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_IMPORT_KEY, s->hostkey); + if (rc != SSH_OK) { + server_fail(s, "Failed to set host key"); + return $R_CONT(c$cont, B_None); + } + + ssh_bind_set_blocking(s->bind, 0); + + rc = ssh_bind_listen(s->bind); + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH listen failed: %s", ssh_get_error(s->bind)); + server_fail(s, errmsg); + return $R_CONT(c$cont, B_None); + } + + s->fd = ssh_bind_get_fd(s->bind); + if (s->fd < 0) { + server_fail(s, "Failed to get bind fd"); + return $R_CONT(c$cont, B_None); + } + if (fd_set_nonblocking(s->fd) != 0) { + server_fail(s, "Failed to set bind fd nonblocking"); + return $R_CONT(c$cont, B_None); + } + + s->poll = acton_calloc(1, sizeof(uv_poll_t)); + s->poll->data = s; + int uv_rc = uv_poll_init(get_uv_loop(), s->poll, s->fd); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + server_fail(s, errmsg); + return $R_CONT(c$cont, B_None); + } + uv_rc = uv_poll_start(s->poll, UV_READABLE, server_poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + server_fail(s, errmsg); + return $R_CONT(c$cont, B_None); + } + + s->state = SERVER_STATE_LISTENING; + server_notify_listen(s, NULL); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerD_closeG_local(sshQ_Server self, $Cont c$cont) { + ssh_server_ctx *s = server_from_actor(self); + if (s == NULL) + return $R_CONT(c$cont, B_None); + server_close_internal(s, "closed"); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD__pin_affinityG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)(unsigned long)fromB_u64(self->session_id); + if (s != NULL && s->owner_wt >= 0) { + set_actor_affinity(s->owner_wt); } else { - printf("Error: %s\n", ssh_get_error(session)); + pin_actor_affinity(); + } + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD__attachG_local(sshQ_ServerSession self, $Cont c$cont, B_u64 session_id) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)(unsigned long)fromB_u64(session_id); + if (s == NULL) + return $R_CONT(c$cont, B_None); + if (ssh_debug_enabled) { + ssh_debug_log("server session attach: id=%llu", (unsigned long long)fromB_u64(session_id)); } + s->actor = self; + self->_session_id = session_id; + s->attached = 1; + s->on_auth = ($action2)self->_on_auth; + s->on_channel_open = ($action)self->_on_channel_open; + s->on_exec = (self->_on_exec == B_None) ? NULL : ($action3)self->_on_exec; + s->on_subsystem = (self->_on_subsystem == B_None) ? NULL : ($action3)self->_on_subsystem; + s->on_close = (self->_on_close == B_None) ? NULL : ($action2)self->_on_close; + s->auth_timeout = fromB_float(self->server->_auth_timeout); + s->keepalive_interval = fromB_float(self->server->_keepalive_interval); + s->keepalive_enabled = fromB_bool(self->server->_keepalive_enabled) ? 1 : 0; + if (ssh_debug_enabled) { + ssh_debug_log("server session attach: callbacks set, driving session"); + } + session_drive(s); + if (ssh_debug_enabled) { + ssh_debug_log("server session attach: session_drive returned"); + } + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_accept_authG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || s->pending_auth == NULL) + return $R_CONT(c$cont, B_None); + + ssh_message_auth_reply_success(s->pending_auth, 0); + ssh_message_free(s->pending_auth); + s->pending_auth = NULL; + s->state = SESSION_STATE_READY; + stop_timer(&s->auth_timer, session_timer_close_cb); + session_start_keepalive(s); + session_drive(s); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_reject_authG_local(sshQ_ServerSession self, $Cont c$cont, B_str reason) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || s->pending_auth == NULL) + return $R_CONT(c$cont, B_None); + + ssh_message_auth_set_methods(s->pending_auth, SSH_AUTH_METHOD_PASSWORD); + ssh_message_reply_default(s->pending_auth); + ssh_message_free(s->pending_auth); + s->pending_auth = NULL; + (void)reason; + session_drive(s); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_accept_channel_openG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, + $action on_data, + $action on_stderr, + $action on_close) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || s->pending_channel_open == NULL) + return $R_CONT(c$cont, B_None); + + ssh_channel chan = ssh_message_channel_request_open_reply_accept(s->pending_channel_open); + if (chan == NULL) { + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + session_drive(s); + return $R_CONT(c$cont, B_None); + } + ssh_channel_set_blocking(chan, 0); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + + ssh_server_channel_ctx *ch = acton_calloc(1, sizeof(ssh_server_channel_ctx)); + ch->channel = chan; + ch->session = s; + ch->actor = channel; + ch->state = SCHAN_STATE_OPEN; + ch->send_eof = 0; + ch->close_requested = 0; + ch->close_sent = 0; + ch->eof_sent = 0; + ch->stdout_eof = 0; + ch->stderr_eof = 0; + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + ch->on_data = ($action2)on_data; + ch->on_stderr = ($action2)on_stderr; + ch->on_close = ($action2)on_close; + + ch->next = s->channels; + s->channels = ch; + channel->_channel_id = toB_u64((unsigned long)ch); + + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_reject_channelG_local(sshQ_ServerSession self, $Cont c$cont, B_str reason) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || s->pending_channel_open == NULL) + return $R_CONT(c$cont, B_None); + + ssh_message_reply_default(s->pending_channel_open); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + (void)reason; + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_closeG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL) + return $R_CONT(c$cont, B_None); + session_close_internal(s, "closed"); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_accept_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (s == NULL || ch == NULL || ch->pending_req == NULL) + return $R_CONT(c$cont, B_None); + + ssh_message_channel_request_reply_success(ch->pending_req); + ssh_message_free(ch->pending_req); + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_reject_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_str reason) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (s == NULL || ch == NULL || ch->pending_req == NULL) + return $R_CONT(c$cont, B_None); + + ssh_message_reply_default(ch->pending_req); + ssh_message_free(ch->pending_req); + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + (void)reason; + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_writeG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_bytes data) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (s == NULL || ch == NULL) + return $R_CONT(c$cont, B_None); + server_channel_queue_write(ch, data, 0); + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_write_stderrG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_bytes data) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (s == NULL || ch == NULL) + return $R_CONT(c$cont, B_None); + server_channel_queue_write(ch, data, 1); + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_send_eofG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (s == NULL || ch == NULL) + return $R_CONT(c$cont, B_None); + ch->send_eof = 1; + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_send_exit_statusG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_int status) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (s == NULL || ch == NULL || ch->channel == NULL) + return $R_CONT(c$cont, B_None); + int rc = ssh_channel_request_send_exit_status(ch->channel, fromB_int(status)); + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server send exit status failed: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + } + session_drive(s); + return $R_CONT(c$cont, B_None); +} -// self->_connected = true; -// $action f = ($action) self->on_connect; -// f->$class->__asyn__(f, self); +$R sshQ_ServerSessionD_channel_closeG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (s == NULL || ch == NULL) + return $R_CONT(c$cont, B_None); + ch->send_eof = 1; + ch->close_requested = 1; + session_drive(s); return $R_CONT(c$cont, B_None); } From d61f834abe220875303b9e42f7bf2c407c5345da Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Tue, 6 Jan 2026 17:30:06 +0100 Subject: [PATCH 02/38] Fix server retry close handling in tests --- src/test_ssh_server.act | 273 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 src/test_ssh_server.act diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act new file mode 100644 index 0000000..023a4f1 --- /dev/null +++ b/src/test_ssh_server.act @@ -0,0 +1,273 @@ +import random +import testing +import logging + +import net +import ssh + + +actor ServerClientTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + NETCONF_END = b"]]>]]>" + NETCONF_HELLO = b'\n' + \ + b'\n' + \ + b' \n' + \ + b' urn:ietf:params:netconf:base:1.0\n' + \ + b' \n' + \ + b']]>]]>' + SERVER_HELLO = b'urn:ietf:params:netconf:base:1.0]]>]]>' + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_session: ?ssh.ServerSession = None + var server_buf = b"" + var client_buf = b"" + var client_channel: ?ssh.Channel = None + var server_channel: ?ssh.ServerChannel = None + var shutdown_started = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_channel_closed = False + var server_channel_closed = False + var client_close_requested = False + var server_close_requested = False + var client_channel_close_requested = False + var session_close_requested = False + var shutdown_timer_started = False + + def maybe_finish(): + if done: + return + if shutdown_started and client_closed and server_closed and session_closed and \ + client_channel_closed and server_channel_closed: + log.info("test success", None) + done = True + t.success() + + def request_client_channel_close(): + if client_channel_close_requested: + return + client_channel_close_requested = True + if client_channel is not None: + client_channel.close() + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def request_session_close(): + if session_close_requested or session_closed: + return + session_close_requested = True + if server_session is not None: + log.info("request session close", None) + server_session.close() + else: + log.info("request session close: no session", None) + + def begin_shutdown(reason: str): + if shutdown_started: + return + shutdown_started = True + log.info("shutdown start: " + reason, None) + request_session_close() + request_client_channel_close() + request_client_close() + if not shutdown_timer_started: + shutdown_timer_started = True + after 10.0: on_timeout() + maybe_finish() + + def finish_error(msg: str): + if done: + return + log.info("test error: " + msg, None) + done = True + shutdown_started = True + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + log.info("timeout fired", None) + finish_error("timeout waiting for SSH test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if done: + return + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + log.info("server listening", None) + server = s + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + log.info("server session ready", None) + server_session = sess + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + log.info("auth request for " + req.user, None) + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + log.info("channel open requested", None) + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + server_channel = ch + sess.accept_channel(ch) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + log.info("subsystem request: " + name, None) + if name == "netconf": + ch.accept_request() + else: + ch.reject_request("unsupported subsystem") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + log.info("exec request: " + cmd, None) + ch.reject_request("exec disabled") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + if shutdown_started: + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_buf += data + log.info("server data len " + str(len(data)), None) + if server_buf.find(NETCONF_END) >= 0: + log.info("server got NETCONF hello", None) + ch.write(SERVER_HELLO) + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_closed = True + maybe_finish() + + def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo): + log.info("hostkey state: " + state, None) + client.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + log.info("client connect error: " + err, None) + finish_error("client connect error: " + err) + return + log.info("client connected", None) + client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if shutdown_started: + request_session_close() + maybe_finish() + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + log.info("client channel open error: " + err, None) + finish_error("client channel open error: " + err) + return + log.info("client channel open", None) + ch.request_subsystem("netconf") + ch.write(NETCONF_HELLO) + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + client_buf += data + log.info("client data len " + str(len(data)), None) + if client_buf.find(NETCONF_END) >= 0: + log.info("client got NETCONF hello", None) + begin_shutdown("client got NETCONF hello") + + def ch_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + log.info("client exit " + str(code), None) + return + + def ch_close(ch: ssh.Channel, reason: str): + log.info("client channel closed: " + reason, None) + client_channel_closed = True + if shutdown_started and client is not None: + client.close() + maybe_finish() + + def start_client(): + log.info("starting client", None) + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 42000 + (random.randint(0, 2000) + attempts * 37) % 2000 + log.info("starting server on " + str(port), None) + server_close_requested = False + server_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + + +def _test_ssh_server_subsystem(t: testing.EnvT): + """Exercise server subsystem handling with a local client.""" + ServerClientTester(t) From af7cf3f10122ee2946498566e14e01e52f418df5 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Tue, 6 Jan 2026 19:39:50 +0100 Subject: [PATCH 03/38] Use libssh channel callbacks for async reads --- src/ssh.ext.c | 186 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 172 insertions(+), 14 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index f92a728..a80f545 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -64,6 +64,7 @@ typedef struct write_chunk { typedef struct ssh_channel_ctx { struct ssh_channel_ctx *next; ssh_channel channel; + struct ssh_channel_callbacks_struct *callbacks; sshQ_Channel actor; channel_state_t state; channel_request_t pending_req; @@ -163,6 +164,7 @@ typedef struct server_write_chunk { typedef struct ssh_server_channel_ctx { struct ssh_server_channel_ctx *next; ssh_channel channel; + struct ssh_channel_callbacks_struct *callbacks; struct ssh_server_session_ctx *session; sshQ_ServerChannel actor; schan_state_t state; @@ -621,6 +623,73 @@ static void channel_notify_exit(ssh_channel_ctx *ch, int exit_status, B_str sign ch->exit_sent = 1; } +static int client_channel_data_cb(ssh_session session, ssh_channel channel, void *data, + uint32_t len, int is_stderr, void *userdata) { + ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) + return 0; + if (len == 0) + return 0; + B_bytes out = to$bytesD_len((const char *)data, (size_t)len); + if (is_stderr) { + if (ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, out); + } + } else { + if (ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, ch->actor, out); + } + } + return (int)len; +} + +static void client_channel_eof_cb(ssh_session session, ssh_channel channel, void *userdata) { + ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL) + return; + if (!ch->stdout_eof && ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stderr_eof = 1; + } +} + +static void client_channel_close_cb(ssh_session session, ssh_channel channel, void *userdata) { + ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL) + return; + channel_notify_close(ch, "closed"); +} + +static void client_channel_setup_callbacks(ssh_channel_ctx *ch) { + if (ch == NULL || ch->channel == NULL || ch->callbacks != NULL) + return; + struct ssh_channel_callbacks_struct *cb = acton_calloc(1, sizeof(*cb)); + ssh_callbacks_init(cb); + cb->userdata = ch; + cb->channel_data_function = client_channel_data_cb; + cb->channel_eof_function = client_channel_eof_cb; + cb->channel_close_function = client_channel_close_cb; + ch->callbacks = cb; + ssh_add_channel_callbacks(ch->channel, cb); + if (ssh_debug_enabled) { + ssh_debug_log("client channel callbacks set ch=%p", (void *)ch); + } +} + static void channel_notify_eof(ssh_channel_ctx *ch) { if (ch->channel == NULL) return; @@ -657,6 +726,10 @@ static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { (void)core_dumped; } channel_notify_exit(ch, exit_status, exit_signal); + if (ch->callbacks) { + ssh_remove_channel_callbacks(ch->channel, ch->callbacks); + ch->callbacks = NULL; + } ssh_channel_free(ch->channel); ch->channel = NULL; } else { @@ -738,7 +811,7 @@ static int channel_read_stream(ssh_client_ctx *c, ssh_channel_ctx *ch, int is_st char buf[SSH_READ_BUFSIZE]; int read_any = 0; for (;;) { - int n = ssh_channel_read_nonblocking(ch->channel, buf, sizeof(buf), is_stderr); + int n = ssh_channel_read_buffered(ch->channel, buf, sizeof(buf), is_stderr); if (ssh_debug_enabled) { ssh_debug_log("client channel read: rc=%d stderr=%d", n, is_stderr); } @@ -784,6 +857,10 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { channel_fail(c, ch, "Failed to create SSH channel"); return; } + client_channel_setup_callbacks(ch); + if (ssh_debug_enabled) { + ssh_debug_log("client channel new ch=%p callbacks=%p", (void *)ch, (void *)ch->callbacks); + } ch->state = CHAN_STATE_OPENING; } @@ -896,12 +973,14 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { } } - for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { - int did = 0; - did |= channel_read_stream(c, ch, 0); - did |= channel_read_stream(c, ch, 1); - if (!did) - break; + if (ch->callbacks == NULL) { + for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + did |= channel_read_stream(c, ch, 0); + did |= channel_read_stream(c, ch, 1); + if (!did) + break; + } } channel_notify_eof(ch); } @@ -1566,6 +1645,7 @@ B_NoneType sshQ__debug(B_str msg) { ssh_channel_ctx *ch = acton_calloc(1, sizeof(ssh_channel_ctx)); ch->actor = channel; + ch->callbacks = NULL; ch->state = CHAN_STATE_INIT; ch->pending_req = CHAN_REQ_NONE; ch->pty_pending = 0; @@ -1771,6 +1851,73 @@ static void server_channel_notify_close(ssh_server_channel_ctx *ch, const char * ch->close_notified = 1; } +static int server_channel_data_cb(ssh_session session, ssh_channel channel, void *data, + uint32_t len, int is_stderr, void *userdata) { + ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) + return 0; + if (len == 0) + return 0; + B_bytes out = to$bytesD_len((const char *)data, (size_t)len); + if (is_stderr) { + if (ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, out); + } + } else { + if (ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, ch->actor, out); + } + } + return (int)len; +} + +static void server_channel_eof_cb(ssh_session session, ssh_channel channel, void *userdata) { + ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL) + return; + if (!ch->stdout_eof && ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, ch->actor, B_None); + ch->stderr_eof = 1; + } +} + +static void server_channel_close_cb(ssh_session session, ssh_channel channel, void *userdata) { + ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL) + return; + server_channel_notify_close(ch, "closed"); +} + +static void server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { + if (ch == NULL || ch->channel == NULL || ch->callbacks != NULL) + return; + struct ssh_channel_callbacks_struct *cb = acton_calloc(1, sizeof(*cb)); + ssh_callbacks_init(cb); + cb->userdata = ch; + cb->channel_data_function = server_channel_data_cb; + cb->channel_eof_function = server_channel_eof_cb; + cb->channel_close_function = server_channel_close_cb; + ch->callbacks = cb; + ssh_add_channel_callbacks(ch->channel, cb); + if (ssh_debug_enabled) { + ssh_debug_log("server channel callbacks set ch=%p", (void *)ch); + } +} + static void server_channel_notify_eof(ssh_server_channel_ctx *ch) { if (ch->channel == NULL) return; @@ -1790,6 +1937,10 @@ static void server_channel_notify_eof(ssh_server_channel_ctx *ch) { static void server_channel_finalize(ssh_server_channel_ctx *ch) { if (ch->channel != NULL) { + if (ch->callbacks) { + ssh_remove_channel_callbacks(ch->channel, ch->callbacks); + ch->callbacks = NULL; + } ssh_channel_free(ch->channel); ch->channel = NULL; } @@ -1875,7 +2026,7 @@ static int server_channel_read_stream(ssh_server_session_ctx *s, ssh_server_chan char buf[SSH_READ_BUFSIZE]; int read_any = 0; for (;;) { - int n = ssh_channel_read_nonblocking(ch->channel, buf, sizeof(buf), is_stderr); + int n = ssh_channel_read_buffered(ch->channel, buf, sizeof(buf), is_stderr); if (ssh_debug_enabled) { ssh_debug_log("server channel read: rc=%d stderr=%d", n, is_stderr); } @@ -1957,12 +2108,14 @@ static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_c } } - for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { - int did = 0; - did |= server_channel_read_stream(s, ch, 0); - did |= server_channel_read_stream(s, ch, 1); - if (!did) - break; + if (ch->callbacks == NULL) { + for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + did |= server_channel_read_stream(s, ch, 0); + did |= server_channel_read_stream(s, ch, 1); + if (!did) + break; + } } server_channel_notify_eof(ch); @@ -2847,6 +3000,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o ch->channel = chan; ch->session = s; ch->actor = channel; + ch->callbacks = NULL; ch->state = SCHAN_STATE_OPEN; ch->send_eof = 0; ch->close_requested = 0; @@ -2859,6 +3013,10 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o ch->on_data = ($action2)on_data; ch->on_stderr = ($action2)on_stderr; ch->on_close = ($action2)on_close; + server_channel_setup_callbacks(ch); + if (ssh_debug_enabled) { + ssh_debug_log("server channel new ch=%p callbacks=%p", (void *)ch, (void *)ch->callbacks); + } ch->next = s->channels; s->channels = ch; From 2c95dbd42c578dc89f7c673d12d0fb33a3427ada Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Tue, 6 Jan 2026 20:17:48 +0100 Subject: [PATCH 04/38] Handle buffered read pending during connect/key exchange --- src/ssh.ext.c | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index a80f545..76170c7 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -1306,6 +1306,7 @@ static void client_drive(ssh_client_ctx *c) { if (c->state == CLIENT_STATE_ERROR || c->state == CLIENT_STATE_CLOSED || c->state == CLIENT_STATE_CLOSING) return; + int spin = 0; while (1) { if (c->state == CLIENT_STATE_CONNECTING) { client_mark_writable(c); @@ -1318,8 +1319,15 @@ static void client_drive(ssh_client_ctx *c) { c->state = CLIENT_STATE_HOSTKEY; continue; } else if (rc == SSH_AGAIN) { - if (ssh_get_status(c->session) & SSH_WRITE_PENDING) + int status = ssh_get_status(c->session); + if (status & SSH_WRITE_PENDING) c->write_ready = 0; + if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) { + if (ssh_debug_enabled) { + ssh_debug_log("client drive: buffered read pending during connect"); + } + continue; + } client_update_poll(c); return; } else { @@ -1360,8 +1368,15 @@ static void client_drive(ssh_client_ctx *c) { client_on_ready(c); return; } else if (rc == SSH_AUTH_AGAIN) { - if (ssh_get_status(c->session) & SSH_WRITE_PENDING) + int status = ssh_get_status(c->session); + if (status & SSH_WRITE_PENDING) c->write_ready = 0; + if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) { + if (ssh_debug_enabled) { + ssh_debug_log("client drive: buffered read pending during auth"); + } + continue; + } client_update_poll(c); return; } else { @@ -2326,6 +2341,7 @@ static void session_drive(ssh_server_session_ctx *s) { if (!s->attached) return; + int spin = 0; while (1) { if (s->state == SESSION_STATE_KEYEX) { session_mark_writable(s); @@ -2336,8 +2352,15 @@ static void session_drive(ssh_server_session_ctx *s) { ssh_set_auth_methods(s->session, SSH_AUTH_METHOD_PASSWORD); continue; } else if (rc == SSH_AGAIN) { - if (ssh_get_status(s->session) & SSH_WRITE_PENDING) + int status = ssh_get_status(s->session); + if (status & SSH_WRITE_PENDING) s->write_ready = 0; + if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) { + if (ssh_debug_enabled) { + ssh_debug_log("server drive: buffered read pending during key exchange"); + } + continue; + } session_update_poll(s); return; } else { From 68a37723f7c99ba526c5207ff00d92bed3960a00 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Wed, 18 Feb 2026 21:28:29 +0100 Subject: [PATCH 05/38] Fix SSH client close flush ordering --- build.act.json | 2 +- src/ssh.act | 37 +++++++++ src/ssh.ext.c | 83 ++++++++++++++++-- src/test_ssh_server.act | 180 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 295 insertions(+), 7 deletions(-) diff --git a/build.act.json b/build.act.json index 04c7953..ac51969 100644 --- a/build.act.json +++ b/build.act.json @@ -2,7 +2,7 @@ "dependencies": {}, "zig_dependencies": { "libssh": { - "path": "../libssh", + "path": "../acton-deps/libssh", "options": { "WITH_SERVER": "true" }, diff --git a/src/ssh.act b/src/ssh.act index df1a6c4..c0af650 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -1,3 +1,40 @@ +""" +Acton SSH module + +This module exposes a callback-driven SSH client/server API built on libssh +and integrated with Acton's libuv loop. The main actors are: + +- Client: connects to a server, reports on_connect(err) and on_close(reason), + and optionally on_hostkey(state, HostKeyInfo) for strict host key checks. +- Channel: client-side data stream (stdout/stderr/exit/close callbacks). +- RunCommand: convenience wrapper that buffers stdout/stderr and reports exit. +- Server: listens for incoming connections and creates ServerSession actors. +- ServerSession: per-connection state; emits on_auth(AuthRequest), + on_channel_open, and optional on_exec/on_subsystem. +- ServerChannel: server-side data stream for a single channel. + +Usage sketch (client): + def on_connect(c, err): + if err is not None: ... + Channel(c, on_open, on_stdout, on_stderr, on_exit, on_close) + + def on_hostkey(c, state, info): + if state == HOSTKEY_OK: c.accept_hostkey() + else: c.reject_hostkey("...") + +Usage sketch (server): + def on_auth(sess, req): sess.accept_auth() + def on_channel_open(sess): sess.accept_channel(ServerChannel(...)) + def on_subsystem(sess, ch, name): ch.accept_request(); ch.write(...) + +Notes: + - All operations are async actor actions; callbacks receive errors as ?str. + - ?bytes callbacks use None to signal EOF. + - Timeouts and keepalive are configurable; host keys are generated in-memory + unless host_key_path is set. + - For debugging, set ACTON_SSH_DEBUG and ACTON_SSH_LIBSSH_LOG. +""" + import net diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 76170c7..2059ae4 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -1,3 +1,64 @@ +/* + * Acton <-> libssh integration overview + * + * This file is the external-C glue that drives libssh from Acton's libuv loop + * and exposes it to Acton actors. The core goals are: + * - Nonblocking SSH I/O integrated with libuv (no blocking syscalls). + * - Actor-safe, async callback-driven API in Acton. + * - GC-safe memory: libssh allocations use Acton's allocator. + * + * Event loop integration + * - Each libssh session is created nonblocking. + * - We attach a uv_poll watcher to the libssh socket fd. + * - On poll events we call ssh_session_handle_poll() (via + * session_apply_poll_events), then drive a small state machine + * (connect/auth/ready for client, keyex/auth/ready for server). + * - ssh_get_poll_flags()/ssh_get_status() decide which poll events to arm. + * + * Buffered data + SSH_AGAIN (why we keep driving without fd readability) + * - libssh maintains its own internal buffers. After a poll callback, libssh + * may have already read bytes into those buffers even though the socket is + * no longer readable at the OS level. + * - When a nonblocking API returns SSH_AGAIN and ssh_get_status() includes + * SSH_READ_PENDING, it means "call again, there is buffered data to + * process" even if the fd will not trigger another readable event. + * - If we only wait for uv_poll readability, we can deadlock: + * 1) uv_poll READABLE fires; ssh_session_handle_poll() drains the fd. + * 2) ssh_connect()/ssh_handle_key_exchange()/ssh_userauth_password() + * returns SSH_AGAIN. + * 3) No more kernel readability events happen, but libssh still has + * buffered protocol bytes (SSH_READ_PENDING). + * 4) We wait for an event that never comes and eventually time out. + * - The fix is to keep driving the state machine in a bounded loop while + * SSH_READ_PENDING is set, even without fd readability. + * We cap iterations with SSH_IO_PUMP_LIMIT to avoid CPU spin. + * + * Channel I/O + * - SSH channels carry two streams: "data" and "extended data". We expose + * these as stdout/stderr callbacks (client on_stdout/on_stderr, server + * on_data/on_stderr). This is protocol-level stdout/stderr, not host OS + * process stdio. + * - Channels install libssh callbacks for data/extended-data/EOF/close. + * - Inbound data always flows through these callbacks. As we drive libssh + * (via ssh_session_handle_poll), libssh invokes the registered C callback + * functions, and those callbacks call the corresponding Acton action + * methods (foo->$class->on_stdout/on_stderr/on_close, etc.). We do not run + * manual read loops; libssh owns buffering and read state. + * - Channel writes are queued and flushed when libssh reports write pending. + * + * Actor/GC/threading model + * - Client and ServerSession actors own libssh state and are pinned to a + * worker thread. Channel actors invoke action methods on their owning + * Client/ServerSession actor for all operations; there is no hidden + * cross-actor C magic. + * - We replace libssh allocators with Acton's GC allocator so libuv/GC + * roots remain visible (libssh structures can reference GC memory). + * + * Config & filesystem + * - libssh config processing is disabled by default; known_hosts is only + * read if explicitly configured by the Acton API. + * - Server host keys are generated in-memory unless a path is provided. + */ #include #include #include @@ -22,7 +83,7 @@ uv_loop_t *get_uv_loop(void); extern struct $Cont $Done$instance; #define SSH_READ_BUFSIZE 4096 -#define SSH_IO_PUMP_LIMIT 32 +#define SSH_IO_PUMP_LIMIT 128 static int ssh_debug_enabled = 0; static int ssh_libssh_log_level = SSH_LOG_NOLOG; @@ -940,7 +1001,7 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { return; } - if (ch->send_eof && !ch->eof_sent) { + if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL) { client_mark_writable(c); int rc = ssh_channel_send_eof(ch->channel); if (rc == SSH_OK) { @@ -956,7 +1017,7 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { } } - if (ch->close_requested && !ch->close_sent) { + if (ch->close_requested && !ch->close_sent && ch->write_head == NULL) { client_mark_writable(c); int rc = ssh_channel_close(ch->channel); if (rc == SSH_OK) { @@ -1224,7 +1285,8 @@ static void client_update_poll(ssh_client_ctx *c) { static void client_pump_io(ssh_client_ctx *c) { if (c == NULL || c->session == NULL) return; - for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int i; + for (i = 0; i < SSH_IO_PUMP_LIMIT; i++) { int did = 0; if (c->session == NULL) return; @@ -1272,6 +1334,11 @@ static void client_pump_io(ssh_client_ctx *c) { if (!did) break; } + if (ssh_debug_enabled && i >= SSH_IO_PUMP_LIMIT) { + int status = ssh_get_status(c->session); + int flags = ssh_get_poll_flags(c->session); + ssh_debug_log("client pump: hit limit status=0x%x flags=0x%x", status, flags); + } } static int client_needs_write(ssh_client_ctx *c) { @@ -2267,7 +2334,8 @@ static void session_update_poll(ssh_server_session_ctx *s) { static void session_pump_io(ssh_server_session_ctx *s) { if (s == NULL || s->session == NULL) return; - for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int i; + for (i = 0; i < SSH_IO_PUMP_LIMIT; i++) { int did = 0; if (s->session == NULL) return; @@ -2315,6 +2383,11 @@ static void session_pump_io(ssh_server_session_ctx *s) { if (!did) break; } + if (ssh_debug_enabled && i >= SSH_IO_PUMP_LIMIT) { + int status = ssh_get_status(s->session); + int flags = ssh_get_poll_flags(s->session); + ssh_debug_log("server pump: hit limit status=0x%x flags=0x%x", status, flags); + } } static int session_needs_write(ssh_server_session_ctx *s) { diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 023a4f1..c09f834 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -26,7 +26,6 @@ actor ServerClientTester(t: testing.EnvT): var server_buf = b"" var client_buf = b"" var client_channel: ?ssh.Channel = None - var server_channel: ?ssh.ServerChannel = None var shutdown_started = False var client_closed = False var server_closed = False @@ -271,3 +270,182 @@ actor ServerClientTester(t: testing.EnvT): def _test_ssh_server_subsystem(t: testing.EnvT): """Exercise server subsystem handling with a local client.""" ServerClientTester(t) + + +actor ClientCloseFlushTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var payload_len = 1024 * 1024 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_buf = b"" + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def finish_success(): + if done: + return + done = True + log.info("test success", None) + request_client_close() + request_server_close() + t.success() + + def on_timeout(): + if done: + return + finish_error("timeout waiting for client close flush test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + log.info("server listening", None) + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + ch.reject_request("exec disabled") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name == "netconf": + ch.accept_request() + else: + ch.reject_request("unsupported subsystem") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_buf += data + return + if len(server_buf) != payload_len: + finish_error("server received " + str(len(server_buf)) + " bytes, expected " + str(payload_len)) + return + finish_success() + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo): + client.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + ch.request_subsystem("netconf") + payload = b"A" * payload_len + ch.write(payload) + ch.close() + + def ch_out(ch: ssh.Channel, data: ?bytes): + return + + def ch_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + return + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 44000 + (random.randint(0, 2000) + attempts * 59) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 1.0: on_timeout() + + +def _test_ssh_client_close_flush(t: testing.EnvT): + """Client close should flush queued writes before EOF.""" + ClientCloseFlushTester(t) From 2d82af8e8908a850827912a9968046ca65beadb2 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Feb 2026 22:19:55 +0100 Subject: [PATCH 06/38] Harden server request teardown and add rejection tests --- src/ssh.ext.c | 13 ++ src/test_ssh_server.act | 334 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 347 insertions(+) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 2059ae4..cabac84 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -2018,6 +2018,12 @@ static void server_channel_notify_eof(ssh_server_channel_ctx *ch) { } static void server_channel_finalize(ssh_server_channel_ctx *ch) { + if (ch->pending_req) { + ssh_message_reply_default(ch->pending_req); + ssh_message_free(ch->pending_req); + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + } if (ch->channel != NULL) { if (ch->callbacks) { ssh_remove_channel_callbacks(ch->channel, ch->callbacks); @@ -2836,10 +2842,12 @@ static void session_close_internal(ssh_server_session_ctx *s, const char *reason s->channels = NULL; if (s->pending_auth) { + ssh_message_reply_default(s->pending_auth); ssh_message_free(s->pending_auth); s->pending_auth = NULL; } if (s->pending_channel_open) { + ssh_message_reply_default(s->pending_channel_open); ssh_message_free(s->pending_channel_open); s->pending_channel_open = NULL; } @@ -3083,8 +3091,13 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o ssh_channel chan = ssh_message_channel_request_open_reply_accept(s->pending_channel_open); if (chan == NULL) { + ssh_message_reply_default(s->pending_channel_open); ssh_message_free(s->pending_channel_open); s->pending_channel_open = NULL; + if (on_close) { + $action2 f = ($action2)on_close; + f->$class->__asyn__(f, channel, to$str((char *)"Failed to accept channel open")); + } session_drive(s); return $R_CONT(c$cont, B_None); } diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index c09f834..8f468c4 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -449,3 +449,337 @@ actor ClientCloseFlushTester(t: testing.EnvT): def _test_ssh_client_close_flush(t: testing.EnvT): """Client close should flush queued writes before EOF.""" ClientCloseFlushTester(t) + + +actor AuthRejectTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var saw_auth = False + var got_connect_error = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if saw_auth and got_connect_error: + done = True + request_client_close() + request_server_close() + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error("timeout waiting for auth reject test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + saw_auth = True + sess.reject_auth("invalid credentials") + maybe_finish() + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open after auth reject") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request after auth reject") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request after auth reject") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is None: + finish_error("expected auth failure but connect succeeded") + return + got_connect_error = True + maybe_finish() + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="wrong-pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 46000 + (random.randint(0, 2000) + attempts * 73) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 2.0: on_timeout() + + +def _test_ssh_auth_reject(t: testing.EnvT): + """Server reject_auth should fail client connect without hanging.""" + AuthRejectTester(t) + + +actor UnsupportedRequestTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var saw_subsystem = False + var saw_exec = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if saw_subsystem and saw_exec: + done = True + request_client_close() + request_server_close() + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error("timeout waiting for unsupported request test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + saw_exec = True + ch.reject_request("exec disabled") + maybe_finish() + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + saw_subsystem = True + ch.reject_request("unsupported subsystem") + maybe_finish() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + ssh.Channel(c, ch_sub_open, ch_sub_out, ch_sub_err, ch_sub_exit, ch_sub_close) + after 0.05: start_exec_channel(c) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + + def ch_sub_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("subsystem channel open error: " + err) + return + ch.request_subsystem("unsupported-subsystem") + ch.close() + + def ch_sub_out(ch: ssh.Channel, data: ?bytes): + return + + def ch_sub_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_sub_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_sub_close(ch: ssh.Channel, reason: str): + return + + def start_exec_channel(c: ssh.Client): + if done: + return + ssh.Channel(c, ch_exec_open, ch_exec_out, ch_exec_err, ch_exec_exit, ch_exec_close) + + def ch_exec_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("exec channel open error: " + err) + return + ch.request_exec("echo denied") + ch.close() + + def ch_exec_out(ch: ssh.Channel, data: ?bytes): + return + + def ch_exec_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exec_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_exec_close(ch: ssh.Channel, reason: str): + return + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 48000 + (random.randint(0, 2000) + attempts * 97) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 2.0: on_timeout() + + +def _test_ssh_unsupported_requests(t: testing.EnvT): + """Server reject_request should handle unsupported subsystem/exec.""" + UnsupportedRequestTester(t) From 4f767d3fa09bcd743af8dc221d64a22081d9d725 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Wed, 18 Mar 2026 22:40:55 +0100 Subject: [PATCH 07/38] Honor libssh async readiness The transport could ask libssh to read or write when the socket or channel was not actually ready. We synthesized write progress, trusted stale readable notifications, and finalized channels before the peer had closed them. Under load this produced truncated close/flush runs, fatal "Resource temporarily unavailable" socket errors, and occasional wedges. This change makes channel payload writes wait for libssh's own channel_write_wontblock callback, keeps EOF and CLOSE behind session flush completion, and filters POLLIN through a fresh readiness check before handing it back to libssh. It also restores accepted and post-connect session sockets to nonblocking mode after libssh switches them back to blocking. These changes align the binding with libssh's external event-loop contract. Read and write callbacks now only run when the transport can actually make progress, and channel teardown waits until both local flush and remote close have completed. --- src/ssh.ext.c | 312 ++++++++++++++++++++++++++------------------------ 1 file changed, 164 insertions(+), 148 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index cabac84..fd45f34 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -142,6 +142,8 @@ typedef struct ssh_channel_ctx { int eof_sent; int close_requested; int close_sent; + int remote_close_seen; + int write_wontblock; int stdout_eof; int stderr_eof; int exit_sent; @@ -232,6 +234,8 @@ typedef struct ssh_server_channel_ctx { int send_eof; int close_requested; int close_sent; + int remote_close_seen; + int write_wontblock; int eof_sent; int stdout_eof; int stderr_eof; @@ -488,18 +492,6 @@ static void stop_timer(uv_timer_t **timer, uv_close_cb close_cb) { } } -static void client_mark_writable(ssh_client_ctx *c) { - if (c != NULL && c->session != NULL && c->write_ready) { - ssh_set_fd_towrite(c->session); - } -} - -static void session_mark_writable(ssh_server_session_ctx *s) { - if (s != NULL && s->session != NULL && s->write_ready) { - ssh_set_fd_towrite(s->session); - } -} - static int fd_has_data(int fd) { if (fd < 0) return 0; @@ -513,9 +505,11 @@ static int fd_has_data(int fd) { } while (rc < 0 && errno == EINTR); if (rc <= 0) return 0; - if (pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) + if (pfd.revents & POLLIN) + return 1; + if (pfd.revents & POLLNVAL) return 0; - return (pfd.revents & POLLIN) != 0; + return 0; } static int fd_set_nonblocking(int fd) { @@ -529,22 +523,22 @@ static int fd_set_nonblocking(int fd) { return 0; } -static int fd_is_writable(int fd) { - if (fd < 0) - return 0; - struct pollfd pfd; - pfd.fd = fd; - pfd.events = POLLOUT; - pfd.revents = 0; - int rc; - do { - rc = poll(&pfd, 1, 0); - } while (rc < 0 && errno == EINTR); - if (rc <= 0) - return 0; - if (pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) +static void format_session_error(ssh_session session, const char *prefix, + char *buf, size_t buflen) { + const char *err = NULL; + if (session != NULL) + err = ssh_get_error(session); + if (err != NULL && err[0] != '\0') + snprintf(buf, buflen, "%s: %s", prefix, err); + else + snprintf(buf, buflen, "%s", prefix); +} + +static int session_has_pending_write(ssh_session session) { + if (session == NULL) return 0; - return (pfd.revents & POLLOUT) != 0; + int pending = ssh_get_status(session) | ssh_get_poll_flags(session); + return (pending & SSH_WRITE_PENDING) != 0; } static void client_poll_close_cb(uv_handle_t *handle) { @@ -732,9 +726,24 @@ static void client_channel_close_cb(ssh_session session, ssh_channel channel, vo (void)channel; if (ch == NULL) return; + ch->remote_close_seen = 1; channel_notify_close(ch, "closed"); } +static int client_channel_write_wontblock_cb(ssh_session session, ssh_channel channel, + uint32_t bytes, void *userdata) { + ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) + return 0; + ch->write_wontblock = bytes > 0 ? 1 : 0; + if (ssh_debug_enabled) { + ssh_debug_log("client channel write_wontblock: bytes=%u ch=%p", bytes, (void *)ch); + } + return 0; +} + static void client_channel_setup_callbacks(ssh_channel_ctx *ch) { if (ch == NULL || ch->channel == NULL || ch->callbacks != NULL) return; @@ -744,6 +753,7 @@ static void client_channel_setup_callbacks(ssh_channel_ctx *ch) { cb->channel_data_function = client_channel_data_cb; cb->channel_eof_function = client_channel_eof_cb; cb->channel_close_function = client_channel_close_cb; + cb->channel_write_wontblock_function = client_channel_write_wontblock_cb; ch->callbacks = cb; ssh_add_channel_callbacks(ch->channel, cb); if (ssh_debug_enabled) { @@ -838,33 +848,35 @@ static void channel_queue_write(ssh_channel_ctx *ch, B_bytes data) { } static void channel_try_write(ssh_client_ctx *c, ssh_channel_ctx *ch) { - while (ch->write_head != NULL) { + while (ch->write_head != NULL && ch->write_head->data->nbytes == ch->write_head->offset) { write_chunk_t *chunk = ch->write_head; - size_t remaining = chunk->data->nbytes - chunk->offset; - if (remaining == 0) { + ch->write_head = chunk->next; + if (ch->write_head == NULL) + ch->write_tail = NULL; + } + + if (ch->write_head == NULL || !ch->write_wontblock || session_has_pending_write(c->session)) + return; + + write_chunk_t *chunk = ch->write_head; + size_t remaining = chunk->data->nbytes - chunk->offset; + ch->write_wontblock = 0; + int rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + if (rc > 0) { + chunk->offset += (size_t)rc; + if (chunk->offset >= chunk->data->nbytes) { ch->write_head = chunk->next; if (ch->write_head == NULL) ch->write_tail = NULL; - continue; - } - client_mark_writable(c); - int rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); - if (rc > 0) { - chunk->offset += (size_t)rc; - if (chunk->offset >= chunk->data->nbytes) { - ch->write_head = chunk->next; - if (ch->write_head == NULL) - ch->write_tail = NULL; - } - } else if (rc == 0 || rc == SSH_AGAIN) { - c->write_ready = 0; - return; - } else { - char errmsg[256] = {0}; - snprintf(errmsg, sizeof(errmsg), "SSH channel write error: %s", ssh_get_error(c->session)); - channel_fail(c, ch, errmsg); - return; } + } else if (rc == 0 || rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH channel write error: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; } } @@ -926,7 +938,6 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { } if (ch->state == CHAN_STATE_OPENING) { - client_mark_writable(c); int rc = ssh_channel_open_session(ch->channel); if (rc == SSH_OK) { ch->state = CHAN_STATE_OPEN; @@ -945,7 +956,6 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { if (ch->state == CHAN_STATE_OPEN || ch->state == CHAN_STATE_RUNNING) { if (ch->pty_pending && !ch->pty_done) { const char *term = ch->term ? (const char *)fromB_str(ch->term) : "xterm-256color"; - client_mark_writable(c); int rc = ssh_channel_request_pty_size(ch->channel, term, ch->cols, ch->rows); if (rc == SSH_OK) { ch->pty_done = 1; @@ -963,13 +973,10 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { if (ch->pending_req != CHAN_REQ_NONE) { int rc = SSH_ERROR; if (ch->pending_req == CHAN_REQ_SHELL) { - client_mark_writable(c); rc = ssh_channel_request_shell(ch->channel); } else if (ch->pending_req == CHAN_REQ_EXEC) { - client_mark_writable(c); rc = ssh_channel_request_exec(ch->channel, (const char *)fromB_str(ch->exec_cmd)); } else if (ch->pending_req == CHAN_REQ_SUBSYSTEM) { - client_mark_writable(c); rc = ssh_channel_request_subsystem(ch->channel, (const char *)fromB_str(ch->subsystem)); } @@ -1001,8 +1008,8 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { return; } - if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL) { - client_mark_writable(c); + if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL && + !session_has_pending_write(c->session)) { int rc = ssh_channel_send_eof(ch->channel); if (rc == SSH_OK) { ch->eof_sent = 1; @@ -1017,8 +1024,9 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { } } - if (ch->close_requested && !ch->close_sent && ch->write_head == NULL) { - client_mark_writable(c); + if (ch->close_requested && !ch->close_sent && ch->write_head == NULL && + (!ch->send_eof || ch->eof_sent) && + !session_has_pending_write(c->session)) { int rc = ssh_channel_close(ch->channel); if (rc == SSH_OK) { ch->close_sent = 1; @@ -1046,7 +1054,8 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { channel_notify_eof(ch); } - if (ch->channel != NULL && ssh_channel_is_closed(ch->channel)) { + if (ch->channel != NULL && ch->remote_close_seen && + ssh_channel_is_closed(ch->channel)) { channel_finalize(c, ch); } } @@ -1226,18 +1235,25 @@ static void poll_cb(uv_poll_t *handle, int status, int events) { client_fail(c, errmsg); return; } - if (events & UV_READABLE) + int libssh_events = 0; + if ((events & UV_READABLE) && fd_has_data(c->fd)) { ssh_set_fd_toread(c->session); + libssh_events |= UV_READABLE; + } if (events & UV_WRITABLE) { c->write_ready = 1; ssh_set_fd_towrite(c->session); + libssh_events |= UV_WRITABLE; } - if (session_apply_poll_events(c->session, events) != 0) { - client_fail(c, "SSH poll callback error"); + if (session_apply_poll_events(c->session, libssh_events) != 0) { + char errmsg[256] = {0}; + format_session_error(c->session, "SSH poll callback error", errmsg, sizeof(errmsg)); + client_fail(c, errmsg); return; } client_drive(c); client_pump_io(c); + c->write_ready = 0; } static void client_update_poll(ssh_client_ctx *c) { @@ -1263,9 +1279,6 @@ static void client_update_poll(ssh_client_ctx *c) { events |= UV_WRITABLE; if ((events & UV_WRITABLE) == 0 && client_needs_write(c)) events |= UV_WRITABLE; -#ifdef UV_DISCONNECT - events |= UV_DISCONNECT; -#endif if (ssh_debug_enabled && (events != c->poll_events || (pending & SSH_WRITE_PENDING))) { ssh_debug_log("client update poll: status=0x%x flags=0x%x pending=0x%x events=0x%x state=%d", status, flags, pending, events, c->state); @@ -1297,7 +1310,9 @@ static void client_pump_io(ssh_client_ctx *c) { } ssh_set_fd_toread(c->session); if (session_apply_poll_events(c->session, UV_READABLE) != 0) { - client_fail(c, "SSH poll callback error"); + char errmsg[256] = {0}; + format_session_error(c->session, "SSH poll callback error", errmsg, sizeof(errmsg)); + client_fail(c, errmsg); return; } client_drive(c); @@ -1305,21 +1320,6 @@ static void client_pump_io(ssh_client_ctx *c) { } if (c->session == NULL) return; - int pending = ssh_get_status(c->session) | ssh_get_poll_flags(c->session); - int can_write = (pending & SSH_WRITE_PENDING) && fd_is_writable(c->fd); - if (can_write) { - if (ssh_debug_enabled) { - ssh_debug_log("client pump: writable pending=0x%x", pending); - } - c->write_ready = 1; - ssh_set_fd_towrite(c->session); - if (session_apply_poll_events(c->session, UV_WRITABLE) != 0) { - client_fail(c, "SSH poll callback error"); - return; - } - client_drive(c); - did = 1; - } if (!did) { int status = ssh_get_status(c->session); if (status & SSH_READ_PENDING) { @@ -1376,13 +1376,16 @@ static void client_drive(ssh_client_ctx *c) { int spin = 0; while (1) { if (c->state == CLIENT_STATE_CONNECTING) { - client_mark_writable(c); int rc = ssh_connect(c->session); if (ssh_debug_enabled) { log_debug("ssh_connect rc=%d state=%d", rc, c->state); } if (rc == SSH_OK) { stop_timer(&c->connect_timer, client_timer_close_cb); + if (fd_set_nonblocking(c->fd) != 0) { + client_fail(c, "Failed to restore SSH session fd nonblocking"); + return; + } c->state = CLIENT_STATE_HOSTKEY; continue; } else if (rc == SSH_AGAIN) { @@ -1429,7 +1432,6 @@ static void client_drive(ssh_client_ctx *c) { client_fail(c, "Password auth requested but no password provided"); return; } - client_mark_writable(c); int rc = ssh_userauth_password(c->session, NULL, (const char *)fromB_str(c->actor->password)); if (rc == SSH_AUTH_SUCCESS) { client_on_ready(c); @@ -1649,6 +1651,10 @@ B_NoneType sshQ__debug(B_str msg) { client_fail(c, "Failed to get SSH session fd"); return $R_CONT(c$cont, B_None); } + if (c->state != CLIENT_STATE_CONNECTING && fd_set_nonblocking(c->fd) != 0) { + client_fail(c, "Failed to set SSH session fd nonblocking"); + return $R_CONT(c$cont, B_None); + } c->poll = acton_calloc(1, sizeof(uv_poll_t)); c->poll->data = c; @@ -1981,9 +1987,24 @@ static void server_channel_close_cb(ssh_session session, ssh_channel channel, vo (void)channel; if (ch == NULL) return; + ch->remote_close_seen = 1; server_channel_notify_close(ch, "closed"); } +static int server_channel_write_wontblock_cb(ssh_session session, ssh_channel channel, + uint32_t bytes, void *userdata) { + ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) + return 0; + ch->write_wontblock = bytes > 0 ? 1 : 0; + if (ssh_debug_enabled) { + ssh_debug_log("server channel write_wontblock: bytes=%u ch=%p", bytes, (void *)ch); + } + return 0; +} + static void server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { if (ch == NULL || ch->channel == NULL || ch->callbacks != NULL) return; @@ -1993,6 +2014,7 @@ static void server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { cb->channel_data_function = server_channel_data_cb; cb->channel_eof_function = server_channel_eof_cb; cb->channel_close_function = server_channel_close_cb; + cb->channel_write_wontblock_function = server_channel_write_wontblock_cb; ch->callbacks = cb; ssh_add_channel_callbacks(ch->channel, cb); if (ssh_debug_enabled) { @@ -2066,47 +2088,49 @@ static void server_channel_queue_write(ssh_server_channel_ctx *ch, B_bytes data, } static void server_channel_try_write(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { - while (ch->write_head != NULL) { + while (ch->write_head != NULL && ch->write_head->data->nbytes == ch->write_head->offset) { server_write_chunk_t *chunk = ch->write_head; - size_t remaining = chunk->data->nbytes - chunk->offset; - if (remaining == 0) { + ch->write_head = chunk->next; + if (ch->write_head == NULL) + ch->write_tail = NULL; + } + + if (ch->write_head == NULL || !ch->write_wontblock || session_has_pending_write(s->session)) + return; + + server_write_chunk_t *chunk = ch->write_head; + size_t remaining = chunk->data->nbytes - chunk->offset; + ch->write_wontblock = 0; + int rc; + if (chunk->is_stderr) { + rc = ssh_channel_write_stderr(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + } else { + rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + } + if (rc > 0) { + chunk->offset += (size_t)rc; + if (ssh_debug_enabled) { + ssh_debug_log("server channel wrote %d bytes", rc); + } + if (chunk->offset >= chunk->data->nbytes) { ch->write_head = chunk->next; if (ch->write_head == NULL) ch->write_tail = NULL; - continue; } - session_mark_writable(s); - int rc; - if (chunk->is_stderr) { - rc = ssh_channel_write_stderr(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); - } else { - rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); - } - if (rc > 0) { - chunk->offset += (size_t)rc; - if (ssh_debug_enabled) { - ssh_debug_log("server channel wrote %d bytes", rc); - } - if (chunk->offset >= chunk->data->nbytes) { - ch->write_head = chunk->next; - if (ch->write_head == NULL) - ch->write_tail = NULL; - } - } else if (rc == 0 || rc == SSH_AGAIN) { - s->write_ready = 0; - if (ssh_debug_enabled) { - unsigned int window = ssh_channel_window_size(ch->channel); - ssh_debug_log("server channel write pending rc=%d remaining=%zu window=%u", - rc, remaining, window); - } - return; - } else { - char errmsg[256] = {0}; - snprintf(errmsg, sizeof(errmsg), "SSH server channel write error: %s", ssh_get_error(s->session)); - server_channel_notify_close(ch, errmsg); - ch->state = SCHAN_STATE_ERROR; - return; + } else if (rc == 0 || rc == SSH_AGAIN) { + s->write_ready = 0; + if (ssh_debug_enabled) { + unsigned int window = ssh_channel_window_size(ch->channel); + ssh_debug_log("server channel write pending rc=%d remaining=%zu window=%u", + rc, remaining, window); } + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server channel write error: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + return; } } @@ -2161,8 +2185,8 @@ static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_c return; } - if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL) { - session_mark_writable(s); + if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL && + !session_has_pending_write(s->session)) { int rc = ssh_channel_send_eof(ch->channel); if (rc == SSH_OK) { ch->eof_sent = 1; @@ -2178,8 +2202,9 @@ static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_c } } - if (ch->close_requested && !ch->close_sent && ch->write_head == NULL) { - session_mark_writable(s); + if (ch->close_requested && !ch->close_sent && ch->write_head == NULL && + (!ch->send_eof || ch->eof_sent) && + !session_has_pending_write(s->session)) { int rc = ssh_channel_close(ch->channel); if (rc == SSH_OK) { ch->close_sent = 1; @@ -2207,7 +2232,8 @@ static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_c } server_channel_notify_eof(ch); - if (ch->channel != NULL && ssh_channel_is_closed(ch->channel)) { + if (ch->channel != NULL && ch->remote_close_seen && + ssh_channel_is_closed(ch->channel)) { server_channel_finalize(ch); } } @@ -2318,9 +2344,6 @@ static void session_update_poll(ssh_server_session_ctx *s) { events |= UV_WRITABLE; if ((events & UV_WRITABLE) == 0 && session_needs_write(s)) events |= UV_WRITABLE; -#ifdef UV_DISCONNECT - events |= UV_DISCONNECT; -#endif if (ssh_debug_enabled && (events != s->poll_events || (pending & SSH_WRITE_PENDING))) { ssh_debug_log("server update poll: status=0x%x flags=0x%x pending=0x%x events=0x%x state=%d", status, flags, pending, events, s->state); @@ -2352,7 +2375,9 @@ static void session_pump_io(ssh_server_session_ctx *s) { } ssh_set_fd_toread(s->session); if (session_apply_poll_events(s->session, UV_READABLE) != 0) { - session_fail(s, "SSH poll callback error"); + char errmsg[256] = {0}; + format_session_error(s->session, "SSH poll callback error", errmsg, sizeof(errmsg)); + session_fail(s, errmsg); return; } session_drive(s); @@ -2360,21 +2385,6 @@ static void session_pump_io(ssh_server_session_ctx *s) { } if (s->session == NULL) return; - int pending = ssh_get_status(s->session) | ssh_get_poll_flags(s->session); - int can_write = (pending & SSH_WRITE_PENDING) && fd_is_writable(s->fd); - if (can_write) { - if (ssh_debug_enabled) { - ssh_debug_log("server pump: writable pending=0x%x", pending); - } - s->write_ready = 1; - ssh_set_fd_towrite(s->session); - if (session_apply_poll_events(s->session, UV_WRITABLE) != 0) { - session_fail(s, "SSH poll callback error"); - return; - } - session_drive(s); - did = 1; - } if (!did) { int status = ssh_get_status(s->session); if (status & SSH_READ_PENDING) { @@ -2423,7 +2433,6 @@ static void session_drive(ssh_server_session_ctx *s) { int spin = 0; while (1) { if (s->state == SESSION_STATE_KEYEX) { - session_mark_writable(s); int rc = ssh_handle_key_exchange(s->session); if (rc == SSH_OK) { s->state = SESSION_STATE_AUTH; @@ -2600,18 +2609,25 @@ static void session_poll_cb(uv_poll_t *handle, int status, int events) { session_fail(s, errmsg); return; } - if (events & UV_READABLE) + int libssh_events = 0; + if ((events & UV_READABLE) && fd_has_data(s->fd)) { ssh_set_fd_toread(s->session); + libssh_events |= UV_READABLE; + } if (events & UV_WRITABLE) { s->write_ready = 1; ssh_set_fd_towrite(s->session); + libssh_events |= UV_WRITABLE; } - if (session_apply_poll_events(s->session, events) != 0) { - session_fail(s, "SSH poll callback error"); + if (session_apply_poll_events(s->session, libssh_events) != 0) { + char errmsg[256] = {0}; + format_session_error(s->session, "SSH poll callback error", errmsg, sizeof(errmsg)); + session_fail(s, errmsg); return; } session_drive(s); session_pump_io(s); + s->write_ready = 0; } static void server_poll_cb(uv_poll_t *handle, int status, int events) { From 5787fb93dc0b0223212e081f22aea995a7224945 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Wed, 18 Mar 2026 22:48:27 +0100 Subject: [PATCH 08/38] Filter stale writable events A rare close/flush stress failure still remained after the main transport hardening. Under four concurrent workers the client could occasionally truncate the payload around 487000 bytes, which matched the earlier write-side failure signature. The remaining hole was that we still trusted libuv's UV_WRITABLE notification directly. libssh treats a write-side EAGAIN as a fatal socket error once we tell it the fd will not block, so a stale or already-consumed writable event could still tear the session down. This change confirms POLLOUT with a fresh poll(0) check before advertising writability to libssh on both client and server session poll callbacks. That keeps the binding symmetric with the existing read-side hardening and only hands libssh readiness that the kernel still agrees with. --- src/ssh.ext.c | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index fd45f34..feb076d 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -512,6 +512,26 @@ static int fd_has_data(int fd) { return 0; } +static int fd_can_write(int fd) { + if (fd < 0) + return 0; + struct pollfd pfd; + pfd.fd = fd; + pfd.events = POLLOUT; + pfd.revents = 0; + int rc; + do { + rc = poll(&pfd, 1, 0); + } while (rc < 0 && errno == EINTR); + if (rc <= 0) + return 0; + if (pfd.revents & POLLOUT) + return 1; + if (pfd.revents & POLLNVAL) + return 0; + return 0; +} + static int fd_set_nonblocking(int fd) { if (fd < 0) return -1; @@ -1240,7 +1260,7 @@ static void poll_cb(uv_poll_t *handle, int status, int events) { ssh_set_fd_toread(c->session); libssh_events |= UV_READABLE; } - if (events & UV_WRITABLE) { + if ((events & UV_WRITABLE) && fd_can_write(c->fd)) { c->write_ready = 1; ssh_set_fd_towrite(c->session); libssh_events |= UV_WRITABLE; @@ -2614,7 +2634,7 @@ static void session_poll_cb(uv_poll_t *handle, int status, int events) { ssh_set_fd_toread(s->session); libssh_events |= UV_READABLE; } - if (events & UV_WRITABLE) { + if ((events & UV_WRITABLE) && fd_can_write(s->fd)) { s->write_ready = 1; ssh_set_fd_towrite(s->session); libssh_events |= UV_WRITABLE; From 94beba71a64475222b21ac5912df69e5fe8a312a Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Wed, 18 Mar 2026 22:48:32 +0100 Subject: [PATCH 09/38] Raise close flush timeout The close/flush stress gate still had a brittle failure mode after the transport fixes. Each iteration spins up a server, completes an SSH handshake and auth exchange, issues a subsystem request, transfers a 1 MiB payload, and waits for channel and session teardown. Under higher worker counts that full wall-clock path can exceed the old one-second budget without the transport being stuck. This change increases the test timeout to two seconds. That keeps the budget tight enough to catch real stalls while removing false stress failures caused by scheduler and handshake tail latency rather than transport corruption. --- src/test_ssh_server.act | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 8f468c4..3bd48bc 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -443,7 +443,7 @@ actor ClientCloseFlushTester(t: testing.EnvT): ) after 0: start_server() - after 1.0: on_timeout() + after 2.0: on_timeout() def _test_ssh_client_close_flush(t: testing.EnvT): From 99e476d4ddcb20a34d711fc0de3c8a281b5dc8b3 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Wed, 18 Mar 2026 23:04:16 +0100 Subject: [PATCH 10/38] Raise close flush timeout The close and flush stress gate is intended to catch truncated transfers and teardown races, not to enforce a tight wall-clock budget. After the transport fixes it could still exceed the shorter timeout under the default stress worker count because each iteration pays for server startup, SSH handshake, auth, a subsystem request, a 1 MiB transfer, and full channel/session shutdown. This change raises the timeout from two to five seconds. That keeps the test focused on correctness while still failing genuinely stuck runs, and it removes false stress failures caused by normal concurrency tail latency. --- src/test_ssh_server.act | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 3bd48bc..176c288 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -443,7 +443,7 @@ actor ClientCloseFlushTester(t: testing.EnvT): ) after 0: start_server() - after 2.0: on_timeout() + after 5.0: on_timeout() def _test_ssh_client_close_flush(t: testing.EnvT): From 6d2d091f6d1d5f395f2ca77449799b6d1fd62414 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 00:01:15 +0100 Subject: [PATCH 11/38] Drain parent close writes A normal Client.close() or ServerSession.close() could tear down every child channel immediately, even when channels still had queued data or libssh still had buffered session writes. Under load that dropped in-flight payloads and turned orderly shutdown into spurious session errors. This change separates graceful parent shutdown from forced error cleanup. Normal client and session close now mark child channels for EOF and CLOSE, keep polling until channel queues and libssh pending writes drain, and only finalize the parent transport once the child channels are gone. Error paths still escalate to immediate teardown. This is correct because parent shutdown now follows the same write-drain contract as individual channels instead of bypassing it at a higher level. --- src/ssh.ext.c | 263 +++++++++++++++++++------- src/test_ssh_server.act | 402 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 594 insertions(+), 71 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index feb076d..41c794a 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -177,6 +177,7 @@ typedef struct ssh_client_ctx { int connected_ok; int close_notified; int close_finalized; + int close_force; int write_ready; char *close_reason; enum ssh_known_hosts_e hostkey_state; @@ -268,6 +269,7 @@ typedef struct ssh_server_session_ctx { int write_ready; int close_notified; int close_finalized; + int close_force; char *close_reason; ssh_message pending_auth; ssh_message pending_channel_open; @@ -298,8 +300,9 @@ typedef struct ssh_server_ctx { static void client_drive(ssh_client_ctx *c); static void client_update_poll(ssh_client_ctx *c); -static void client_close_internal(ssh_client_ctx *c, const char *reason); +static void client_close_internal(ssh_client_ctx *c, const char *reason, int force_close); static void client_finalize(ssh_client_ctx *c); +static void client_finish_close(ssh_client_ctx *c); static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch); static int client_needs_write(ssh_client_ctx *c); static void client_pump_io(ssh_client_ctx *c); @@ -311,8 +314,9 @@ static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *ses static void session_drive(ssh_server_session_ctx *s); static void session_update_poll(ssh_server_session_ctx *s); static void session_pump_io(ssh_server_session_ctx *s); -static void session_close_internal(ssh_server_session_ctx *s, const char *reason); +static void session_close_internal(ssh_server_session_ctx *s, const char *reason, int force_close); static void session_finalize(ssh_server_session_ctx *s); +static void session_finish_close(ssh_server_session_ctx *s); static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch); static int session_needs_write(ssh_server_session_ctx *s); static void session_poll_cb(uv_poll_t *handle, int status, int events); @@ -652,7 +656,7 @@ static void client_fail(ssh_client_ctx *c, const char *msg) { c->state = CLIENT_STATE_ERROR; if (!c->connected_ok) client_notify_connect(c, msg); - client_close_internal(c, msg); + client_close_internal(c, msg, 1); } static void channel_notify_open(ssh_channel_ctx *ch, const char *err) { @@ -1279,7 +1283,7 @@ static void poll_cb(uv_poll_t *handle, int status, int events) { static void client_update_poll(ssh_client_ctx *c) { if (c->poll == NULL || c->session == NULL) return; - if (c->state == CLIENT_STATE_CLOSING || c->state == CLIENT_STATE_CLOSED) + if (c->state == CLIENT_STATE_CLOSED) return; if (uv_is_closing((uv_handle_t *)c->poll)) return; @@ -1289,7 +1293,7 @@ static void client_update_poll(ssh_client_ctx *c) { return; } if (status & SSH_CLOSED) { - client_close_internal(c, "SSH session closed"); + client_close_internal(c, "SSH session closed", 1); return; } int flags = ssh_get_poll_flags(c->session); @@ -1390,8 +1394,13 @@ static void client_on_ready(ssh_client_ctx *c) { static void client_drive(ssh_client_ctx *c) { if (c == NULL) return; - if (c->state == CLIENT_STATE_ERROR || c->state == CLIENT_STATE_CLOSED || c->state == CLIENT_STATE_CLOSING) + if (c->state == CLIENT_STATE_ERROR || c->state == CLIENT_STATE_CLOSED) return; + if (c->state == CLIENT_STATE_CLOSING) { + client_drive_channels(c); + client_finish_close(c); + return; + } int spin = 0; while (1) { @@ -1503,28 +1512,7 @@ static void client_finalize(ssh_client_ctx *c) { c->actor->_client = toB_u64(0); } -static void client_close_internal(ssh_client_ctx *c, const char *reason) { - if (c == NULL || c->state == CLIENT_STATE_CLOSED || c->state == CLIENT_STATE_CLOSING) - return; - - if (!c->connected_ok && !c->connect_notified) { - client_notify_connect(c, reason ? reason : "closed"); - } - - int notify_channel_error = (c->state == CLIENT_STATE_ERROR); - c->state = CLIENT_STATE_CLOSING; - if (reason != NULL && c->close_reason == NULL) - c->close_reason = acton_strdup(reason); - - stop_timer(&c->connect_timer, client_timer_close_cb); - stop_timer(&c->auth_timer, client_timer_close_cb); - stop_timer(&c->keepalive_timer, client_timer_close_cb); - - if (c->poll != NULL) { - close_poll(&c->poll, client_poll_close_cb); - c->poll_events = 0; - } - +static void client_abort_channels(ssh_client_ctx *c, int notify_channel_error) { ssh_channel_ctx *ch = c->channels; while (ch != NULL) { ssh_channel_ctx *next = ch->next; @@ -1536,13 +1524,84 @@ static void client_close_internal(ssh_client_ctx *c, const char *reason) { ch = next; } c->channels = NULL; +} - if (c->poll != NULL) - return; +static void client_request_channel_close(ssh_client_ctx *c) { + ssh_channel_ctx *ch = c->channels; + while (ch != NULL) { + if (ch->state != CHAN_STATE_CLOSED && ch->state != CHAN_STATE_ERROR) { + ch->send_eof = 1; + ch->close_requested = 1; + } + ch = ch->next; + } +} +static void client_finish_close(ssh_client_ctx *c) { + if (c == NULL || c->state != CLIENT_STATE_CLOSING) + return; + if (c->close_force) { + if (c->poll != NULL) { + close_poll(&c->poll, client_poll_close_cb); + c->poll_events = 0; + return; + } + client_finalize(c); + return; + } + if (c->channels != NULL) { + client_update_poll(c); + return; + } + if (c->session != NULL && session_has_pending_write(c->session)) { + client_update_poll(c); + return; + } + if (c->poll != NULL) { + close_poll(&c->poll, client_poll_close_cb); + c->poll_events = 0; + return; + } client_finalize(c); } +static void client_close_internal(ssh_client_ctx *c, const char *reason, int force_close) { + if (c == NULL || c->state == CLIENT_STATE_CLOSED) + return; + if (!force_close && c->state != CLIENT_STATE_READY) + force_close = 1; + + if (!c->connected_ok && !c->connect_notified) { + client_notify_connect(c, reason ? reason : "closed"); + } + if (reason != NULL && c->close_reason == NULL) + c->close_reason = acton_strdup(reason); + + if (c->state == CLIENT_STATE_CLOSING) { + if (force_close && !c->close_force) { + c->close_force = 1; + client_abort_channels(c, 1); + } + client_finish_close(c); + return; + } + + stop_timer(&c->connect_timer, client_timer_close_cb); + stop_timer(&c->auth_timer, client_timer_close_cb); + stop_timer(&c->keepalive_timer, client_timer_close_cb); + + c->state = CLIENT_STATE_CLOSING; + c->close_force = force_close; + if (force_close) { + client_abort_channels(c, 1); + client_finish_close(c); + return; + } + + client_request_channel_close(c); + client_drive(c); +} + void sshQ___ext_init__() { const char *dbg_env = getenv("ACTON_SSH_DEBUG"); const char *log_env = getenv("ACTON_SSH_LIBSSH_LOG"); @@ -1732,7 +1791,7 @@ B_NoneType sshQ__debug(B_str msg) { ssh_client_ctx *c = client_from_actor(self); if (c == NULL) return $R_CONT(c$cont, B_None); - client_close_internal(c, "closed"); + client_close_internal(c, "closed", 0); return $R_CONT(c$cont, B_None); } @@ -2282,7 +2341,7 @@ static void session_fail(ssh_server_session_ctx *s, const char *msg) { if (s == NULL || s->state == SESSION_STATE_CLOSED || s->state == SESSION_STATE_ERROR) return; s->state = SESSION_STATE_ERROR; - session_close_internal(s, msg); + session_close_internal(s, msg, 1); } static void session_start_auth_timer(ssh_server_session_ctx *s) { @@ -2344,7 +2403,7 @@ static void session_keepalive_cb(uv_timer_t *timer) { static void session_update_poll(ssh_server_session_ctx *s) { if (s->poll == NULL || s->session == NULL) return; - if (s->state == SESSION_STATE_CLOSING || s->state == SESSION_STATE_CLOSED) + if (s->state == SESSION_STATE_CLOSED) return; if (uv_is_closing((uv_handle_t *)s->poll)) return; @@ -2354,7 +2413,7 @@ static void session_update_poll(ssh_server_session_ctx *s) { return; } if (status & SSH_CLOSED) { - session_close_internal(s, "SSH session closed"); + session_close_internal(s, "SSH session closed", 1); return; } int flags = ssh_get_poll_flags(s->session); @@ -2445,10 +2504,15 @@ static int session_needs_write(ssh_server_session_ctx *s) { static void session_drive(ssh_server_session_ctx *s) { if (s == NULL) return; - if (s->state == SESSION_STATE_ERROR || s->state == SESSION_STATE_CLOSED || s->state == SESSION_STATE_CLOSING) + if (s->state == SESSION_STATE_ERROR || s->state == SESSION_STATE_CLOSED) return; if (!s->attached) return; + if (s->state == SESSION_STATE_CLOSING) { + session_drive_channels(s); + session_finish_close(s); + return; + } int spin = 0; while (1) { @@ -2821,6 +2885,76 @@ static void session_finalize(ssh_server_session_ctx *s) { s->actor->_session_id = toB_u64(0); } +static void session_reject_pending_messages(ssh_server_session_ctx *s) { + if (s->pending_auth) { + ssh_message_reply_default(s->pending_auth); + ssh_message_free(s->pending_auth); + s->pending_auth = NULL; + } + if (s->pending_channel_open) { + ssh_message_reply_default(s->pending_channel_open); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + } +} + +static void session_abort_channels(ssh_server_session_ctx *s) { + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + ssh_server_channel_ctx *next = ch->next; + server_channel_notify_close(ch, "Session closed"); + server_channel_finalize(ch); + ch->next = NULL; + ch = next; + } + s->channels = NULL; +} + +static void session_request_channel_close(ssh_server_session_ctx *s) { + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + if (ch->pending_req) { + ssh_message_reply_default(ch->pending_req); + ssh_message_free(ch->pending_req); + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + } + if (ch->state != SCHAN_STATE_CLOSED && ch->state != SCHAN_STATE_ERROR) { + ch->send_eof = 1; + ch->close_requested = 1; + } + ch = ch->next; + } +} + +static void session_finish_close(ssh_server_session_ctx *s) { + if (s == NULL || s->state != SESSION_STATE_CLOSING) + return; + if (s->close_force) { + if (s->poll != NULL) { + close_poll(&s->poll, session_poll_close_cb); + s->poll_events = 0; + return; + } + session_finalize(s); + return; + } + if (s->channels != NULL) { + session_update_poll(s); + return; + } + if (s->session != NULL && session_has_pending_write(s->session)) { + session_update_poll(s); + return; + } + if (s->poll != NULL) { + close_poll(&s->poll, session_poll_close_cb); + s->poll_events = 0; + return; + } + session_finalize(s); +} + static void server_close_internal(ssh_server_ctx *s, const char *reason) { if (s == NULL || s->state == SERVER_STATE_CLOSED || s->state == SERVER_STATE_CLOSING) return; @@ -2840,10 +2974,9 @@ static void server_close_internal(ssh_server_ctx *s, const char *reason) { ssh_server_session_ctx *sess = s->sessions; while (sess != NULL) { ssh_server_session_ctx *next = sess->next; - session_close_internal(sess, "Server closed"); + session_close_internal(sess, "Server closed", s->state == SERVER_STATE_ERROR); sess = next; } - s->sessions = NULL; if (s->poll != NULL) return; @@ -2851,47 +2984,37 @@ static void server_close_internal(ssh_server_ctx *s, const char *reason) { server_finalize(s); } -static void session_close_internal(ssh_server_session_ctx *s, const char *reason) { - if (s == NULL || s->state == SESSION_STATE_CLOSED || s->state == SESSION_STATE_CLOSING) +static void session_close_internal(ssh_server_session_ctx *s, const char *reason, int force_close) { + if (s == NULL || s->state == SESSION_STATE_CLOSED) return; - - s->state = SESSION_STATE_CLOSING; + if (!force_close && s->state != SESSION_STATE_READY) + force_close = 1; if (reason != NULL && s->close_reason == NULL) s->close_reason = acton_strdup(reason); + if (s->state == SESSION_STATE_CLOSING) { + if (force_close && !s->close_force) { + s->close_force = 1; + session_reject_pending_messages(s); + session_abort_channels(s); + } + session_finish_close(s); + return; + } stop_timer(&s->auth_timer, session_timer_close_cb); stop_timer(&s->keepalive_timer, session_timer_close_cb); - if (s->poll != NULL) { - close_poll(&s->poll, session_poll_close_cb); - s->poll_events = 0; - } - - ssh_server_channel_ctx *ch = s->channels; - while (ch != NULL) { - ssh_server_channel_ctx *next = ch->next; - server_channel_notify_close(ch, "Session closed"); - server_channel_finalize(ch); - ch->next = NULL; - ch = next; - } - s->channels = NULL; - - if (s->pending_auth) { - ssh_message_reply_default(s->pending_auth); - ssh_message_free(s->pending_auth); - s->pending_auth = NULL; - } - if (s->pending_channel_open) { - ssh_message_reply_default(s->pending_channel_open); - ssh_message_free(s->pending_channel_open); - s->pending_channel_open = NULL; - } - - if (s->poll != NULL) + s->state = SESSION_STATE_CLOSING; + s->close_force = force_close; + session_reject_pending_messages(s); + if (force_close) { + session_abort_channels(s); + session_finish_close(s); return; + } - session_finalize(s); + session_request_channel_close(s); + session_drive(s); } static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_out) { @@ -3188,7 +3311,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o ssh_server_session_ctx *s = session_from_actor(self); if (s == NULL) return $R_CONT(c$cont, B_None); - session_close_internal(s, "closed"); + session_close_internal(s, "closed", 0); return $R_CONT(c$cont, B_None); } diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 176c288..b38f555 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -443,7 +443,7 @@ actor ClientCloseFlushTester(t: testing.EnvT): ) after 0: start_server() - after 5.0: on_timeout() + after 10.0: on_timeout() def _test_ssh_client_close_flush(t: testing.EnvT): @@ -451,6 +451,202 @@ def _test_ssh_client_close_flush(t: testing.EnvT): ClientCloseFlushTester(t) +actor ClientSessionCloseFlushTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var payload_len = 1024 * 1024 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_buf = b"" + var client_close_requested = False + var server_close_requested = False + var saw_terminal_event = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def finish_success(): + if done: + return + done = True + log.info("test success", None) + request_client_close() + request_server_close() + t.success() + + def check_server_bytes(reason: str): + if done or not saw_terminal_event: + return + if len(server_buf) != payload_len: + finish_error(reason + ": server received " + str(len(server_buf)) + " bytes, expected " + str(payload_len)) + return + finish_success() + + def on_timeout(): + if done: + return + finish_error("timeout waiting for client session close flush test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + log.info("server listening", None) + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + ch.reject_request("exec disabled") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name == "netconf": + ch.accept_request() + else: + ch.reject_request("unsupported subsystem") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if done: + return + saw_terminal_event = True + check_server_bytes("session closed") + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_buf += data + if not client_close_requested: + request_client_close() + return + if done: + return + saw_terminal_event = True + check_server_bytes("server EOF") + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + if done: + return + saw_terminal_event = True + check_server_bytes("server channel close") + + def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo): + client.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + ch.request_subsystem("netconf") + payload = b"B" * payload_len + ch.write(payload) + + def ch_out(ch: ssh.Channel, data: ?bytes): + return + + def ch_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + return + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 2000) + attempts * 101) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_client_session_close_flush(t: testing.EnvT): + """Client.close should not drop queued channel writes.""" + ClientSessionCloseFlushTester(t) + + actor AuthRejectTester(t: testing.EnvT): log = logging.Logger(t.log_handler) @@ -589,6 +785,210 @@ def _test_ssh_auth_reject(t: testing.EnvT): AuthRejectTester(t) +actor ServerSessionCloseFlushTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var payload_len = 1024 * 1024 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_session: ?ssh.ServerSession = None + var client_buf = b"" + var client_close_requested = False + var server_close_requested = False + var session_close_requested = False + var saw_terminal_event = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def request_session_close(): + if session_close_requested: + return + session_close_requested = True + if server_session is not None: + server_session.close() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def finish_success(): + if done: + return + done = True + log.info("test success", None) + request_client_close() + request_server_close() + t.success() + + def check_client_bytes(reason: str): + if done or not saw_terminal_event: + return + if len(client_buf) != payload_len: + finish_error(reason + ": client received " + str(len(client_buf)) + " bytes, expected " + str(payload_len)) + return + finish_success() + + def on_timeout(): + if done: + return + finish_error("timeout waiting for server session close flush test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + server_session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + ch.reject_request("exec disabled") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "netconf": + ch.reject_request("unsupported subsystem") + return + ch.accept_request() + payload = b"C" * payload_len + ch.write(payload) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + if done: + return + saw_terminal_event = True + check_client_bytes("client closed") + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + ch.request_subsystem("netconf") + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + client_buf += data + if not session_close_requested: + request_session_close() + return + if done: + return + saw_terminal_event = True + check_client_bytes("client EOF") + + def ch_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + if done: + return + saw_terminal_event = True + check_client_bytes("client channel close") + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 52000 + (random.randint(0, 2000) + attempts * 107) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_server_session_close_flush(t: testing.EnvT): + """ServerSession.close should not drop queued channel writes.""" + ServerSessionCloseFlushTester(t) + + actor UnsupportedRequestTester(t: testing.EnvT): log = logging.Logger(t.log_handler) From b9b4710356ee7496255a825feee2d5b3f81993d7 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 00:25:34 +0100 Subject: [PATCH 12/38] Clarify channel close semantics It was too easy to use channel close as if it were a half-close. The Acton API only said "Close the channel", while libssh close semantics explicitly discard unread inbound data. That made request-response flows look valid even though they were racing their own replies. This change documents the distinction directly on Channel and ServerChannel: send_eof keeps the read side open, while close is a full close that may drop unread inbound bytes. It also adds concurrency coverage around shared-session traffic so the safe pattern stays exercised under load. This is correct because the binding now states the contract that libssh already implements, and the new stress cases follow the same half-close pattern production request-response traffic needs. --- src/ssh.act | 8 +- src/test_ssh_server.act | 460 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 464 insertions(+), 4 deletions(-) diff --git a/src/ssh.act b/src/ssh.act index c0af650..c21949f 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -195,11 +195,11 @@ actor Channel(client: Client, client.channel_write(self, data) action def send_eof() -> None: - """Send EOF to the channel""" + """Half-close writes while keeping the read side open""" client.channel_send_eof(self) action def close() -> None: - """Close the channel""" + """Close the channel and discard unread inbound data""" client.channel_close(self) @@ -444,7 +444,7 @@ actor ServerChannel(session: ServerSession, session.channel_write_stderr(self, data) action def send_eof() -> None: - """Send EOF to the channel""" + """Half-close writes while keeping the read side open""" session.channel_send_eof(self) action def send_exit_status(status: int) -> None: @@ -452,7 +452,7 @@ actor ServerChannel(session: ServerSession, session.channel_send_exit_status(self, status) action def close() -> None: - """Close the channel""" + """Close the channel and discard unread inbound data""" session.channel_close(self) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index b38f555..e7c5ee9 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -989,6 +989,466 @@ def _test_ssh_server_session_close_flush(t: testing.EnvT): ServerSessionCloseFlushTester(t) +actor ConcurrentChannelTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + PAYLOAD = b"channel-ok" + + var done = False + var attempts = 0 + var port: int = 0 + var num_channels = 8 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_session: ?ssh.ServerSession = None + var client_close_requested = False + var server_close_requested = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_opened = 0 + var client_closed_channels = 0 + var server_subsystems = 0 + var server_closed_channels = 0 + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_shutdown(): + if done: + return + if client_closed_channels == num_channels and server_closed_channels == num_channels: + request_client_close() + if server_session is not None: + server_session.close() + request_server_close() + + def maybe_finish(): + if done: + return + if client_closed and server_closed and session_closed and \ + client_opened == num_channels and server_subsystems == num_channels and \ + client_closed_channels == num_channels and server_closed_channels == num_channels: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + if server_session is not None: + server_session.close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error("timeout waiting for concurrent channel test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + server_session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "netconf": + finish_error("unexpected subsystem: " + name) + return + server_subsystems += 1 + ch.accept_request() + ch.write(PAYLOAD) + ch.close() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("unexpected client payload on concurrent channel test") + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("unexpected client stderr on concurrent channel test") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + server_closed_channels += 1 + maybe_shutdown() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + start_channel(c, 0) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + maybe_finish() + + def start_channel(c: ssh.Client, idx: int): + if idx >= num_channels: + return + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("channel " + str(idx) + " open error: " + err) + return + client_opened += 1 + ch.request_subsystem("netconf") + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if data != PAYLOAD: + finish_error("channel " + str(idx) + " received unexpected payload") + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel " + str(idx) + " received unexpected stderr") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + client_closed_channels += 1 + maybe_shutdown() + maybe_finish() + + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + start_channel(c, idx + 1) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 54000 + (random.randint(0, 2000) + attempts * 109) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_concurrent_channels(t: testing.EnvT): + """Multiple concurrent channels should share one session cleanly.""" + ConcurrentChannelTester(t) + + +actor ConcurrentWriteServerChannelHandler(payload_len: int, + ack: bytes, + on_complete: action() -> None, + on_error: action(str) -> None): + var recv_buf = b"" + + action def on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + recv_buf += data + return + if len(recv_buf) != payload_len: + on_error("server received " + str(len(recv_buf)) + " bytes, expected " + str(payload_len)) + return + ch.write(ack) + ch.close() + + action def on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + on_error("unexpected client stderr on concurrent write test") + + action def on_close(ch: ssh.ServerChannel, reason: str): + on_complete() + + +actor ConcurrentChannelWriteTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + ACK = b"ack" + + var done = False + var attempts = 0 + var port: int = 0 + var num_channels = 4 + var payload_len = 256 * 1024 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_session: ?ssh.ServerSession = None + var client_close_requested = False + var server_close_requested = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_opened = 0 + var client_acks = 0 + var client_closed_channels = 0 + var server_subsystems = 0 + var server_closed_channels = 0 + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_shutdown(): + if done: + return + if client_closed_channels == num_channels and server_closed_channels == num_channels: + request_client_close() + if server_session is not None: + server_session.close() + request_server_close() + + def maybe_finish(): + if done: + return + if client_closed and server_closed and session_closed and \ + client_opened == num_channels and server_subsystems == num_channels and \ + client_closed_channels == num_channels and server_closed_channels == num_channels: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + if server_session is not None: + server_session.close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error("timeout waiting for concurrent channel write test: client_opened=" + str(client_opened) + " client_acks=" + str(client_acks) + " server_subsystems=" + str(server_subsystems) + " client_closed_channels=" + str(client_closed_channels) + " server_closed_channels=" + str(server_closed_channels) + " client_closed=" + str(client_closed) + " session_closed=" + str(session_closed) + " server_closed=" + str(server_closed)) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + server_session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_server_channel_close(): + server_closed_channels += 1 + maybe_shutdown() + maybe_finish() + + def on_channel_open(sess: ssh.ServerSession): + handler = ConcurrentWriteServerChannelHandler( + payload_len, + ACK, + on_server_channel_close, + finish_error, + ) + ch = ssh.ServerChannel(sess, handler.on_data, handler.on_stderr, handler.on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "netconf": + finish_error("unexpected subsystem: " + name) + return + server_subsystems += 1 + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + start_channel(c, 0) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + maybe_finish() + + def start_channel(c: ssh.Client, idx: int): + if idx >= num_channels: + return + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("channel " + str(idx) + " open error: " + err) + return + client_opened += 1 + ch.request_subsystem("netconf") + payload = b"W" * payload_len + ch.write(payload) + ch.send_eof() + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if data != ACK: + finish_error("channel " + str(idx) + " received unexpected payload") + return + client_acks += 1 + ch.close() + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel " + str(idx) + " received unexpected stderr") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + client_closed_channels += 1 + maybe_shutdown() + maybe_finish() + + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + start_channel(c, idx + 1) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 56000 + (random.randint(0, 2000) + attempts * 113) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_concurrent_channel_writes(t: testing.EnvT): + """Multiple channels should flush writes over one session.""" + ConcurrentChannelWriteTester(t) + + actor UnsupportedRequestTester(t: testing.EnvT): log = logging.Logger(t.log_handler) From b58baa7c3cba88305742289e86282896a3705d82 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 00:25:40 +0100 Subject: [PATCH 13/38] Force sessions closed on errors A server-side failure was meant to force child sessions closed, but that intent was lost before it reached the session layer. server_fail set the server state to ERROR, then server_close_internal immediately overwrote it to CLOSING and checked the new state when closing child sessions. In practice that quietly degraded server error paths into graceful session drains. This change captures whether the server is already in an error close before switching the server state to CLOSING, then passes that latched decision into each child session close. This is correct because graceful server shutdown still drains child sessions, while real server failures now preserve the immediate error-close behavior they were already trying to request. --- src/ssh.ext.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 41c794a..d04ea69 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -2958,6 +2958,7 @@ static void session_finish_close(ssh_server_session_ctx *s) { static void server_close_internal(ssh_server_ctx *s, const char *reason) { if (s == NULL || s->state == SERVER_STATE_CLOSED || s->state == SERVER_STATE_CLOSING) return; + int force_sessions = (s->state == SERVER_STATE_ERROR); if (!s->listen_ok && !s->listen_notified) { server_notify_listen(s, reason ? reason : "closed"); @@ -2974,7 +2975,7 @@ static void server_close_internal(ssh_server_ctx *s, const char *reason) { ssh_server_session_ctx *sess = s->sessions; while (sess != NULL) { ssh_server_session_ctx *next = sess->next; - session_close_internal(sess, "Server closed", s->state == SERVER_STATE_ERROR); + session_close_internal(sess, "Server closed", force_sessions); sess = next; } From 055436fa8fd1e7367252889b9c2e4c8d13dd5351 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 00:29:27 +0100 Subject: [PATCH 14/38] Tighten channel callback teardown Server-side channel teardown had two holes in exceptional paths. Errored server channels were marked SCHAN_STATE_ERROR and then removed from the session list without running the finalizer, which left the actor-side channel id and EOF bookkeeping behind. At the same time, client and server channel callback structs were allocated by the binding, removed from libssh on teardown, and then leaked because ownership never returned to the binding. This change finalizes errored server channels before unlinking them from the session, frees channel callback structs after removing them from libssh, and only marks callbacks as installed after ssh_add_channel_callbacks succeeds. This is correct because every channel teardown path now converges on the same cleanup logic, and callback lifetime now matches libssh's callback registration API instead of relying on list removal to free caller-owned memory. --- src/ssh.ext.c | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index d04ea69..42b827d 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -778,8 +778,11 @@ static void client_channel_setup_callbacks(ssh_channel_ctx *ch) { cb->channel_eof_function = client_channel_eof_cb; cb->channel_close_function = client_channel_close_cb; cb->channel_write_wontblock_function = client_channel_write_wontblock_cb; + if (ssh_add_channel_callbacks(ch->channel, cb) != SSH_OK) { + acton_free(cb); + return; + } ch->callbacks = cb; - ssh_add_channel_callbacks(ch->channel, cb); if (ssh_debug_enabled) { ssh_debug_log("client channel callbacks set ch=%p", (void *)ch); } @@ -823,6 +826,7 @@ static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { channel_notify_exit(ch, exit_status, exit_signal); if (ch->callbacks) { ssh_remove_channel_callbacks(ch->channel, ch->callbacks); + acton_free(ch->callbacks); ch->callbacks = NULL; } ssh_channel_free(ch->channel); @@ -2094,8 +2098,11 @@ static void server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { cb->channel_eof_function = server_channel_eof_cb; cb->channel_close_function = server_channel_close_cb; cb->channel_write_wontblock_function = server_channel_write_wontblock_cb; + if (ssh_add_channel_callbacks(ch->channel, cb) != SSH_OK) { + acton_free(cb); + return; + } ch->callbacks = cb; - ssh_add_channel_callbacks(ch->channel, cb); if (ssh_debug_enabled) { ssh_debug_log("server channel callbacks set ch=%p", (void *)ch); } @@ -2128,6 +2135,7 @@ static void server_channel_finalize(ssh_server_channel_ctx *ch) { if (ch->channel != NULL) { if (ch->callbacks) { ssh_remove_channel_callbacks(ch->channel, ch->callbacks); + acton_free(ch->callbacks); ch->callbacks = NULL; } ssh_channel_free(ch->channel); @@ -2323,6 +2331,9 @@ static void session_drive_channels(ssh_server_session_ctx *s) { while (ch != NULL) { ssh_server_channel_ctx *next = ch->next; server_channel_drive(s, ch); + if (ch->state == SCHAN_STATE_ERROR) { + server_channel_finalize(ch); + } if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) { if (prev != NULL) { prev->next = next; From b9dca3c3e9165e3bfe256d9a5fa088a0e5b219f6 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 00:32:06 +0100 Subject: [PATCH 15/38] Fail channels on exit-status errors Sending a server-side exit status already reported a close when it failed, but it left the channel otherwise live. That meant callers could receive an error close callback while the channel remained in the session list and kept participating in later drive cycles. This change marks the channel as failed when ssh_channel_request_send_exit_status returns an error, so the normal server-channel error teardown path runs on the next session drive. This is correct because exit-status send failures are channel write failures like the other server-side send paths, and they should terminate the channel instead of only emitting a notification. --- src/ssh.ext.c | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 42b827d..e089b1b 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -3396,6 +3396,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o char errmsg[256] = {0}; snprintf(errmsg, sizeof(errmsg), "SSH server send exit status failed: %s", ssh_get_error(s->session)); server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; } session_drive(s); return $R_CONT(c$cont, B_None); From c36a182047289bb81ec753742a8e0a13485f8c04 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 10:38:53 +0100 Subject: [PATCH 16/38] Free channel write queue nodes Per-write queue nodes were retained after successful writes and also left behind when channels were finalized with queued data still present. Under long stress this leaked one heap allocation per queued write chunk on both the client and server channel paths. Free write queue nodes as they are consumed and drain any remaining nodes during channel finalization. This keeps graceful close semantics unchanged while making channel teardown release the queue state it no longer owns. --- src/ssh.ext.c | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index e089b1b..da489cc 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -806,6 +806,12 @@ static void channel_notify_eof(ssh_channel_ctx *ch) { } static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { + while (ch->write_head != NULL) { + write_chunk_t *chunk = ch->write_head; + ch->write_head = chunk->next; + acton_free(chunk); + } + ch->write_tail = NULL; if (ch->channel != NULL) { int exit_status = -1; B_str exit_signal = B_None; @@ -879,6 +885,7 @@ static void channel_try_write(ssh_client_ctx *c, ssh_channel_ctx *ch) { while (ch->write_head != NULL && ch->write_head->data->nbytes == ch->write_head->offset) { write_chunk_t *chunk = ch->write_head; ch->write_head = chunk->next; + acton_free(chunk); if (ch->write_head == NULL) ch->write_tail = NULL; } @@ -894,6 +901,7 @@ static void channel_try_write(ssh_client_ctx *c, ssh_channel_ctx *ch) { chunk->offset += (size_t)rc; if (chunk->offset >= chunk->data->nbytes) { ch->write_head = chunk->next; + acton_free(chunk); if (ch->write_head == NULL) ch->write_tail = NULL; } @@ -2126,6 +2134,12 @@ static void server_channel_notify_eof(ssh_server_channel_ctx *ch) { } static void server_channel_finalize(ssh_server_channel_ctx *ch) { + while (ch->write_head != NULL) { + server_write_chunk_t *chunk = ch->write_head; + ch->write_head = chunk->next; + acton_free(chunk); + } + ch->write_tail = NULL; if (ch->pending_req) { ssh_message_reply_default(ch->pending_req); ssh_message_free(ch->pending_req); @@ -2178,6 +2192,7 @@ static void server_channel_try_write(ssh_server_session_ctx *s, ssh_server_chann while (ch->write_head != NULL && ch->write_head->data->nbytes == ch->write_head->offset) { server_write_chunk_t *chunk = ch->write_head; ch->write_head = chunk->next; + acton_free(chunk); if (ch->write_head == NULL) ch->write_tail = NULL; } @@ -2201,6 +2216,7 @@ static void server_channel_try_write(ssh_server_session_ctx *s, ssh_server_chann } if (chunk->offset >= chunk->data->nbytes) { ch->write_head = chunk->next; + acton_free(chunk); if (ch->write_head == NULL) ch->write_tail = NULL; } From 5ccb8e7d40692f14d07bb881957d68eaf69882b7 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 10:52:43 +0100 Subject: [PATCH 17/38] Validate channel ownership Client and server channel entrypoints trusted actor handles without checking which client or session actually owned the underlying channel. A wrong-handle call could queue writes or accept requests on an unrelated channel, and in the client case could also poison the real channel with a spurious "Channel not ready" error. Track the owning client for each client channel and validate ownership on both client and server channel APIs before mutating channel state. Mismatched handles are now ignored instead of touching the foreign channel. A regression test drives two live sessions, crosses both the client and server channel handles, and asserts that no payload leaks into the wrong channel. --- src/ssh.ext.c | 62 ++++-- src/test_ssh_server.act | 431 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 477 insertions(+), 16 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index da489cc..6aa7135 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -125,6 +125,7 @@ typedef struct write_chunk { typedef struct ssh_channel_ctx { struct ssh_channel_ctx *next; ssh_channel channel; + struct ssh_client_ctx *client; struct ssh_channel_callbacks_struct *callbacks; sshQ_Channel actor; channel_state_t state; @@ -1823,6 +1824,7 @@ B_NoneType sshQ__debug(B_str msg) { } ssh_channel_ctx *ch = acton_calloc(1, sizeof(ssh_channel_ctx)); + ch->client = c; ch->actor = channel; ch->callbacks = NULL; ch->state = CHAN_STATE_INIT; @@ -1853,9 +1855,33 @@ B_NoneType sshQ__debug(B_str msg) { static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) { if (c == NULL || ch == NULL) { - return -1; + return 1; + } + if (ch->client != c) { + if (ssh_debug_enabled) { + ssh_debug_log("client channel ownership mismatch: client=%p owner=%p ch=%p", + (void *)c, (void *)ch->client, (void *)ch); + } + return 1; } if (ch->state == CHAN_STATE_ERROR || ch->state == CHAN_STATE_CLOSED) { + return 2; + } + return 0; +} + +static int server_channel_validate(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { + if (s == NULL || ch == NULL) { + return -1; + } + if (ch->session != s) { + if (ssh_debug_enabled) { + ssh_debug_log("server channel ownership mismatch: session=%p owner=%p ch=%p", + (void *)s, (void *)ch->session, (void *)ch); + } + return -1; + } + if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) { return -1; } return 0; @@ -1864,8 +1890,9 @@ static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) { $R sshQ_ClientD_channel_request_execG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str cmd) { ssh_client_ctx *c = client_from_actor(self); ssh_channel_ctx *ch = channel_from_actor(channel); - if (channel_validate(c, ch) != 0) { - if (ch != NULL) + int valid = channel_validate(c, ch); + if (valid != 0) { + if (valid == 2 && ch != NULL) channel_notify_error(ch, "Channel not ready"); return $R_CONT(c$cont, B_None); } @@ -1889,8 +1916,9 @@ static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) { $R sshQ_ClientD_channel_request_shellG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str term, B_int cols, B_int rows, B_int width_px, B_int height_px, B_bool with_pty) { ssh_client_ctx *c = client_from_actor(self); ssh_channel_ctx *ch = channel_from_actor(channel); - if (channel_validate(c, ch) != 0) { - if (ch != NULL) + int valid = channel_validate(c, ch); + if (valid != 0) { + if (valid == 2 && ch != NULL) channel_notify_error(ch, "Channel not ready"); return $R_CONT(c$cont, B_None); } @@ -1920,8 +1948,9 @@ static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) { $R sshQ_ClientD_channel_request_subsystemG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str name) { ssh_client_ctx *c = client_from_actor(self); ssh_channel_ctx *ch = channel_from_actor(channel); - if (channel_validate(c, ch) != 0) { - if (ch != NULL) + int valid = channel_validate(c, ch); + if (valid != 0) { + if (valid == 2 && ch != NULL) channel_notify_error(ch, "Channel not ready"); return $R_CONT(c$cont, B_None); } @@ -1945,8 +1974,9 @@ static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) { $R sshQ_ClientD_channel_writeG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_bytes data) { ssh_client_ctx *c = client_from_actor(self); ssh_channel_ctx *ch = channel_from_actor(channel); - if (channel_validate(c, ch) != 0) { - if (ch != NULL) + int valid = channel_validate(c, ch); + if (valid != 0) { + if (valid == 2 && ch != NULL) channel_notify_error(ch, "Channel not ready"); return $R_CONT(c$cont, B_None); } @@ -3346,7 +3376,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o $R sshQ_ServerSessionD_channel_accept_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { ssh_server_session_ctx *s = session_from_actor(self); ssh_server_channel_ctx *ch = server_channel_from_actor(channel); - if (s == NULL || ch == NULL || ch->pending_req == NULL) + if (server_channel_validate(s, ch) != 0 || ch->pending_req == NULL) return $R_CONT(c$cont, B_None); ssh_message_channel_request_reply_success(ch->pending_req); @@ -3360,7 +3390,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o $R sshQ_ServerSessionD_channel_reject_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_str reason) { ssh_server_session_ctx *s = session_from_actor(self); ssh_server_channel_ctx *ch = server_channel_from_actor(channel); - if (s == NULL || ch == NULL || ch->pending_req == NULL) + if (server_channel_validate(s, ch) != 0 || ch->pending_req == NULL) return $R_CONT(c$cont, B_None); ssh_message_reply_default(ch->pending_req); @@ -3375,7 +3405,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o $R sshQ_ServerSessionD_channel_writeG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_bytes data) { ssh_server_session_ctx *s = session_from_actor(self); ssh_server_channel_ctx *ch = server_channel_from_actor(channel); - if (s == NULL || ch == NULL) + if (server_channel_validate(s, ch) != 0) return $R_CONT(c$cont, B_None); server_channel_queue_write(ch, data, 0); session_drive(s); @@ -3385,7 +3415,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o $R sshQ_ServerSessionD_channel_write_stderrG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_bytes data) { ssh_server_session_ctx *s = session_from_actor(self); ssh_server_channel_ctx *ch = server_channel_from_actor(channel); - if (s == NULL || ch == NULL) + if (server_channel_validate(s, ch) != 0) return $R_CONT(c$cont, B_None); server_channel_queue_write(ch, data, 1); session_drive(s); @@ -3395,7 +3425,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o $R sshQ_ServerSessionD_channel_send_eofG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { ssh_server_session_ctx *s = session_from_actor(self); ssh_server_channel_ctx *ch = server_channel_from_actor(channel); - if (s == NULL || ch == NULL) + if (server_channel_validate(s, ch) != 0) return $R_CONT(c$cont, B_None); ch->send_eof = 1; session_drive(s); @@ -3405,7 +3435,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o $R sshQ_ServerSessionD_channel_send_exit_statusG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_int status) { ssh_server_session_ctx *s = session_from_actor(self); ssh_server_channel_ctx *ch = server_channel_from_actor(channel); - if (s == NULL || ch == NULL || ch->channel == NULL) + if (server_channel_validate(s, ch) != 0 || ch->channel == NULL) return $R_CONT(c$cont, B_None); int rc = ssh_channel_request_send_exit_status(ch->channel, fromB_int(status)); if (rc != SSH_OK) { @@ -3421,7 +3451,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o $R sshQ_ServerSessionD_channel_closeG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { ssh_server_session_ctx *s = session_from_actor(self); ssh_server_channel_ctx *ch = server_channel_from_actor(channel); - if (s == NULL || ch == NULL) + if (server_channel_validate(s, ch) != 0) return $R_CONT(c$cont, B_None); ch->send_eof = 1; ch->close_requested = 1; diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index e7c5ee9..c048ca1 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -1449,6 +1449,437 @@ def _test_ssh_concurrent_channel_writes(t: testing.EnvT): ConcurrentChannelWriteTester(t) +actor CrossHandleOwnershipTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + CLIENT1_PAYLOAD = b"client-one" + CLIENT2_PAYLOAD = b"client-two" + BAD_CLIENT_PAYLOAD = b"wrong-client" + SERVER1_REPLY = b"server-one" + SERVER2_REPLY = b"server-two" + BAD_SERVER_PAYLOAD = b"wrong-server" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client1: ?ssh.Client = None + var client2: ?ssh.Client = None + var client_channel1: ?ssh.Channel = None + var client_channel2: ?ssh.Channel = None + var server_session1: ?ssh.ServerSession = None + var server_session2: ?ssh.ServerSession = None + var server_channel1: ?ssh.ServerChannel = None + var server_channel2: ?ssh.ServerChannel = None + var client1_buf = b"" + var client2_buf = b"" + var server1_buf = b"" + var server2_buf = b"" + var io_started = False + var replies_sent = False + var shutdown_started = False + var client1_opened = False + var client2_opened = False + var client1_ok = False + var client2_ok = False + var server1_ok = False + var server2_ok = False + var client_close_count = 0 + var client_channel_close_count = 0 + var server_channel_close_count = 0 + var session_close_count = 0 + var server_closed = False + var client1_close_requested = False + var client2_close_requested = False + var server_close_requested = False + + def request_client1_close(): + if client1_close_requested: + return + client1_close_requested = True + if client1 is not None: + client1.close() + + def request_client2_close(): + if client2_close_requested: + return + client2_close_requested = True + if client2 is not None: + client2.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def begin_shutdown(): + if shutdown_started: + return + shutdown_started = True + if client_channel1 is not None: + client_channel1.close() + if client_channel2 is not None: + client_channel2.close() + request_client1_close() + request_client2_close() + if server_session1 is not None: + server_session1.close() + if server_session2 is not None: + server_session2.close() + request_server_close() + + def maybe_finish(): + if done: + return + if client1_ok and client2_ok and server1_ok and server2_ok and \ + client_close_count == 2 and client_channel_close_count == 2 and \ + server_channel_close_count == 2 and session_close_count == 2 and \ + server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + begin_shutdown() + t.error(Exception(msg)) + + def maybe_start_io(): + if io_started or done: + return + if not client1_opened or not client2_opened: + return + if client1 is not None: + if client_channel1 is not None: + if client_channel2 is not None: + if server_channel1 is not None: + if server_channel2 is not None: + io_started = True + client1.channel_write(client_channel2, BAD_CLIENT_PAYLOAD) + client_channel1.write(CLIENT1_PAYLOAD) + client_channel2.write(CLIENT2_PAYLOAD) + + def maybe_send_replies(): + if replies_sent or done: + return + if not server1_ok or not server2_ok: + return + if server_session1 is not None: + if server_channel1 is not None: + if server_channel2 is not None: + replies_sent = True + server_session1.channel_write(server_channel2, BAD_SERVER_PAYLOAD) + server_channel1.write(SERVER1_REPLY) + server_channel2.write(SERVER2_REPLY) + return + finish_error("server channels not ready for reply") + + def maybe_shutdown(): + if client1_ok and client2_ok: + begin_shutdown() + + def on_timeout(): + if done: + return + finish_error("timeout waiting for cross-handle ownership test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client1() + after 0.01: start_client2() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + if server_channel1 is None: + server_session1 = sess + ch = ssh.ServerChannel(sess, srv1_on_data, srv1_on_stderr, srv1_on_close) + server_channel1 = ch + sess.accept_channel(ch) + else: + if server_channel2 is None: + server_session2 = sess + ch = ssh.ServerChannel(sess, srv2_on_data, srv2_on_stderr, srv2_on_close) + server_channel2 = ch + sess.accept_channel(ch) + else: + finish_error("unexpected extra server channel") + return + maybe_start_io() + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "echo": + finish_error("unexpected subsystem: " + name) + return + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_close_count += 1 + maybe_finish() + + def srv1_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server1_buf += data + if server1_buf.find(BAD_CLIENT_PAYLOAD) >= 0: + finish_error("cross-client payload leaked into server channel 1") + return + if len(server1_buf) > len(CLIENT1_PAYLOAD): + finish_error("server channel 1 received extra data") + return + if len(server1_buf) == len(CLIENT1_PAYLOAD): + if server1_buf == CLIENT1_PAYLOAD: + if server1_ok: + finish_error("client-one payload delivered twice") + return + server1_ok = True + else: + if server1_buf == CLIENT2_PAYLOAD: + if server2_ok: + finish_error("client-two payload delivered twice") + return + server2_ok = True + else: + finish_error("server channel 1 received wrong payload") + return + maybe_send_replies() + + def srv1_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server channel 1 received unexpected stderr") + + def srv1_on_close(ch: ssh.ServerChannel, reason: str): + server_channel_close_count += 1 + maybe_finish() + + def srv2_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server2_buf += data + if server2_buf.find(BAD_CLIENT_PAYLOAD) >= 0: + finish_error("cross-client payload leaked into server channel 2") + return + if len(server2_buf) > len(CLIENT2_PAYLOAD): + finish_error("server channel 2 received extra data") + return + if len(server2_buf) == len(CLIENT2_PAYLOAD): + if server2_buf == CLIENT2_PAYLOAD: + if server2_ok: + finish_error("client-two payload delivered twice") + return + server2_ok = True + else: + if server2_buf == CLIENT1_PAYLOAD: + if server1_ok: + finish_error("client-one payload delivered twice") + return + server1_ok = True + else: + finish_error("server channel 2 received wrong payload") + return + maybe_send_replies() + + def srv2_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server channel 2 received unexpected stderr") + + def srv2_on_close(ch: ssh.ServerChannel, reason: str): + server_channel_close_count += 1 + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client1_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client 1 connect error: " + err) + return + client_channel1 = ssh.Channel(c, ch1_open, ch1_out, ch1_err, ch1_exit, ch1_close) + + def on_client2_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client 2 connect error: " + err) + return + client_channel2 = ssh.Channel(c, ch2_open, ch2_out, ch2_err, ch2_exit, ch2_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_close_count += 1 + maybe_finish() + + def ch1_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel 1 open error: " + err) + return + client1_opened = True + client_channel1 = ch + ch.request_subsystem("echo") + maybe_start_io() + + def ch1_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + client1_buf += data + if client1_buf.find(BAD_SERVER_PAYLOAD) >= 0: + finish_error("cross-session server payload leaked into client channel 1") + return + if len(client1_buf) > len(SERVER1_REPLY): + finish_error("client channel 1 received extra data") + return + if len(client1_buf) == len(SERVER1_REPLY): + if client1_buf == SERVER1_REPLY: + if client1_ok: + finish_error("server-one reply delivered twice") + return + client1_ok = True + else: + if client1_buf == SERVER2_REPLY: + if client2_ok: + finish_error("server-two reply delivered twice") + return + client2_ok = True + else: + finish_error("client channel 1 received wrong payload") + return + maybe_shutdown() + + def ch1_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client channel 1 received unexpected stderr") + + def ch1_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch1_close(ch: ssh.Channel, reason: str): + client_channel_close_count += 1 + maybe_finish() + + def ch2_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel 2 open error: " + err) + return + client2_opened = True + client_channel2 = ch + ch.request_subsystem("echo") + maybe_start_io() + + def ch2_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + client2_buf += data + if client2_buf.find(BAD_SERVER_PAYLOAD) >= 0: + finish_error("cross-session server payload leaked into client channel 2") + return + if len(client2_buf) > len(SERVER2_REPLY): + finish_error("client channel 2 received extra data") + return + if len(client2_buf) == len(SERVER2_REPLY): + if client2_buf == SERVER2_REPLY: + if client2_ok: + finish_error("server-two reply delivered twice") + return + client2_ok = True + else: + if client2_buf == SERVER1_REPLY: + if client1_ok: + finish_error("server-one reply delivered twice") + return + client1_ok = True + else: + finish_error("client channel 2 received wrong payload") + return + maybe_shutdown() + + def ch2_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client channel 2 received unexpected stderr") + + def ch2_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch2_close(ch: ssh.Channel, reason: str): + client_channel_close_count += 1 + maybe_finish() + + def start_client1(): + client1 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client1_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_client2(): + client2 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client2_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 2000) + attempts * 131) % 2000 + server_close_requested = False + client1_close_requested = False + client2_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_cross_handle_ownership(t: testing.EnvT): + """Wrong client or session should not drive another channel.""" + CrossHandleOwnershipTester(t) + + actor UnsupportedRequestTester(t: testing.EnvT): log = logging.Logger(t.log_handler) From d09c5fea1871570364c4e5c61ad1fe2a53869171 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 11:38:07 +0100 Subject: [PATCH 18/38] Harden SSH command completion RunCommand timeouts only closed the channel, so callers could miss a terminal callback or see late channel events mutate state after timeout. Several server reply paths also assumed callback registration and libssh control-packet sends always succeeded, which hid real errors under load. The stress suite had two tests that raced their own requests or teardown and could time out even when the transport behaved correctly. This change makes RunCommand timeout an explicit terminal error and ignores late open, data, exit, and close callbacks after completion. It checks server-side callback setup and libssh reply/send results, treats SSH_AGAIN as backpressure instead of failure, and adds explicit coverage for RunCommand timeout and exit-status delivery. The ownership and unsupported-request stress tests now use isolated port ranges and a more ordered shutdown so they exercise the binding instead of their own close races. This is correct because command completion is now single-shot, server control replies no longer fail silently, and the stress suite keeps reporting transport regressions instead of test-induced flakes. --- src/ssh.act | 14 ++ src/ssh.ext.c | 128 ++++++++--- src/test_ssh_server.act | 457 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 557 insertions(+), 42 deletions(-) diff --git a/src/ssh.act b/src/ssh.act index c21949f..5f9ce82 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -226,6 +226,8 @@ actor RunCommand(client: Client, on_exit(ch, _exit_code, _exit_signal, out_buf, err_buf, _error) def _on_open(ch: Channel, err: ?str): + if _done: + return if err is not None: _error = err _finish(ch) @@ -233,6 +235,8 @@ actor RunCommand(client: Client, ch.request_exec(cmd) def _on_stdout(ch: Channel, data: ?bytes): + if _done: + return if data is not None: out_buf += data else: @@ -240,6 +244,8 @@ actor RunCommand(client: Client, _check_done(ch) def _on_stderr(ch: Channel, data: ?bytes): + if _done: + return if data is not None: err_buf += data else: @@ -247,12 +253,16 @@ actor RunCommand(client: Client, _check_done(ch) def _on_exit(ch: Channel, code: int, sig: ?str): + if _done: + return _exited = True _exit_code = code _exit_signal = sig _check_done(ch) def _on_close(ch: Channel, reason: str): + if _done: + return if _error is None and reason != "closed": _error = reason if _error is not None: @@ -267,6 +277,10 @@ actor RunCommand(client: Client, if timeout is not None: def _on_timeout(): + if _done: + return + _error = "timeout" + _finish(_channel) _channel.close() after timeout: _on_timeout() diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 6aa7135..0cddba9 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -318,6 +318,7 @@ static void session_pump_io(ssh_server_session_ctx *s); static void session_close_internal(ssh_server_session_ctx *s, const char *reason, int force_close); static void session_finalize(ssh_server_session_ctx *s); static void session_finish_close(ssh_server_session_ctx *s); +static void session_fail(ssh_server_session_ctx *s, const char *msg); static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch); static int session_needs_write(ssh_server_session_ctx *s); static void session_poll_cb(uv_poll_t *handle, int status, int events); @@ -769,9 +770,11 @@ static int client_channel_write_wontblock_cb(ssh_session session, ssh_channel ch return 0; } -static void client_channel_setup_callbacks(ssh_channel_ctx *ch) { - if (ch == NULL || ch->channel == NULL || ch->callbacks != NULL) - return; +static int client_channel_setup_callbacks(ssh_channel_ctx *ch) { + if (ch == NULL || ch->channel == NULL) + return SSH_ERROR; + if (ch->callbacks != NULL) + return SSH_OK; struct ssh_channel_callbacks_struct *cb = acton_calloc(1, sizeof(*cb)); ssh_callbacks_init(cb); cb->userdata = ch; @@ -781,12 +784,13 @@ static void client_channel_setup_callbacks(ssh_channel_ctx *ch) { cb->channel_write_wontblock_function = client_channel_write_wontblock_cb; if (ssh_add_channel_callbacks(ch->channel, cb) != SSH_OK) { acton_free(cb); - return; + return SSH_ERROR; } ch->callbacks = cb; if (ssh_debug_enabled) { ssh_debug_log("client channel callbacks set ch=%p", (void *)ch); } + return SSH_OK; } static void channel_notify_eof(ssh_channel_ctx *ch) { @@ -967,7 +971,10 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { channel_fail(c, ch, "Failed to create SSH channel"); return; } - client_channel_setup_callbacks(ch); + if (client_channel_setup_callbacks(ch) != SSH_OK) { + channel_fail(c, ch, "Failed to set SSH channel callbacks"); + return; + } if (ssh_debug_enabled) { ssh_debug_log("client channel new ch=%p callbacks=%p", (void *)ch, (void *)ch->callbacks); } @@ -2126,9 +2133,11 @@ static int server_channel_write_wontblock_cb(ssh_session session, ssh_channel ch return 0; } -static void server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { - if (ch == NULL || ch->channel == NULL || ch->callbacks != NULL) - return; +static int server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { + if (ch == NULL || ch->channel == NULL) + return SSH_ERROR; + if (ch->callbacks != NULL) + return SSH_OK; struct ssh_channel_callbacks_struct *cb = acton_calloc(1, sizeof(*cb)); ssh_callbacks_init(cb); cb->userdata = ch; @@ -2138,12 +2147,13 @@ static void server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { cb->channel_write_wontblock_function = server_channel_write_wontblock_cb; if (ssh_add_channel_callbacks(ch->channel, cb) != SSH_OK) { acton_free(cb); - return; + return SSH_ERROR; } ch->callbacks = cb; if (ssh_debug_enabled) { ssh_debug_log("server channel callbacks set ch=%p", (void *)ch); } + return SSH_OK; } static void server_channel_notify_eof(ssh_server_channel_ctx *ch) { @@ -2401,6 +2411,16 @@ static void session_fail(ssh_server_session_ctx *s, const char *msg) { session_close_internal(s, msg, 1); } +static int session_check_reply_rc(ssh_server_session_ctx *s, int rc, const char *context) { + if (rc == SSH_OK || rc == SSH_AGAIN) + return 0; + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "%s: %s", + context, s != NULL && s->session != NULL ? ssh_get_error(s->session) : "unknown error"); + session_fail(s, errmsg); + return -1; +} + static void session_start_auth_timer(ssh_server_session_ctx *s) { if (s == NULL || s->auth_timeout <= 0.0 || s->auth_timer != NULL) return; @@ -2612,20 +2632,26 @@ static void session_drive(ssh_server_session_ctx *s) { } int type = ssh_message_type(msg); if (type == SSH_REQUEST_SERVICE) { + int rc; const char *service = ssh_message_service_service(msg); if (service && strcmp(service, "ssh-userauth") == 0) { - ssh_message_service_reply_success(msg); + rc = ssh_message_service_reply_success(msg); } else { - ssh_message_reply_default(msg); + rc = ssh_message_reply_default(msg); } ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH service reply failed") != 0) + return; continue; } if (type == SSH_REQUEST_AUTH && ssh_message_subtype(msg) == SSH_AUTH_METHOD_PASSWORD) { if (s->on_auth == NULL) { + int rc; ssh_message_auth_set_methods(msg, SSH_AUTH_METHOD_PASSWORD); - ssh_message_reply_default(msg); + rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH auth reject failed") != 0) + return; continue; } const char *user = ssh_message_auth_user(msg); @@ -2648,8 +2674,10 @@ static void session_drive(ssh_server_session_ctx *s) { session_update_poll(s); return; } - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH auth reply failed") != 0) + return; continue; } @@ -2666,14 +2694,20 @@ static void session_drive(ssh_server_session_ctx *s) { int type = ssh_message_type(msg); if (type == SSH_REQUEST_CHANNEL_OPEN) { if (s->pending_channel_open != NULL) { - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return; } else if (ssh_message_subtype(msg) != SSH_CHANNEL_SESSION) { - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return; } else if (s->on_channel_open == NULL) { - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return; } else { s->pending_channel_open = msg; $action f = ($action)s->on_channel_open; @@ -2684,12 +2718,16 @@ static void session_drive(ssh_server_session_ctx *s) { ssh_channel chan = ssh_message_channel_request_channel(msg); ssh_server_channel_ctx *ch = server_channel_from_ssh(s, chan); if (ch == NULL || ch->pending_req != NULL) { - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0) + return; } else if (ssh_message_subtype(msg) == SSH_CHANNEL_REQUEST_EXEC) { if (s->on_exec == NULL) { - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH exec reject failed") != 0) + return; } else { const char *cmd = ssh_message_channel_request_command(msg); ch->pending_req = msg; @@ -2701,8 +2739,10 @@ static void session_drive(ssh_server_session_ctx *s) { } } else if (ssh_message_subtype(msg) == SSH_CHANNEL_REQUEST_SUBSYSTEM) { if (s->on_subsystem == NULL) { - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH subsystem reject failed") != 0) + return; } else { const char *name = ssh_message_channel_request_subsystem(msg); ch->pending_req = msg; @@ -2713,20 +2753,27 @@ static void session_drive(ssh_server_session_ctx *s) { break; } } else { - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0) + return; } } else if (type == SSH_REQUEST_SERVICE) { + int rc; const char *service = ssh_message_service_service(msg); if (service && strcmp(service, "ssh-connection") == 0) { - ssh_message_service_reply_success(msg); + rc = ssh_message_service_reply_success(msg); } else { - ssh_message_reply_default(msg); + rc = ssh_message_reply_default(msg); } ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH connection service reply failed") != 0) + return; } else { - ssh_message_reply_default(msg); + int rc = ssh_message_reply_default(msg); ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH request reply failed") != 0) + return; } } session_drive_channels(s); @@ -3272,9 +3319,11 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o if (s == NULL || s->pending_auth == NULL) return $R_CONT(c$cont, B_None); - ssh_message_auth_reply_success(s->pending_auth, 0); + int rc = ssh_message_auth_reply_success(s->pending_auth, 0); ssh_message_free(s->pending_auth); s->pending_auth = NULL; + if (session_check_reply_rc(s, rc, "SSH auth accept failed") != 0) + return $R_CONT(c$cont, B_None); s->state = SESSION_STATE_READY; stop_timer(&s->auth_timer, session_timer_close_cb); session_start_keepalive(s); @@ -3289,10 +3338,12 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o return $R_CONT(c$cont, B_None); ssh_message_auth_set_methods(s->pending_auth, SSH_AUTH_METHOD_PASSWORD); - ssh_message_reply_default(s->pending_auth); + int rc = ssh_message_reply_default(s->pending_auth); ssh_message_free(s->pending_auth); s->pending_auth = NULL; (void)reason; + if (session_check_reply_rc(s, rc, "SSH auth reject failed") != 0) + return $R_CONT(c$cont, B_None); session_drive(s); return $R_CONT(c$cont, B_None); @@ -3308,13 +3359,15 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o ssh_channel chan = ssh_message_channel_request_open_reply_accept(s->pending_channel_open); if (chan == NULL) { - ssh_message_reply_default(s->pending_channel_open); + int rc = ssh_message_reply_default(s->pending_channel_open); ssh_message_free(s->pending_channel_open); s->pending_channel_open = NULL; if (on_close) { $action2 f = ($action2)on_close; f->$class->__asyn__(f, channel, to$str((char *)"Failed to accept channel open")); } + if (session_check_reply_rc(s, rc, "SSH channel open accept failed") != 0) + return $R_CONT(c$cont, B_None); session_drive(s); return $R_CONT(c$cont, B_None); } @@ -3339,7 +3392,14 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o ch->on_data = ($action2)on_data; ch->on_stderr = ($action2)on_stderr; ch->on_close = ($action2)on_close; - server_channel_setup_callbacks(ch); + if (server_channel_setup_callbacks(ch) != SSH_OK) { + server_channel_notify_close(ch, "Failed to set SSH channel callbacks"); + if (ch->channel != NULL) + ssh_channel_close(ch->channel); + server_channel_finalize(ch); + session_drive(s); + return $R_CONT(c$cont, B_None); + } if (ssh_debug_enabled) { ssh_debug_log("server channel new ch=%p callbacks=%p", (void *)ch, (void *)ch->callbacks); } @@ -3357,10 +3417,12 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o if (s == NULL || s->pending_channel_open == NULL) return $R_CONT(c$cont, B_None); - ssh_message_reply_default(s->pending_channel_open); + int rc = ssh_message_reply_default(s->pending_channel_open); ssh_message_free(s->pending_channel_open); s->pending_channel_open = NULL; (void)reason; + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return $R_CONT(c$cont, B_None); session_drive(s); return $R_CONT(c$cont, B_None); } @@ -3379,10 +3441,12 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o if (server_channel_validate(s, ch) != 0 || ch->pending_req == NULL) return $R_CONT(c$cont, B_None); - ssh_message_channel_request_reply_success(ch->pending_req); + int rc = ssh_message_channel_request_reply_success(ch->pending_req); ssh_message_free(ch->pending_req); ch->pending_req = NULL; ch->pending_req_type = SCHAN_REQ_NONE; + if (session_check_reply_rc(s, rc, "SSH channel request accept failed") != 0) + return $R_CONT(c$cont, B_None); session_drive(s); return $R_CONT(c$cont, B_None); } @@ -3393,11 +3457,13 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o if (server_channel_validate(s, ch) != 0 || ch->pending_req == NULL) return $R_CONT(c$cont, B_None); - ssh_message_reply_default(ch->pending_req); + int rc = ssh_message_reply_default(ch->pending_req); ssh_message_free(ch->pending_req); ch->pending_req = NULL; ch->pending_req_type = SCHAN_REQ_NONE; (void)reason; + if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0) + return $R_CONT(c$cont, B_None); session_drive(s); return $R_CONT(c$cont, B_None); } @@ -3438,7 +3504,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o if (server_channel_validate(s, ch) != 0 || ch->channel == NULL) return $R_CONT(c$cont, B_None); int rc = ssh_channel_request_send_exit_status(ch->channel, fromB_int(status)); - if (rc != SSH_OK) { + if (rc != SSH_OK && rc != SSH_AGAIN) { char errmsg[256] = {0}; snprintf(errmsg, sizeof(errmsg), "SSH server send exit status failed: %s", ssh_get_error(s->session)); server_channel_notify_close(ch, errmsg); diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index c048ca1..9f1b705 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -1514,6 +1514,19 @@ actor CrossHandleOwnershipTester(t: testing.EnvT): if server is not None: server.close() + def maybe_request_client_closes(): + if not shutdown_started: + return + if client_channel_close_count >= 2: + request_client1_close() + request_client2_close() + + def maybe_request_server_shutdown(): + if not shutdown_started: + return + if session_close_count >= 2: + request_server_close() + def begin_shutdown(): if shutdown_started: return @@ -1522,13 +1535,12 @@ actor CrossHandleOwnershipTester(t: testing.EnvT): client_channel1.close() if client_channel2 is not None: client_channel2.close() - request_client1_close() - request_client2_close() if server_session1 is not None: server_session1.close() if server_session2 is not None: server_session2.close() - request_server_close() + maybe_request_client_closes() + maybe_request_server_shutdown() def maybe_finish(): if done: @@ -1585,7 +1597,14 @@ actor CrossHandleOwnershipTester(t: testing.EnvT): def on_timeout(): if done: return - finish_error("timeout waiting for cross-handle ownership test") + finish_error( + "timeout waiting for cross-handle ownership test " + "(client_close_count=" + str(client_close_count) + + ", client_channel_close_count=" + str(client_channel_close_count) + + ", server_channel_close_count=" + str(server_channel_close_count) + + ", session_close_count=" + str(session_close_count) + + ", server_closed=" + str(server_closed) + ")" + ) def on_server_listen(s: ssh.Server, err: ?str): server = s @@ -1643,6 +1662,7 @@ actor CrossHandleOwnershipTester(t: testing.EnvT): def on_session_close(sess: ssh.ServerSession, reason: str): log.info("session closed: " + reason, None) session_close_count += 1 + maybe_request_server_shutdown() maybe_finish() def srv1_on_data(ch: ssh.ServerChannel, data: ?bytes): @@ -1777,6 +1797,7 @@ actor CrossHandleOwnershipTester(t: testing.EnvT): def ch1_close(ch: ssh.Channel, reason: str): client_channel_close_count += 1 + maybe_request_client_closes() maybe_finish() def ch2_open(ch: ssh.Channel, err: ?str): @@ -1823,6 +1844,7 @@ actor CrossHandleOwnershipTester(t: testing.EnvT): def ch2_close(ch: ssh.Channel, reason: str): client_channel_close_count += 1 + maybe_request_client_closes() maybe_finish() def start_client1(): @@ -1853,7 +1875,7 @@ actor CrossHandleOwnershipTester(t: testing.EnvT): def start_server(): attempts += 1 - port = 50000 + (random.randint(0, 2000) + attempts * 131) % 2000 + port = 60000 + (random.randint(0, 4000) + attempts * 131) % 4000 server_close_requested = False client1_close_requested = False client2_close_requested = False @@ -1872,7 +1894,7 @@ actor CrossHandleOwnershipTester(t: testing.EnvT): ) after 0: start_server() - after 5.0: on_timeout() + after 10.0: on_timeout() def _test_ssh_cross_handle_ownership(t: testing.EnvT): @@ -1928,7 +1950,11 @@ actor UnsupportedRequestTester(t: testing.EnvT): def on_timeout(): if done: return - finish_error("timeout waiting for unsupported request test") + finish_error( + "timeout waiting for unsupported request test " + "(saw_subsystem=" + str(saw_subsystem) + + ", saw_exec=" + str(saw_exec) + ")" + ) def on_server_listen(s: ssh.Server, err: ?str): server = s @@ -1998,7 +2024,6 @@ actor UnsupportedRequestTester(t: testing.EnvT): finish_error("subsystem channel open error: " + err) return ch.request_subsystem("unsupported-subsystem") - ch.close() def ch_sub_out(ch: ssh.Channel, data: ?bytes): return @@ -2022,7 +2047,6 @@ actor UnsupportedRequestTester(t: testing.EnvT): finish_error("exec channel open error: " + err) return ch.request_exec("echo denied") - ch.close() def ch_exec_out(ch: ssh.Channel, data: ?bytes): return @@ -2051,7 +2075,7 @@ actor UnsupportedRequestTester(t: testing.EnvT): def start_server(): attempts += 1 - port = 48000 + (random.randint(0, 2000) + attempts * 97) % 2000 + port = 64000 + (random.randint(0, 1000) + attempts * 97) % 1000 server_close_requested = False server = ssh.Server( net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), @@ -2068,9 +2092,420 @@ actor UnsupportedRequestTester(t: testing.EnvT): ) after 0: start_server() - after 2.0: on_timeout() + after 10.0: on_timeout() def _test_ssh_unsupported_requests(t: testing.EnvT): """Server reject_request should handle unsupported subsystem/exec.""" UnsupportedRequestTester(t) + + +actor RunCommandTimeoutTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + PARTIAL_OUT = b"partial-out" + PARTIAL_ERR = b"partial-err" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var run_exit_called = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if run_exit_called and client_closed and server_closed and session_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for RunCommand timeout test " + "(run_exit=" + str(run_exit_called) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd != "stall": + finish_error("unexpected exec command: " + cmd) + return + ch.accept_request() + ch.write(PARTIAL_OUT) + ch.write_stderr(PARTIAL_ERR) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + ssh.RunCommand(c, "stall", on_run_exit, timeout=0.1) + + def on_run_exit(ch: ssh.Channel, + code: int, + sig: ?str, + stdout: bytes, + stderr: bytes, + err: ?str): + if run_exit_called: + finish_error("RunCommand on_exit called more than once") + return + run_exit_called = True + if err != "timeout": + finish_error("expected timeout error, got " + str(err)) + return + if len(stdout) > len(PARTIAL_OUT) or PARTIAL_OUT.find(stdout) != 0: + finish_error("unexpected stdout on timeout") + return + if len(stderr) > len(PARTIAL_ERR) or PARTIAL_ERR.find(stderr) != 0: + finish_error("unexpected stderr on timeout") + return + if code != 0: + finish_error("timeout should not invent exit code") + return + if sig is not None: + finish_error("timeout should not invent exit signal") + return + request_client_close() + maybe_finish() + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 52000 + (random.randint(0, 2000) + attempts * 149) % 2000 + server_close_requested = False + client_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_runcommand_timeout(t: testing.EnvT): + """RunCommand timeout should complete exactly once with buffered output.""" + RunCommandTimeoutTester(t) + + +actor RunCommandExitStatusTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + OUT1 = b"alpha-" + OUT2 = b"beta" + ERR1 = b"warn-" + ERR2 = b"detail" + OUT = OUT1 + OUT2 + ERR = ERR1 + ERR2 + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var run_exit_called = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if run_exit_called and client_closed and server_closed and session_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for RunCommand exit status test " + "(run_exit=" + str(run_exit_called) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd != "report": + finish_error("unexpected exec command: " + cmd) + return + ch.accept_request() + ch.write(OUT1) + ch.write(OUT2) + ch.write_stderr(ERR1) + ch.write_stderr(ERR2) + ch.send_exit_status(23) + ch.send_eof() + ch.close() + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + ssh.RunCommand(c, "report", on_run_exit) + + def on_run_exit(ch: ssh.Channel, + code: int, + sig: ?str, + stdout: bytes, + stderr: bytes, + err: ?str): + if run_exit_called: + finish_error("RunCommand on_exit called more than once") + return + run_exit_called = True + if err is not None: + finish_error("unexpected RunCommand error: " + str(err)) + return + if stdout != OUT: + finish_error("unexpected stdout payload") + return + if stderr != ERR: + finish_error("unexpected stderr payload") + return + if code != 23: + finish_error("unexpected exit code: " + str(code)) + return + if sig is not None: + finish_error("unexpected exit signal: " + str(sig)) + return + request_client_close() + maybe_finish() + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 58000 + (random.randint(0, 2000) + attempts * 151) % 2000 + server_close_requested = False + client_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_runcommand_exit_status(t: testing.EnvT): + """RunCommand should collect stdout, stderr, and exit status.""" + RunCommandExitStatusTester(t) From 9ddecf1253a8d0e52167d0c938ca3171556c4049 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 11:44:54 +0100 Subject: [PATCH 19/38] Bound pre-attach lifetime Accepted server sessions were started before any ServerSession actor attached to them. That let sockets sit in a pre-attach limbo with no bounded lifetime, and a late attach could still publish a session that had already closed underneath it. This defers per-session polling until attach, starts a bounded pre-attach timeout for newly accepted sessions, and skips publishing a ServerSession when native attach finds the session already closed. The session is now either attached and driven through the normal close path, or failed before it becomes visible to user code. That removes a dead-session publication race and prevents pre-attach sessions from lingering indefinitely. --- src/ssh.act | 5 ++-- src/ssh.ext.c | 78 ++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/src/ssh.act b/src/ssh.act index 5f9ce82..a916206 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -372,8 +372,9 @@ actor ServerSession(server: Server, _debug("server session attach ready: " + str(session_id)) _attach(session_id) _debug("server session attach: _attach returned for " + str(session_id)) - server.on_session_ready(self) - _debug("server session attach: on_session_ready sent for " + str(session_id)) + if _session_id != 0: + server.on_session_ready(self) + _debug("server session attach: on_session_ready sent for " + str(session_id)) after 0: _attach_ready() diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 0cddba9..617ff6d 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -84,6 +84,7 @@ extern struct $Cont $Done$instance; #define SSH_READ_BUFSIZE 4096 #define SSH_IO_PUMP_LIMIT 128 +#define SSH_ATTACH_TIMEOUT_SEC 5.0 static int ssh_debug_enabled = 0; static int ssh_libssh_log_level = SSH_LOG_NOLOG; @@ -319,6 +320,8 @@ static void session_close_internal(ssh_server_session_ctx *s, const char *reason static void session_finalize(ssh_server_session_ctx *s); static void session_finish_close(ssh_server_session_ctx *s); static void session_fail(ssh_server_session_ctx *s, const char *msg); +static void session_start_attach_timer(ssh_server_session_ctx *s); +static int session_start_poll(ssh_server_session_ctx *s, char *errmsg, size_t errmsg_len); static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch); static int session_needs_write(ssh_server_session_ctx *s); static void session_poll_cb(uv_poll_t *handle, int status, int events); @@ -2421,6 +2424,16 @@ static int session_check_reply_rc(ssh_server_session_ctx *s, int rc, const char return -1; } +static void session_start_attach_timer(ssh_server_session_ctx *s) { + if (s == NULL || s->auth_timer != NULL) + return; + s->auth_timer = acton_calloc(1, sizeof(uv_timer_t)); + s->auth_timer->data = s; + uv_timer_init(get_uv_loop(), s->auth_timer); + uv_timer_start(s->auth_timer, session_auth_timeout_cb, + (uint64_t)(SSH_ATTACH_TIMEOUT_SEC * 1000.0), 0); +} + static void session_start_auth_timer(ssh_server_session_ctx *s) { if (s == NULL || s->auth_timeout <= 0.0 || s->auth_timer != NULL) return; @@ -2452,11 +2465,41 @@ static void session_auth_timeout_cb(uv_timer_t *timer) { ssh_server_session_ctx *s = (ssh_server_session_ctx *)timer->data; if (s == NULL) return; + if (!s->attached && s->state == SESSION_STATE_KEYEX) { + session_fail(s, "SSH session attach timeout"); + return; + } if (s->state == SESSION_STATE_AUTH) { session_fail(s, "SSH authentication timeout"); } } +static int session_start_poll(ssh_server_session_ctx *s, char *errmsg, size_t errmsg_len) { + if (s == NULL || s->session == NULL || s->fd < 0) { + snprintf(errmsg, errmsg_len, "Failed to start SSH session poll"); + return -1; + } + if (s->poll != NULL) + return 0; + + s->poll = acton_calloc(1, sizeof(uv_poll_t)); + s->poll->data = s; + int uv_rc = uv_poll_init(get_uv_loop(), s->poll, s->fd); + if (uv_rc != 0) { + uv_strerror_r(uv_rc, errmsg, errmsg_len); + acton_free(s->poll); + s->poll = NULL; + return -1; + } + s->poll_events = UV_READABLE | UV_WRITABLE; + uv_rc = uv_poll_start(s->poll, s->poll_events, session_poll_cb); + if (uv_rc != 0) { + uv_strerror_r(uv_rc, errmsg, errmsg_len); + return -1; + } + return 0; +} + static void session_keepalive_cb(uv_timer_t *timer) { ssh_server_session_ctx *s = (ssh_server_session_ctx *)timer->data; if (s == NULL || s->session == NULL) @@ -2888,30 +2931,13 @@ static void server_accept(ssh_server_ctx *s) { sess->state = SESSION_STATE_KEYEX; sess->fd = ssh_get_fd(session); sess->owner_wt = s->actor ? (int)s->actor->$affinity : 0; + sess->auth_timeout = s->actor ? fromB_float(((sshQ_Server)s->actor)->_auth_timeout) : 0.0; if (sess->fd < 0) { ssh_free(session); server_fail(s, "Failed to get SSH session fd"); return; } - sess->poll = acton_calloc(1, sizeof(uv_poll_t)); - sess->poll->data = sess; - int uv_rc = uv_poll_init(get_uv_loop(), sess->poll, sess->fd); - if (uv_rc != 0) { - char errmsg[256] = {0}; - uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); - ssh_free(session); - server_fail(s, errmsg); - return; - } - sess->poll_events = UV_READABLE | UV_WRITABLE; - uv_rc = uv_poll_start(sess->poll, sess->poll_events, session_poll_cb); - if (uv_rc != 0) { - char errmsg[256] = {0}; - uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); - ssh_free(session); - server_fail(s, errmsg); - return; - } + session_start_attach_timer(sess); sess->next = s->sessions; s->sessions = sess; @@ -3290,9 +3316,23 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o ssh_server_session_ctx *s = (ssh_server_session_ctx *)(unsigned long)fromB_u64(session_id); if (s == NULL) return $R_CONT(c$cont, B_None); + if (s->session == NULL || s->state == SESSION_STATE_CLOSED || + s->state == SESSION_STATE_CLOSING || s->state == SESSION_STATE_ERROR) { + if (ssh_debug_enabled) { + ssh_debug_log("server session attach skipped closed id=%llu", + (unsigned long long)fromB_u64(session_id)); + } + return $R_CONT(c$cont, B_None); + } if (ssh_debug_enabled) { ssh_debug_log("server session attach: id=%llu", (unsigned long long)fromB_u64(session_id)); } + char errmsg[256] = {0}; + if (session_start_poll(s, errmsg, sizeof(errmsg)) != 0) { + session_close_internal(s, errmsg, 1); + return $R_CONT(c$cont, B_None); + } + stop_timer(&s->auth_timer, session_timer_close_cb); s->actor = self; self->_session_id = session_id; s->attached = 1; From d91b9fd22b066304bc1b20aa748210b1c4779f3b Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 11:59:49 +0100 Subject: [PATCH 20/38] Separate attach and auth timers The pre-attach timeout reused the session auth timer slot. Attach stopped that timer, but uv_close is asynchronous, so the pointer stayed non-null long enough for key exchange to reach AUTH with no real auth timeout armed for the session. This gives the pre-attach guard its own timer handle, derives its budget from server auth policy when available, and adds a regression test that stalls hostkey acceptance until the server must time out authentication after attach. Attach timeout and auth timeout now have independent lifetimes, so stopping one can no longer mask the other. Sessions either attach within the configured budget or continue into AUTH with a real auth timeout still enforced. --- src/ssh.ext.c | 19 +++-- src/test_ssh_server.act | 151 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 7 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 617ff6d..7050ea8 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -259,6 +259,7 @@ typedef struct ssh_server_session_ctx { ssh_session session; uv_poll_t *poll; int poll_events; + uv_timer_t *attach_timer; uv_timer_t *auth_timer; uv_timer_t *keepalive_timer; double auth_timeout; @@ -623,6 +624,8 @@ static void session_timer_close_cb(uv_handle_t *handle) { ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; if (s == NULL) return; + if ((uv_timer_t *)handle == s->attach_timer) + s->attach_timer = NULL; if ((uv_timer_t *)handle == s->auth_timer) s->auth_timer = NULL; if ((uv_timer_t *)handle == s->keepalive_timer) @@ -2425,13 +2428,14 @@ static int session_check_reply_rc(ssh_server_session_ctx *s, int rc, const char } static void session_start_attach_timer(ssh_server_session_ctx *s) { - if (s == NULL || s->auth_timer != NULL) + if (s == NULL || s->attach_timer != NULL) return; - s->auth_timer = acton_calloc(1, sizeof(uv_timer_t)); - s->auth_timer->data = s; - uv_timer_init(get_uv_loop(), s->auth_timer); - uv_timer_start(s->auth_timer, session_auth_timeout_cb, - (uint64_t)(SSH_ATTACH_TIMEOUT_SEC * 1000.0), 0); + double timeout = s->auth_timeout > 0.0 ? s->auth_timeout : SSH_ATTACH_TIMEOUT_SEC; + s->attach_timer = acton_calloc(1, sizeof(uv_timer_t)); + s->attach_timer->data = s; + uv_timer_init(get_uv_loop(), s->attach_timer); + uv_timer_start(s->attach_timer, session_auth_timeout_cb, + (uint64_t)(timeout * 1000.0), 0); } static void session_start_auth_timer(ssh_server_session_ctx *s) { @@ -3133,6 +3137,7 @@ static void session_close_internal(ssh_server_session_ctx *s, const char *reason } stop_timer(&s->auth_timer, session_timer_close_cb); + stop_timer(&s->attach_timer, session_timer_close_cb); stop_timer(&s->keepalive_timer, session_timer_close_cb); s->state = SESSION_STATE_CLOSING; @@ -3332,7 +3337,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o session_close_internal(s, errmsg, 1); return $R_CONT(c$cont, B_None); } - stop_timer(&s->auth_timer, session_timer_close_cb); + stop_timer(&s->attach_timer, session_timer_close_cb); s->actor = self; self->_session_id = session_id; s->attached = 1; diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 9f1b705..d51cdd5 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -2509,3 +2509,154 @@ actor RunCommandExitStatusTester(t: testing.EnvT): def _test_ssh_runcommand_exit_status(t: testing.EnvT): """RunCommand should collect stdout, stderr, and exit status.""" RunCommandExitStatusTester(t) + + +actor AuthTimeoutTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var client_done = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + + def maybe_finish(): + if done: + return + if client_done and server_closed and session_closed: + done = True + t.success() + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + request_client_close() + request_server_close() + done = True + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error("timeout waiting for auth timeout test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + finish_error("unexpected auth callback") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if reason != "SSH authentication timeout": + finish_error("unexpected session close reason: " + reason) + return + session_closed = True + request_server_close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + log.info("hostkey state: " + state, None) + return + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + if err.find("Received SSH_MSG_DISCONNECT") >= 0: + client_done = True + maybe_finish() + return + finish_error("client connect error: " + err) + return + client = c + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_done = True + if session_closed: + request_server_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 56000 + (random.randint(0, 2000) + attempts * 157) % 2000 + server_close_requested = False + client_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + auth_timeout=0.2, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_auth_timeout(t: testing.EnvT): + """Server auth timeout should still fire after attach.""" + AuthTimeoutTester(t) From 37f1bf51ea8864ac6bfb2d287fb8d501dfc0f67e Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 12:09:30 +0100 Subject: [PATCH 21/38] Harden poll close handling Long stress runs drove RSS into multi-gigabyte territory while file descriptor counts stayed flat. The poll and timer handles were heap-allocated for every client and session, but the libuv close callbacks only cleared pointers and never released the handle memory. Disconnect-only readiness could also be dropped before reaching libssh because hangups were filtered behind readable-byte checks. This frees poll and timer handles from their close callbacks, treats POLLHUP and POLLERR as read-side work, and forwards UV_DISCONNECT into the libssh poll mapping instead of discarding it. Handle lifetime now matches the allocation path, so repeated open and close churn does not retain closed libuv objects. Peer hangups also reach libssh promptly even when no additional bytes are readable, which makes close detection less timing-sensitive under load. --- src/ssh.ext.c | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 7050ea8..78d3eb3 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -517,6 +517,8 @@ static int fd_has_data(int fd) { return 0; if (pfd.revents & POLLIN) return 1; + if (pfd.revents & (POLLHUP | POLLERR)) + return 1; if (pfd.revents & POLLNVAL) return 0; return 0; @@ -573,8 +575,10 @@ static int session_has_pending_write(ssh_session session) { static void client_poll_close_cb(uv_handle_t *handle) { ssh_client_ctx *c = (ssh_client_ctx *)handle->data; - if (c == NULL) + if (c == NULL) { + acton_free(handle); return; + } if (c->poll == (uv_poll_t *)handle) { c->poll = NULL; c->poll_events = 0; @@ -582,35 +586,44 @@ static void client_poll_close_cb(uv_handle_t *handle) { if (c->state == CLIENT_STATE_CLOSING) { client_finalize(c); } + acton_free(handle); } static void client_timer_close_cb(uv_handle_t *handle) { ssh_client_ctx *c = (ssh_client_ctx *)handle->data; - if (c == NULL) + if (c == NULL) { + acton_free(handle); return; + } if ((uv_timer_t *)handle == c->connect_timer) c->connect_timer = NULL; if ((uv_timer_t *)handle == c->auth_timer) c->auth_timer = NULL; if ((uv_timer_t *)handle == c->keepalive_timer) c->keepalive_timer = NULL; + acton_free(handle); } static void server_poll_close_cb(uv_handle_t *handle) { ssh_server_ctx *s = (ssh_server_ctx *)handle->data; - if (s == NULL) + if (s == NULL) { + acton_free(handle); return; + } if (s->poll == (uv_poll_t *)handle) s->poll = NULL; if (s->state == SERVER_STATE_CLOSING) { server_finalize(s); } + acton_free(handle); } static void session_poll_close_cb(uv_handle_t *handle) { ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; - if (s == NULL) + if (s == NULL) { + acton_free(handle); return; + } if (s->poll == (uv_poll_t *)handle) { s->poll = NULL; s->poll_events = 0; @@ -618,18 +631,22 @@ static void session_poll_close_cb(uv_handle_t *handle) { if (s->state == SESSION_STATE_CLOSING) { session_finalize(s); } + acton_free(handle); } static void session_timer_close_cb(uv_handle_t *handle) { ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; - if (s == NULL) + if (s == NULL) { + acton_free(handle); return; + } if ((uv_timer_t *)handle == s->attach_timer) s->attach_timer = NULL; if ((uv_timer_t *)handle == s->auth_timer) s->auth_timer = NULL; if ((uv_timer_t *)handle == s->keepalive_timer) s->keepalive_timer = NULL; + acton_free(handle); } static void client_notify_connect(ssh_client_ctx *c, const char *err) { @@ -1290,6 +1307,12 @@ static void poll_cb(uv_poll_t *handle, int status, int events) { ssh_set_fd_toread(c->session); libssh_events |= UV_READABLE; } +#ifdef UV_DISCONNECT + if (events & UV_DISCONNECT) { + ssh_set_fd_toread(c->session); + libssh_events |= UV_DISCONNECT; + } +#endif if ((events & UV_WRITABLE) && fd_can_write(c->fd)) { c->write_ready = 1; ssh_set_fd_towrite(c->session); @@ -2849,6 +2872,12 @@ static void session_poll_cb(uv_poll_t *handle, int status, int events) { ssh_set_fd_toread(s->session); libssh_events |= UV_READABLE; } +#ifdef UV_DISCONNECT + if (events & UV_DISCONNECT) { + ssh_set_fd_toread(s->session); + libssh_events |= UV_DISCONNECT; + } +#endif if ((events & UV_WRITABLE) && fd_can_write(s->fd)) { s->write_ready = 1; ssh_set_fd_towrite(s->session); From 747126fd55ffc32340f870f46521925141ba6356 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 12:11:28 +0100 Subject: [PATCH 22/38] Invalidate channels before close callbacks Channel close callbacks were delivered after the native ssh_channel had already been freed, but before the actor handle was invalidated. That left a re-entrancy window where on_close or EOF handlers could call back into write or close on a channel whose native state was already gone. This marks client and server channels closed and clears their exported actor ids before delivering exit, EOF, and close notifications from the finalize path. Callbacks still observe the expected terminal notifications, but any re-entrant channel operation now resolves to a dead handle instead of a half-finalized native channel. That removes a latent teardown footgun without changing normal close ordering. --- src/ssh.ext.c | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 78d3eb3..d7e80f4 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -834,6 +834,9 @@ static void channel_notify_eof(ssh_channel_ctx *ch) { } static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { + int exit_status = -1; + B_str exit_signal = B_None; + while (ch->write_head != NULL) { write_chunk_t *chunk = ch->write_head; ch->write_head = chunk->next; @@ -841,8 +844,6 @@ static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { } ch->write_tail = NULL; if (ch->channel != NULL) { - int exit_status = -1; - B_str exit_signal = B_None; if (ssh_channel_is_closed(ch->channel)) { uint32_t exit_code = 0; char *signal = NULL; @@ -865,9 +866,12 @@ static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { } ssh_channel_free(ch->channel); ch->channel = NULL; - } else { - channel_notify_exit(ch, -1, B_None); } + ch->state = CHAN_STATE_CLOSED; + if (ch->actor) + ch->actor->_channel_id = toB_u64(0); + + channel_notify_exit(ch, exit_status, exit_signal); if (!ch->stdout_eof && ch->on_stdout) { $action2 f = ($action2)ch->on_stdout; f->$class->__asyn__(f, ch->actor, B_None); @@ -879,9 +883,6 @@ static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { ch->stderr_eof = 1; } channel_notify_close(ch, "closed"); - if (ch->actor) - ch->actor->_channel_id = toB_u64(0); - ch->state = CHAN_STATE_CLOSED; (void)c; } @@ -2224,6 +2225,9 @@ static void server_channel_finalize(ssh_server_channel_ctx *ch) { ssh_channel_free(ch->channel); ch->channel = NULL; } + ch->state = SCHAN_STATE_CLOSED; + if (ch->actor) + ch->actor->_channel_id = toB_u64(0); if (!ch->stdout_eof && ch->on_data) { $action2 f = ($action2)ch->on_data; f->$class->__asyn__(f, ch->actor, B_None); @@ -2235,9 +2239,6 @@ static void server_channel_finalize(ssh_server_channel_ctx *ch) { ch->stderr_eof = 1; } server_channel_notify_close(ch, "closed"); - if (ch->actor) - ch->actor->_channel_id = toB_u64(0); - ch->state = SCHAN_STATE_CLOSED; } static void server_channel_queue_write(ssh_server_channel_ctx *ch, B_bytes data, int is_stderr) { From 61d9da08507947ad1a90464588d42a48b498abd3 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 15:26:32 +0100 Subject: [PATCH 23/38] Tighten session edge handling A channel could still accept writes after its parent client or server session had left the ready state, which left a gap where late payloads could be queued during shutdown. Attached server sessions could also sit in key exchange without a real timeout after the attach timer stopped, and disconnect-only poll edges were not subscribed consistently. This change makes channel validation require a ready parent session, starts the normal auth timeout as soon as an attached session begins running, and treats attached key exchange timeout separately from the pre-attach timeout. It also subscribes session polls to disconnect notifications and makes the readiness probe treat EOF and reset as terminal readable state. These checks align the binding with the actual session state machine. Late writes are rejected once shutdown starts, attached sessions cannot stall indefinitely before authentication, and close paths see terminal socket state promptly instead of depending on another read event. --- src/ssh.ext.c | 40 ++- src/test_ssh_server.act | 654 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 682 insertions(+), 12 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index d7e80f4..902c40e 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -505,23 +505,22 @@ static void stop_timer(uv_timer_t **timer, uv_close_cb close_cb) { static int fd_has_data(int fd) { if (fd < 0) return 0; - struct pollfd pfd; - pfd.fd = fd; - pfd.events = POLLIN; - pfd.revents = 0; - int rc; + char byte; + ssize_t rc; do { - rc = poll(&pfd, 1, 0); + rc = recv(fd, &byte, 1, MSG_PEEK | MSG_DONTWAIT); } while (rc < 0 && errno == EINTR); - if (rc <= 0) - return 0; - if (pfd.revents & POLLIN) + if (rc > 0) return 1; - if (pfd.revents & (POLLHUP | POLLERR)) + if (rc == 0) return 1; - if (pfd.revents & POLLNVAL) + if (errno == EAGAIN || errno == EWOULDBLOCK) return 0; - return 0; + if (errno == ECONNRESET || errno == ECONNABORTED || errno == ENOTCONN) + return 1; + if (errno == EBADF || errno == ENOTSOCK || errno == EINVAL) + return 0; + return 1; } static int fd_can_write(int fd) { @@ -1349,6 +1348,9 @@ static void client_update_poll(ssh_client_ctx *c) { int flags = ssh_get_poll_flags(c->session); int pending = flags | status; int events = UV_READABLE; +#ifdef UV_DISCONNECT + events |= UV_DISCONNECT; +#endif if (pending & SSH_WRITE_PENDING) events |= UV_WRITABLE; if ((events & UV_WRITABLE) == 0 && client_needs_write(c)) @@ -1894,6 +1896,9 @@ static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) { if (c == NULL || ch == NULL) { return 1; } + if (c->state != CLIENT_STATE_READY) { + return 1; + } if (ch->client != c) { if (ssh_debug_enabled) { ssh_debug_log("client channel ownership mismatch: client=%p owner=%p ch=%p", @@ -1911,6 +1916,9 @@ static int server_channel_validate(ssh_server_session_ctx *s, ssh_server_channel if (s == NULL || ch == NULL) { return -1; } + if (s->state != SESSION_STATE_READY) { + return -1; + } if (ch->session != s) { if (ssh_debug_enabled) { ssh_debug_log("server channel ownership mismatch: session=%p owner=%p ch=%p", @@ -2497,6 +2505,10 @@ static void session_auth_timeout_cb(uv_timer_t *timer) { session_fail(s, "SSH session attach timeout"); return; } + if (s->attached && s->state == SESSION_STATE_KEYEX) { + session_fail(s, "SSH key exchange timeout"); + return; + } if (s->state == SESSION_STATE_AUTH) { session_fail(s, "SSH authentication timeout"); } @@ -2567,6 +2579,9 @@ static void session_update_poll(ssh_server_session_ctx *s) { int flags = ssh_get_poll_flags(s->session); int pending = flags | status; int events = UV_READABLE; +#ifdef UV_DISCONNECT + events |= UV_DISCONNECT; +#endif if (pending & SSH_WRITE_PENDING) events |= UV_WRITABLE; if ((events & UV_WRITABLE) == 0 && session_needs_write(s)) @@ -3379,6 +3394,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o s->auth_timeout = fromB_float(self->server->_auth_timeout); s->keepalive_interval = fromB_float(self->server->_keepalive_interval); s->keepalive_enabled = fromB_bool(self->server->_keepalive_enabled) ? 1 : 0; + session_start_auth_timer(s); if (ssh_debug_enabled) { ssh_debug_log("server session attach: callbacks set, driving session"); } diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index d51cdd5..43418a3 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -2660,3 +2660,657 @@ actor AuthTimeoutTester(t: testing.EnvT): def _test_ssh_auth_timeout(t: testing.EnvT): """Server auth timeout should still fire after attach.""" AuthTimeoutTester(t) + + +actor CloseCallbackReentryTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + REPLY = b"close-ok" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var session: ?ssh.ServerSession = None + var client_channel: ?ssh.Channel = None + var saw_reply = False + var client_exit_count = 0 + var client_channel_close_count = 0 + var client_close_count = 0 + var server_channel_close_count = 0 + var session_close_count = 0 + var server_close_count = 0 + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if saw_reply and client_exit_count == 1 and client_channel_close_count == 1 and \ + client_close_count == 1 and server_channel_close_count == 1 and \ + session_close_count == 1 and server_close_count == 1: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for close callback reentry test " + "(reply=" + str(saw_reply) + + ", client_exit_count=" + str(client_exit_count) + + ", client_channel_close_count=" + str(client_channel_close_count) + + ", client_close_count=" + str(client_close_count) + + ", server_channel_close_count=" + str(server_channel_close_count) + + ", session_close_count=" + str(session_close_count) + + ", server_close_count=" + str(server_close_count) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_close_count += 1 + if server_close_count > 1: + finish_error("server close callback called more than once") + return + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + session = sess + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd != "reentry": + finish_error("unexpected exec command: " + cmd) + return + session = sess + ch.accept_request() + ch.write(REPLY) + ch.send_exit_status(7) + ch.send_eof() + ch.close() + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_close_count += 1 + if session_close_count > 1: + finish_error("session close callback called more than once") + return + sess.close() + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected data") + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected stderr") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_close_count += 1 + if server_channel_close_count > 1: + finish_error("server channel close callback called more than once") + return + ch.write(b"after-server-close") + ch.send_eof() + ch.close() + if session is not None: + session.close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_close_count += 1 + if client_close_count > 1: + finish_error("client close callback called more than once") + return + c.close() + maybe_finish() + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + client_channel = ch + ch.request_exec("reentry") + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if saw_reply: + finish_error("client received duplicate reply") + return + if data != REPLY: + finish_error("unexpected reply payload: " + str(data)) + return + saw_reply = True + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client received unexpected stderr") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + client_exit_count += 1 + if client_exit_count > 1: + finish_error("client exit callback called more than once") + return + if code != 7: + finish_error("unexpected exit code: " + str(code)) + return + if sig is not None: + finish_error("unexpected exit signal: " + str(sig)) + return + maybe_finish() + + def ch_close(ch: ssh.Channel, reason: str): + log.info("client channel closed: " + reason, None) + client_channel_close_count += 1 + if client_channel_close_count > 1: + finish_error("client channel close callback called more than once") + return + ch.write(b"after-client-close") + ch.send_eof() + ch.close() + request_client_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 54000 + (random.randint(0, 3000) + attempts * 163) % 3000 + server_close_requested = False + client_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_close_callback_reentry(t: testing.EnvT): + """Close callbacks should tolerate reentrant close calls.""" + CloseCallbackReentryTester(t) + + +actor KeyExchangeTimeoutTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var raw_client: ?net.TCPConnection = None + var raw_connected = False + var raw_bytes_in = 0 + var session_ready = False + var session_closed = False + var server_closed = False + var raw_close_requested = False + var server_close_requested = False + + def request_raw_close(): + if raw_close_requested: + return + raw_close_requested = True + if raw_client is not None: + raw_client.close(on_raw_close) + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if raw_connected and session_ready and session_closed and server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_raw_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for key exchange timeout test " + "(raw_connected=" + str(raw_connected) + + ", session_ready=" + str(session_ready) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + + ", raw_bytes_in=" + str(raw_bytes_in) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_raw_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session_ready = True + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + finish_error("unexpected auth request on raw TCP idle client") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open on raw TCP idle client") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request on raw TCP idle client") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request on raw TCP idle client") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if reason != "SSH key exchange timeout": + finish_error("unexpected session close reason: " + reason) + return + session_closed = True + request_raw_close() + request_server_close() + maybe_finish() + + def on_raw_connect(c: net.TCPConnection): + raw_client = c + raw_connected = True + + def on_raw_receive(c: net.TCPConnection, payload: bytes): + raw_bytes_in += len(payload) + + def on_raw_error(c: net.TCPConnection, errmsg: str): + finish_error("unexpected raw client error: " + errmsg) + + def on_raw_remote_close(c: net.TCPConnection): + return + + def on_raw_close(c: net.TCPConnection): + return + + def start_raw_client(): + raw_close_requested = False + raw_connected = False + raw_bytes_in = 0 + raw_client = net.TCPConnection( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + port, + on_raw_connect, + on_raw_receive, + on_raw_error, + on_raw_remote_close, + ) + + def start_server(): + attempts += 1 + port = 53000 + (random.randint(0, 3000) + attempts * 179) % 3000 + server_close_requested = False + server_closed = False + session_ready = False + session_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + auth_timeout=0.2, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_key_exchange_timeout(t: testing.EnvT): + """Attached sessions should time out during key exchange.""" + KeyExchangeTimeoutTester(t) + + +actor CloseRejectsLateWritesTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + START = b"start" + ACK = b"ack" + LATE = b"late-data" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var session: ?ssh.ServerSession = None + var client_channel: ?ssh.Channel = None + var server_buf = b"" + var ack_received = False + var shutdown_started = False + var client_closed = False + var server_channel_closed = False + var session_closed = False + var server_closed = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if ack_received and client_closed and server_channel_closed and \ + session_closed and server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def late_write_burst(remaining: int): + if done or remaining <= 0: + return + if client_channel is not None: + client_channel.write(LATE) + after 0: late_write_burst(remaining - 1) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for late-write close test " + "(ack_received=" + str(ack_received) + + ", client_closed=" + str(client_closed) + + ", server_channel_closed=" + str(server_channel_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + session = sess + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "echo": + finish_error("unexpected subsystem request: " + name) + return + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_buf += data + if server_buf.find(LATE) >= 0: + finish_error("late write reached server after close started") + return + if not ack_received and server_buf.find(START) >= 0: + ch.write(ACK) + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected stderr") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_closed = True + if session is not None: + session.close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + client_channel = ch + ch.request_subsystem("echo") + ch.write(START) + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + if data != ACK: + finish_error("unexpected client payload: " + str(data)) + return + if ack_received: + finish_error("client received duplicate ACK") + return + ack_received = True + if not shutdown_started: + shutdown_started = True + request_client_close() + after 0: late_write_burst(32) + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client received unexpected stderr") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + log.info("client channel closed: " + reason, None) + request_client_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 3000) + attempts * 181) % 3000 + server_close_requested = False + client_close_requested = False + ack_received = False + shutdown_started = False + client_closed = False + server_channel_closed = False + session_closed = False + server_closed = False + server_buf = b"" + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_close_rejects_late_writes(t: testing.EnvT): + """Channel writes should be ignored once parent close starts.""" + CloseRejectsLateWritesTester(t) From 949d851e8fb91a110815da8c25ed178a9d8a2313 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 15:48:40 +0100 Subject: [PATCH 24/38] Tighten native SSH teardown Native SSH contexts and channel structs were finalized logically but their wrapper allocations stayed live until process exit. That showed up as steady RSS growth under stress even with flat fd counts. The client also treated a channel close as a terminal request result even when an exec request was still pending, so a rejected request could race with close and either look successful or poison sibling channels. This change defers releasing client, server, and session contexts until their libuv handles are actually closed, frees channel structs when they leave the active lists, and drops stored close reasons during final release. It also keeps a remote close from resolving a pending client request until teardown knows whether the request was ever answered, and reports a closed channel as a request failure only if the request remains unresolved at finalization. This is correct because the native wrapper lifetime now matches the lifetime of the libuv handles that still reference it, and request completion is no longer inferred from close ordering alone. Rejected or aborted commands fail deterministically, while sibling channels and long stress runs no longer accumulate unreachable native state. --- src/ssh.ext.c | 58 ++++ src/test_ssh_server.act | 593 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 651 insertions(+) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 902c40e..42f07fe 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -306,6 +306,7 @@ static void client_update_poll(ssh_client_ctx *c); static void client_close_internal(ssh_client_ctx *c, const char *reason, int force_close); static void client_finalize(ssh_client_ctx *c); static void client_finish_close(ssh_client_ctx *c); +static void client_maybe_release(ssh_client_ctx *c); static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch); static int client_needs_write(ssh_client_ctx *c); static void client_pump_io(ssh_client_ctx *c); @@ -314,12 +315,14 @@ static void server_accept(ssh_server_ctx *s); static void server_close_internal(ssh_server_ctx *s, const char *reason); static void server_finalize(ssh_server_ctx *s); static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *sess); +static void server_maybe_release(ssh_server_ctx *s); static void session_drive(ssh_server_session_ctx *s); static void session_update_poll(ssh_server_session_ctx *s); static void session_pump_io(ssh_server_session_ctx *s); static void session_close_internal(ssh_server_session_ctx *s, const char *reason, int force_close); static void session_finalize(ssh_server_session_ctx *s); static void session_finish_close(ssh_server_session_ctx *s); +static void session_maybe_release(ssh_server_session_ctx *s); static void session_fail(ssh_server_session_ctx *s, const char *msg); static void session_start_attach_timer(ssh_server_session_ctx *s); static int session_start_poll(ssh_server_session_ctx *s, char *errmsg, size_t errmsg_len); @@ -600,6 +603,7 @@ static void client_timer_close_cb(uv_handle_t *handle) { c->auth_timer = NULL; if ((uv_timer_t *)handle == c->keepalive_timer) c->keepalive_timer = NULL; + client_maybe_release(c); acton_free(handle); } @@ -645,9 +649,48 @@ static void session_timer_close_cb(uv_handle_t *handle) { s->auth_timer = NULL; if ((uv_timer_t *)handle == s->keepalive_timer) s->keepalive_timer = NULL; + session_maybe_release(s); acton_free(handle); } +static void client_maybe_release(ssh_client_ctx *c) { + if (c == NULL || !c->close_finalized) + return; + if (c->poll != NULL || c->connect_timer != NULL || + c->auth_timer != NULL || c->keepalive_timer != NULL) + return; + if (c->close_reason != NULL) { + acton_free(c->close_reason); + c->close_reason = NULL; + } + acton_free(c); +} + +static void server_maybe_release(ssh_server_ctx *s) { + if (s == NULL || !s->close_finalized) + return; + if (s->poll != NULL || s->sessions != NULL) + return; + if (s->close_reason != NULL) { + acton_free(s->close_reason); + s->close_reason = NULL; + } + acton_free(s); +} + +static void session_maybe_release(ssh_server_session_ctx *s) { + if (s == NULL || !s->close_finalized) + return; + if (s->poll != NULL || s->attach_timer != NULL || + s->auth_timer != NULL || s->keepalive_timer != NULL) + return; + if (s->close_reason != NULL) { + acton_free(s->close_reason); + s->close_reason = NULL; + } + acton_free(s); +} + static void client_notify_connect(ssh_client_ctx *c, const char *err) { if (c == NULL) return; @@ -775,6 +818,9 @@ static void client_channel_close_cb(ssh_session session, ssh_channel channel, vo if (ch == NULL) return; ch->remote_close_seen = 1; + if (ch->pending_req != CHAN_REQ_NONE) { + return; + } channel_notify_close(ch, "closed"); } @@ -842,6 +888,10 @@ static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { acton_free(chunk); } ch->write_tail = NULL; + if (ch->pending_req != CHAN_REQ_NONE) { + ch->pending_req = CHAN_REQ_NONE; + channel_notify_error(ch, "SSH channel request failed: channel closed"); + } if (ch->channel != NULL) { if (ssh_channel_is_closed(ch->channel)) { uint32_t exit_code = 0; @@ -1140,6 +1190,7 @@ static void client_drive_channels(ssh_client_ctx *c) { c->channels = next; } ch->next = NULL; + acton_free(ch); } else { prev = ch; } @@ -1562,6 +1613,7 @@ static void client_finalize(ssh_client_ctx *c) { c->state = CLIENT_STATE_CLOSED; if (c->actor) c->actor->_client = toB_u64(0); + client_maybe_release(c); } static void client_abort_channels(ssh_client_ctx *c, int notify_channel_error) { @@ -1573,6 +1625,7 @@ static void client_abort_channels(ssh_client_ctx *c, int notify_channel_error) { channel_notify_eof(ch); channel_finalize(c, ch); ch->next = NULL; + acton_free(ch); ch = next; } c->channels = NULL; @@ -2435,6 +2488,7 @@ static void session_drive_channels(ssh_server_session_ctx *s) { s->channels = next; } ch->next = NULL; + acton_free(ch); } else { prev = ch; } @@ -3024,6 +3078,7 @@ static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *ses prev = cur; cur = cur->next; } + server_maybe_release(s); } static void server_finalize(ssh_server_ctx *s) { @@ -3044,6 +3099,7 @@ static void server_finalize(ssh_server_ctx *s) { s->state = SERVER_STATE_CLOSED; if (s->actor) s->actor->_server = toB_u64(0); + server_maybe_release(s); } static void session_finalize(ssh_server_session_ctx *s) { @@ -3062,6 +3118,7 @@ static void session_finalize(ssh_server_session_ctx *s) { s->state = SESSION_STATE_CLOSED; if (s->actor) s->actor->_session_id = toB_u64(0); + session_maybe_release(s); } static void session_reject_pending_messages(ssh_server_session_ctx *s) { @@ -3084,6 +3141,7 @@ static void session_abort_channels(ssh_server_session_ctx *s) { server_channel_notify_close(ch, "Session closed"); server_channel_finalize(ch); ch->next = NULL; + acton_free(ch); ch = next; } s->channels = NULL; diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 43418a3..23c8391 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -2511,6 +2511,599 @@ def _test_ssh_runcommand_exit_status(t: testing.EnvT): RunCommandExitStatusTester(t) +actor ConcurrentRunCommandTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + NUM_CMDS = 4 + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_channel_closes = 0 + var run_exits = 0 + var client_closed = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + var run0_done = False + var run1_done = False + var run2_done = False + var run3_done = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def out1(cmd: str) -> bytes: + if cmd == "report0": + return b"out0-a:" * 512 + if cmd == "report1": + return b"out1-a:" * 512 + if cmd == "report2": + return b"out2-a:" * 512 + return b"out3-a:" * 512 + + def out2(cmd: str) -> bytes: + if cmd == "report0": + return b"out0-b:" * 512 + if cmd == "report1": + return b"out1-b:" * 512 + if cmd == "report2": + return b"out2-b:" * 512 + return b"out3-b:" * 512 + + def err1(cmd: str) -> bytes: + if cmd == "report0": + return b"err0-a:" * 256 + if cmd == "report1": + return b"err1-a:" * 256 + if cmd == "report2": + return b"err2-a:" * 256 + return b"err3-a:" * 256 + + def err2(cmd: str) -> bytes: + if cmd == "report0": + return b"err0-b:" * 256 + if cmd == "report1": + return b"err1-b:" * 256 + if cmd == "report2": + return b"err2-b:" * 256 + return b"err3-b:" * 256 + + def full_out(cmd: str) -> bytes: + return out1(cmd) + out2(cmd) + + def full_err(cmd: str) -> bytes: + return err1(cmd) + err2(cmd) + + def exit_code(cmd: str) -> int: + if cmd == "report0": + return 30 + if cmd == "report1": + return 31 + if cmd == "report2": + return 32 + return 33 + + def maybe_finish(): + if done: + return + if run_exits == NUM_CMDS and server_channel_closes == NUM_CMDS and \ + client_closed and server_closed and session_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def mark_run_done(cmd: str): + if cmd == "report0": + if run0_done: + finish_error("RunCommand report0 completed more than once") + return + run0_done = True + return + if cmd == "report1": + if run1_done: + finish_error("RunCommand report1 completed more than once") + return + run1_done = True + return + if cmd == "report2": + if run2_done: + finish_error("RunCommand report2 completed more than once") + return + run2_done = True + return + if cmd == "report3": + if run3_done: + finish_error("RunCommand report3 completed more than once") + return + run3_done = True + return + finish_error("unexpected RunCommand completion for " + cmd) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for concurrent RunCommand test " + "(run_exits=" + str(run_exits) + + ", server_channel_closes=" + str(server_channel_closes) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def send_report(ch: ssh.ServerChannel, cmd: str, exit_first: bool): + if exit_first: + ch.send_exit_status(exit_code(cmd)) + ch.write(out1(cmd)) + ch.write_stderr(err1(cmd)) + ch.write(out2(cmd)) + ch.write_stderr(err2(cmd)) + if not exit_first: + ch.send_exit_status(exit_code(cmd)) + ch.send_eof() + ch.close() + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd != "report0" and cmd != "report1" and cmd != "report2" and cmd != "report3": + finish_error("unexpected exec command: " + cmd) + return + ch.accept_request() + if cmd == "report0" or cmd == "report2": + send_report(ch, cmd, True) + else: + send_report(ch, cmd, False) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + server_channel_closes += 1 + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + start_run(c, 0) + start_run(c, 1) + start_run(c, 2) + start_run(c, 3) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def start_run(c: ssh.Client, idx: int): + cmd = "report" + str(idx) + + def on_run_exit(ch: ssh.Channel, + code: int, + sig: ?str, + stdout: bytes, + stderr: bytes, + err: ?str): + if err is not None: + finish_error("unexpected RunCommand error for " + cmd + ": " + str(err)) + return + if stdout != full_out(cmd): + finish_error("unexpected stdout payload for " + cmd) + return + if stderr != full_err(cmd): + finish_error("unexpected stderr payload for " + cmd) + return + if code != exit_code(cmd): + finish_error("unexpected exit code for " + cmd + ": " + str(code)) + return + if sig is not None: + finish_error("unexpected exit signal for " + cmd + ": " + str(sig)) + return + mark_run_done(cmd) + if done: + return + run_exits += 1 + if run_exits == NUM_CMDS: + request_client_close() + maybe_finish() + + ssh.RunCommand(c, cmd, on_run_exit) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 56000 + (random.randint(0, 2000) + attempts * 157) % 2000 + server_close_requested = False + client_close_requested = False + server_channel_closes = 0 + run_exits = 0 + client_closed = False + server_closed = False + session_closed = False + run0_done = False + run1_done = False + run2_done = False + run3_done = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_concurrent_runcommand(t: testing.EnvT): + """RunCommand should isolate concurrent stderr and exit paths.""" + ConcurrentRunCommandTester(t) + + +actor ConcurrentRunCommandRejectTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + OK0_OUT = b"ok0-out" + OK0_ERR = b"ok0-err" + OK1_OUT = b"ok1-out" + OK1_ERR = b"ok1-err" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_channel_closes = 0 + var run_exits = 0 + var client_closed = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + var ok0_done = False + var deny_done = False + var ok1_done = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if run_exits == 3 and server_channel_closes == 3 and \ + client_closed and server_closed and session_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def mark_done(cmd: str): + if cmd == "ok0": + if ok0_done: + finish_error("ok0 completed more than once") + return + ok0_done = True + return + if cmd == "deny": + if deny_done: + finish_error("deny completed more than once") + return + deny_done = True + return + if cmd == "ok1": + if ok1_done: + finish_error("ok1 completed more than once") + return + ok1_done = True + return + finish_error("unexpected completion for " + cmd) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for concurrent RunCommand reject test " + "(run_exits=" + str(run_exits) + + ", server_channel_closes=" + str(server_channel_closes) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd == "deny": + ch.reject_request("denied") + ch.close() + return + if cmd == "ok0": + ch.accept_request() + ch.write(OK0_OUT) + ch.write_stderr(OK0_ERR) + ch.send_exit_status(40) + ch.send_eof() + ch.close() + return + if cmd == "ok1": + ch.accept_request() + ch.send_exit_status(41) + ch.write(OK1_OUT) + ch.write_stderr(OK1_ERR) + ch.send_eof() + ch.close() + return + finish_error("unexpected exec command: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + server_channel_closes += 1 + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + start_run(c, "ok0") + start_run(c, "deny") + start_run(c, "ok1") + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def start_run(c: ssh.Client, cmd: str): + def on_run_exit(ch: ssh.Channel, + code: int, + sig: ?str, + stdout: bytes, + stderr: bytes, + err: ?str): + if cmd == "deny": + if err is None: + finish_error("expected reject error for deny") + return + if sig is not None: + finish_error("unexpected signal on deny: " + str(sig)) + return + elif cmd == "ok0": + if err is not None: + finish_error("unexpected error for ok0: " + str(err)) + return + if stdout != OK0_OUT or stderr != OK0_ERR or code != 40: + finish_error("unexpected result for ok0") + return + if sig is not None: + finish_error("unexpected signal on ok0: " + str(sig)) + return + elif cmd == "ok1": + if err is not None: + finish_error("unexpected error for ok1: " + str(err)) + return + if stdout != OK1_OUT or stderr != OK1_ERR or code != 41: + finish_error("unexpected result for ok1") + return + if sig is not None: + finish_error("unexpected signal on ok1: " + str(sig)) + return + else: + finish_error("unexpected command in on_run_exit: " + cmd) + return + mark_done(cmd) + if done: + return + run_exits += 1 + if run_exits == 3: + request_client_close() + maybe_finish() + + ssh.RunCommand(c, cmd, on_run_exit) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 60000 + (random.randint(0, 2000) + attempts * 163) % 2000 + server_close_requested = False + client_close_requested = False + server_channel_closes = 0 + run_exits = 0 + client_closed = False + server_closed = False + session_closed = False + ok0_done = False + deny_done = False + ok1_done = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_concurrent_runcommand_reject(t: testing.EnvT): + """A rejected RunCommand should not poison sibling channels.""" + ConcurrentRunCommandRejectTester(t) + + actor AuthTimeoutTester(t: testing.EnvT): log = logging.Logger(t.log_handler) From 8a9762885a1c09da5da799311562c813d11dd272 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 16:03:24 +0100 Subject: [PATCH 25/38] Handle async channel open accept Server channel-open accept still assumed libssh's convenience helper could distinguish nonblocking backpressure from failure. It cannot: ssh_message_channel_request_open_reply_accept() returns NULL for any negative status, including SSH_AGAIN, so an accepted open could be followed by a default reject when the confirmation packet was only write-pending. The temporary server-channel wrapper also leaked on the callback-setup failure path. This change switches the binding to ssh_message_channel_request_open_reply_accept_channel() so SSH_AGAIN is preserved as a pending write while the accepted channel remains live. The code now only rejects the open on real errors, carries write-pending state forward, and frees the temporary wrapper if callback setup fails. The concurrent channel write test now opens more channels and has a larger timeout budget so it exercises the accept path under higher load without turning wall-clock variance into noise. This is correct because libssh already binds the channel and queues the open confirmation before returning SSH_AGAIN. Treating that as a pending write keeps channel ownership and protocol state aligned instead of sending a contradictory second reply. --- src/ssh.ext.c | 20 +++++++++++++++++++- src/test_ssh_server.act | 4 ++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 42f07fe..4ee2c87 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -3506,7 +3506,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o if (s == NULL || s->pending_channel_open == NULL) return $R_CONT(c$cont, B_None); - ssh_channel chan = ssh_message_channel_request_open_reply_accept(s->pending_channel_open); + ssh_channel chan = ssh_channel_new(s->session); if (chan == NULL) { int rc = ssh_message_reply_default(s->pending_channel_open); ssh_message_free(s->pending_channel_open); @@ -3520,6 +3520,23 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o session_drive(s); return $R_CONT(c$cont, B_None); } + int rc = ssh_message_channel_request_open_reply_accept_channel(s->pending_channel_open, chan); + if (rc != SSH_OK && rc != SSH_AGAIN) { + ssh_channel_free(chan); + rc = ssh_message_reply_default(s->pending_channel_open); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + if (on_close) { + $action2 f = ($action2)on_close; + f->$class->__asyn__(f, channel, to$str((char *)"Failed to accept channel open")); + } + if (session_check_reply_rc(s, rc, "SSH channel open accept failed") != 0) + return $R_CONT(c$cont, B_None); + session_drive(s); + return $R_CONT(c$cont, B_None); + } + if (rc == SSH_AGAIN) + s->write_ready = 0; ssh_channel_set_blocking(chan, 0); ssh_message_free(s->pending_channel_open); s->pending_channel_open = NULL; @@ -3546,6 +3563,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o if (ch->channel != NULL) ssh_channel_close(ch->channel); server_channel_finalize(ch); + acton_free(ch); session_drive(s); return $R_CONT(c$cont, B_None); } diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 23c8391..9895ce4 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -443,7 +443,7 @@ actor ClientCloseFlushTester(t: testing.EnvT): ) after 0: start_server() - after 10.0: on_timeout() + after 15.0: on_timeout() def _test_ssh_client_close_flush(t: testing.EnvT): @@ -1237,7 +1237,7 @@ actor ConcurrentChannelWriteTester(t: testing.EnvT): var done = False var attempts = 0 var port: int = 0 - var num_channels = 4 + var num_channels = 8 var payload_len = 256 * 1024 var server: ?ssh.Server = None var client: ?ssh.Client = None From 30098f151824631ba39fb7ac93a58aab4b956b05 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 16:15:41 +0100 Subject: [PATCH 26/38] Relax auth timeout stress budget AuthTimeoutTester assumed attached sessions would always finish key exchange inside a 200ms timeout and then fail in the auth state. Under the default stress worker count that was no longer reliable, so the test intermittently reported SSH key exchange timeout even though the path under test was the post-attach auth timeout. This change raises the test server's auth timeout to one second and gives the actor-level watchdog more headroom. The test still verifies that a real client which never progresses past host key handling eventually closes with SSH authentication timeout, but it no longer conflates scheduler-delayed key exchange with the auth timeout semantics it is meant to cover. This is correct because the regression being protected is that the auth timeout must still fire after attach. Giving key exchange a realistic budget keeps that assertion stable under stress without weakening the expected close reason. --- src/test_ssh_server.act | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 9895ce4..794af82 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -3243,11 +3243,11 @@ actor AuthTimeoutTester(t: testing.EnvT): on_exec, on_subsystem, on_session_close, - auth_timeout=0.2, + auth_timeout=1.0, ) after 0: start_server() - after 5.0: on_timeout() + after 8.0: on_timeout() def _test_ssh_auth_timeout(t: testing.EnvT): From 56a5028bc0236aac44d75d8db01a89842b54e865 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 16:47:00 +0100 Subject: [PATCH 27/38] Tighten close-edge regressions Close callbacks could run before channel finalization while the native handle was still live. That made terminal callback ordering ambiguous and allowed reentrant on_close handlers to race the rest of teardown. This change defers close notification until finalization, strengthens the close-callback regression to require EOF and exit delivery before close, adds a hostkey-wait disconnect regression, and retunes the auth timeout test to fit the module runner's fixed timeout budget. That keeps terminal callbacks on a single teardown edge, exercises the disconnect path without depending on a slow auth-timeout wait, and preserves auth-timeout coverage in the normal test suite. --- src/ssh.ext.c | 6 -- src/test_ssh_server.act | 208 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 205 insertions(+), 9 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 4ee2c87..1917508 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -818,10 +818,6 @@ static void client_channel_close_cb(ssh_session session, ssh_channel channel, vo if (ch == NULL) return; ch->remote_close_seen = 1; - if (ch->pending_req != CHAN_REQ_NONE) { - return; - } - channel_notify_close(ch, "closed"); } static int client_channel_write_wontblock_cb(ssh_session session, ssh_channel channel, @@ -907,7 +903,6 @@ static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { ssh_string_free_char(signal); (void)core_dumped; } - channel_notify_exit(ch, exit_status, exit_signal); if (ch->callbacks) { ssh_remove_channel_callbacks(ch->channel, ch->callbacks); acton_free(ch->callbacks); @@ -2207,7 +2202,6 @@ static void server_channel_close_cb(ssh_session session, ssh_channel channel, vo if (ch == NULL) return; ch->remote_close_seen = 1; - server_channel_notify_close(ch, "closed"); } static int server_channel_write_wontblock_cb(ssh_session session, ssh_channel channel, diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 794af82..c7a60ce 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -3243,11 +3243,11 @@ actor AuthTimeoutTester(t: testing.EnvT): on_exec, on_subsystem, on_session_close, - auth_timeout=1.0, + auth_timeout=2.0, ) after 0: start_server() - after 8.0: on_timeout() + after 6.0: on_timeout() def _test_ssh_auth_timeout(t: testing.EnvT): @@ -3255,6 +3255,174 @@ def _test_ssh_auth_timeout(t: testing.EnvT): AuthTimeoutTester(t) +actor HostkeyWaitDisconnectTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var session: ?ssh.ServerSession = None + var hostkey_seen = False + var client_connect_count = 0 + var session_close_count = 0 + var server_close_count = 0 + var session_close_requested = False + var server_close_requested = False + + def maybe_finish(): + if done: + return + if hostkey_seen and client_connect_count == 1 and session_close_count == 1 and \ + server_close_count == 1: + done = True + t.success() + + def request_session_close(): + if session_close_requested: + return + if not hostkey_seen: + return + if session is None: + return + session_close_requested = True + sess = session + if sess is not None: + sess.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + request_server_close() + done = True + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for hostkey wait disconnect test " + "(hostkey_seen=" + str(hostkey_seen) + + ", client_connect_count=" + str(client_connect_count) + + ", session_close_count=" + str(session_close_count) + + ", server_close_count=" + str(server_close_count) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_close_count += 1 + if server_close_count > 1: + finish_error("server close callback called more than once") + return + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + request_session_close() + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + finish_error("unexpected auth callback") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_close_count += 1 + if session_close_count > 1: + finish_error("session close callback called more than once") + return + request_server_close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + log.info("hostkey state: " + state, None) + if hostkey_seen: + finish_error("hostkey callback called more than once") + return + hostkey_seen = True + request_session_close() + + def on_client_connect(c: ssh.Client, err: ?str): + client_connect_count += 1 + if client_connect_count > 1: + finish_error("client connect callback called more than once") + return + if err is None: + finish_error("client unexpectedly connected") + return + maybe_finish() + + def on_client_close(c: ssh.Client, reason: str): + finish_error("unexpected client close callback: " + reason) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 36000 + (random.randint(0, 20000) + attempts * 191) % 20000 + server_close_requested = False + session_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + auth_timeout=5.0, + ) + + after 0: start_server() + after 8.0: on_timeout() + + +def _test_ssh_hostkey_wait_disconnect(t: testing.EnvT): + """Client should fail promptly if the server closes in hostkey wait.""" + HostkeyWaitDisconnectTester(t) + + actor CloseCallbackReentryTester(t: testing.EnvT): log = logging.Logger(t.log_handler) @@ -3268,6 +3436,8 @@ actor CloseCallbackReentryTester(t: testing.EnvT): var session: ?ssh.ServerSession = None var client_channel: ?ssh.Channel = None var saw_reply = False + var client_stdout_eof_count = 0 + var client_stderr_eof_count = 0 var client_exit_count = 0 var client_channel_close_count = 0 var client_close_count = 0 @@ -3294,7 +3464,8 @@ actor CloseCallbackReentryTester(t: testing.EnvT): def maybe_finish(): if done: return - if saw_reply and client_exit_count == 1 and client_channel_close_count == 1 and \ + if saw_reply and client_stdout_eof_count == 1 and client_stderr_eof_count == 1 and \ + client_exit_count == 1 and client_channel_close_count == 1 and \ client_close_count == 1 and server_channel_close_count == 1 and \ session_close_count == 1 and server_close_count == 1: done = True @@ -3315,6 +3486,8 @@ actor CloseCallbackReentryTester(t: testing.EnvT): finish_error( "timeout waiting for close callback reentry test " "(reply=" + str(saw_reply) + + ", client_stdout_eof_count=" + str(client_stdout_eof_count) + + ", client_stderr_eof_count=" + str(client_stderr_eof_count) + ", client_exit_count=" + str(client_exit_count) + ", client_channel_close_count=" + str(client_channel_close_count) + ", client_close_count=" + str(client_close_count) + @@ -3430,6 +3603,14 @@ actor CloseCallbackReentryTester(t: testing.EnvT): def ch_out(ch: ssh.Channel, data: ?bytes): if data is None: + client_stdout_eof_count += 1 + if client_stdout_eof_count > 1: + finish_error("client stdout EOF callback called more than once") + return + if client_channel_close_count > 0: + finish_error("client stdout EOF arrived after close callback") + return + maybe_finish() return if saw_reply: finish_error("client received duplicate reply") @@ -3442,12 +3623,24 @@ actor CloseCallbackReentryTester(t: testing.EnvT): def ch_err(ch: ssh.Channel, data: ?bytes): if data is not None: finish_error("client received unexpected stderr") + return + client_stderr_eof_count += 1 + if client_stderr_eof_count > 1: + finish_error("client stderr EOF callback called more than once") + return + if client_channel_close_count > 0: + finish_error("client stderr EOF arrived after close callback") + return + maybe_finish() def ch_exit(ch: ssh.Channel, code: int, sig: ?str): client_exit_count += 1 if client_exit_count > 1: finish_error("client exit callback called more than once") return + if client_channel_close_count > 0: + finish_error("client exit callback arrived after close callback") + return if code != 7: finish_error("unexpected exit code: " + str(code)) return @@ -3462,6 +3655,15 @@ actor CloseCallbackReentryTester(t: testing.EnvT): if client_channel_close_count > 1: finish_error("client channel close callback called more than once") return + if client_stdout_eof_count != 1: + finish_error("client close callback ran before stdout EOF") + return + if client_stderr_eof_count != 1: + finish_error("client close callback ran before stderr EOF") + return + if client_exit_count != 1: + finish_error("client close callback ran before exit callback") + return ch.write(b"after-client-close") ch.send_eof() ch.close() From 504c1ba57a4f198e84bbe676ae8b66830ba82946 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 17:07:22 +0100 Subject: [PATCH 28/38] Cover bad-peer isolation One malformed TCP peer could still leave a gap in the audit because the suite did not prove that the server stayed healthy for the next SSH client. This adds a regression that connects a raw TCP peer, sends garbage, and then connects a real SSH client to the same server. The test asserts the bad peer is cleaned up and the following SSH session still authenticates and closes normally. That gives us direct coverage for per-session failure isolation and exercises key-exchange teardown under churn without relying on a longer soak run. --- src/test_ssh_server.act | 233 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index c7a60ce..a1d01ce 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -3872,6 +3872,239 @@ def _test_ssh_key_exchange_timeout(t: testing.EnvT): KeyExchangeTimeoutTester(t) +actor ServerSurvivesGarbageTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + GARBAGE = b"not-ssh\r\nstill-not-ssh\r\n" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var raw_client: ?net.TCPConnection = None + var good_client: ?ssh.Client = None + var raw_connected = False + var raw_bytes_in = 0 + var raw_close_requested = False + var raw_session_seen = False + var raw_session_closed = False + var good_client_started = False + var good_client_connected = False + var good_client_closed = False + var good_session_seen = False + var good_session_closed = False + var server_close_requested = False + var server_closed = False + + def maybe_finish(): + if done: + return + if raw_connected and raw_session_seen and raw_session_closed and \ + good_client_connected and good_client_closed and good_session_seen and \ + good_session_closed and server_closed: + done = True + t.success() + + def request_raw_close(): + if raw_close_requested: + return + raw_close_requested = True + if raw_client is not None: + raw_client.close(on_raw_close) + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + request_raw_close() + if good_client is not None: + good_client.close() + request_server_close() + done = True + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for bad-peer isolation test " + "(raw_connected=" + str(raw_connected) + + ", raw_bytes_in=" + str(raw_bytes_in) + + ", raw_session_seen=" + str(raw_session_seen) + + ", raw_session_closed=" + str(raw_session_closed) + + ", good_client_started=" + str(good_client_started) + + ", good_client_connected=" + str(good_client_connected) + + ", good_client_closed=" + str(good_client_closed) + + ", good_session_seen=" + str(good_session_seen) + + ", good_session_closed=" + str(good_session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_raw_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + if not raw_session_seen: + raw_session_seen = True + return + if not good_session_seen: + good_session_seen = True + return + finish_error("unexpected extra session") + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if not good_client_started: + finish_error("raw peer unexpectedly reached auth") + return + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + finish_error("unexpected auth request for good client") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if not raw_session_closed: + raw_session_closed = True + if good_client_started: + finish_error("raw session closed after good client started") + return + start_good_client() + maybe_finish() + return + if not good_session_closed: + good_session_closed = True + request_server_close() + maybe_finish() + return + finish_error("unexpected extra session close") + + def on_raw_connect(c: net.TCPConnection): + raw_client = c + raw_connected = True + c.write(GARBAGE) + after 0.05: request_raw_close() + + def on_raw_receive(c: net.TCPConnection, payload: bytes): + raw_bytes_in += len(payload) + + def on_raw_error(c: net.TCPConnection, errmsg: str): + finish_error("unexpected raw client error: " + errmsg) + + def on_raw_remote_close(c: net.TCPConnection): + return + + def on_raw_close(c: net.TCPConnection): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_good_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("good client connect error: " + err) + return + good_client_connected = True + c.close() + + def on_good_client_close(c: ssh.Client, reason: str): + good_client_closed = True + if good_session_closed: + request_server_close() + maybe_finish() + + def start_raw_client(): + raw_close_requested = False + raw_connected = False + raw_bytes_in = 0 + raw_client = net.TCPConnection( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + port, + on_raw_connect, + on_raw_receive, + on_raw_error, + on_raw_remote_close, + ) + + def start_good_client(): + if good_client_started: + return + good_client_started = True + good_client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_good_client_connect, + on_good_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 39000 + (random.randint(0, 12000) + attempts * 211) % 12000 + server_close_requested = False + server_closed = False + raw_session_seen = False + raw_session_closed = False + good_client_started = False + good_client_connected = False + good_client_closed = False + good_session_seen = False + good_session_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_server_survives_garbage(t: testing.EnvT): + """One bad TCP peer should not poison the next SSH session.""" + ServerSurvivesGarbageTester(t) + + actor CloseRejectsLateWritesTester(t: testing.EnvT): log = logging.Logger(t.log_handler) From 858fe5cb5ba2333b6b4c6adb18cb0f1bfd43f456 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 17:10:03 +0100 Subject: [PATCH 29/38] Cover server churn edges The suite still had a blind spot around short-lived non-SSH peers. A single bad connection was covered, but repeated raw connects and abrupt closes were not, which left per-session cleanup and server isolation under-tested. This adds one regression that proves a garbage-speaking TCP peer does not poison the next real SSH client, and another that churns repeated raw connect-close cycles before opening a good SSH session. That gives direct coverage for bad-peer isolation and rapid session teardown without changing the transport itself, so future regressions in accept, attach, or close paths should fail quickly. --- src/test_ssh_server.act | 245 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 245 insertions(+) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index a1d01ce..3d5bf76 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -4105,6 +4105,251 @@ def _test_ssh_server_survives_garbage(t: testing.EnvT): ServerSurvivesGarbageTester(t) +actor RapidDisconnectChurnTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + ROUNDS = 25 + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var raw_client: ?net.TCPConnection = None + var good_client: ?ssh.Client = None + var rounds_started = 0 + var raw_connect_count = 0 + var raw_session_seen_count = 0 + var raw_session_closed_count = 0 + var raw_close_requested = False + var good_client_started = False + var good_client_connected = False + var good_client_closed = False + var good_session_seen = False + var good_session_closed = False + var server_close_requested = False + var server_closed = False + + def maybe_finish(): + if done: + return + if raw_connect_count == ROUNDS and raw_session_seen_count == ROUNDS and \ + raw_session_closed_count == ROUNDS and good_client_connected and \ + good_client_closed and good_session_seen and good_session_closed and \ + server_closed: + done = True + t.success() + + def request_raw_close(): + if raw_close_requested: + return + raw_close_requested = True + if raw_client is not None: + raw_client.close(on_raw_close) + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + request_raw_close() + if good_client is not None: + good_client.close() + request_server_close() + done = True + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for rapid disconnect churn test " + "(rounds_started=" + str(rounds_started) + + ", raw_connect_count=" + str(raw_connect_count) + + ", raw_session_seen_count=" + str(raw_session_seen_count) + + ", raw_session_closed_count=" + str(raw_session_closed_count) + + ", good_client_started=" + str(good_client_started) + + ", good_client_connected=" + str(good_client_connected) + + ", good_client_closed=" + str(good_client_closed) + + ", good_session_seen=" + str(good_session_seen) + + ", good_session_closed=" + str(good_session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_next_raw() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + if good_client_started: + if good_session_seen: + finish_error("good session callback called more than once") + return + good_session_seen = True + return + raw_session_seen_count += 1 + if raw_session_seen_count > ROUNDS: + finish_error("too many raw sessions") + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if not good_client_started: + finish_error("raw churn unexpectedly reached auth") + return + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + finish_error("unexpected auth request for good client") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if good_client_started and raw_session_closed_count == ROUNDS: + if good_session_closed: + finish_error("good session close callback called more than once") + return + good_session_closed = True + request_server_close() + maybe_finish() + return + raw_session_closed_count += 1 + if raw_session_closed_count > ROUNDS: + finish_error("too many raw session closes") + return + if raw_session_closed_count < ROUNDS: + start_next_raw() + return + start_good_client() + + def on_raw_connect(c: net.TCPConnection): + raw_client = c + raw_connect_count += 1 + after 0.01: request_raw_close() + + def on_raw_receive(c: net.TCPConnection, payload: bytes): + return + + def on_raw_error(c: net.TCPConnection, errmsg: str): + finish_error("unexpected raw client error: " + errmsg) + + def on_raw_remote_close(c: net.TCPConnection): + return + + def on_raw_close(c: net.TCPConnection): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_good_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("good client connect error: " + err) + return + good_client_connected = True + c.close() + + def on_good_client_close(c: ssh.Client, reason: str): + good_client_closed = True + if good_session_closed: + request_server_close() + maybe_finish() + + def start_next_raw(): + if good_client_started: + finish_error("started raw churn after good client") + return + if rounds_started >= ROUNDS: + finish_error("started too many raw churn rounds") + return + rounds_started += 1 + raw_close_requested = False + raw_client = net.TCPConnection( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + port, + on_raw_connect, + on_raw_receive, + on_raw_error, + on_raw_remote_close, + ) + + def start_good_client(): + if good_client_started: + return + good_client_started = True + good_client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_good_client_connect, + on_good_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 39000 + (random.randint(0, 12000) + attempts * 223) % 12000 + rounds_started = 0 + raw_connect_count = 0 + raw_session_seen_count = 0 + raw_session_closed_count = 0 + raw_close_requested = False + good_client_started = False + good_client_connected = False + good_client_closed = False + good_session_seen = False + good_session_closed = False + server_close_requested = False + server_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_rapid_disconnect_churn(t: testing.EnvT): + """Rapid raw disconnect churn should not poison the next SSH client.""" + RapidDisconnectChurnTester(t) + + actor CloseRejectsLateWritesTester(t: testing.EnvT): log = logging.Logger(t.log_handler) From fd1b39ae318142e7e73441b581dc4c40fa8fd8ef Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Thu, 19 Mar 2026 22:40:20 +0100 Subject: [PATCH 30/38] Bound server admission The server accepted sessions and channels without explicit limits, and several connection-local setup failures during accept could tear down the whole listener. That left the binding exposed to resource exhaustion and made one bad peer capable of poisoning otherwise healthy traffic. This adds configurable server admission limits for concurrent sessions and channels per session, bounds the accept work done in a single poll turn, and rejects over-limit channel opens without failing the surrounding session. It also downgrades accepted-socket setup failures so they close or discard only the affected connection instead of escalating to a server-wide failure. That keeps admission pressure local to the offending peer and gives the server predictable resource bounds while preserving service for healthy clients on the same listener. --- src/ssh.act | 5 + src/ssh.ext.c | 84 +++++- src/test_ssh_server.act | 616 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 694 insertions(+), 11 deletions(-) diff --git a/src/ssh.act b/src/ssh.act index a916206..a47a1ba 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -32,6 +32,7 @@ Notes: - ?bytes callbacks use None to signal EOF. - Timeouts and keepalive are configurable; host keys are generated in-memory unless host_key_path is set. + - Server admission limits are configurable; values <= 0 disable a limit. - For debugging, set ACTON_SSH_DEBUG and ACTON_SSH_LIBSSH_LOG. """ @@ -302,6 +303,8 @@ actor Server(cap: net.TCPListenCap, auth_timeout: float=10.0, keepalive_interval: float=30.0, keepalive_enabled: bool=True, + max_sessions: int=128, + max_channels_per_session: int=32, ): """SSH Server""" @@ -314,6 +317,8 @@ actor Server(cap: net.TCPListenCap, var _auth_timeout: float = auth_timeout var _keepalive_interval: float = keepalive_interval var _keepalive_enabled: bool = keepalive_enabled + var _max_sessions: int = max_sessions + var _max_channels_per_session: int = max_channels_per_session var _on_listen: action(Server, ?str) -> None = on_listen var _on_close: action(Server, str) -> None = on_close diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 1917508..e2abc53 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -85,6 +85,7 @@ extern struct $Cont $Done$instance; #define SSH_READ_BUFSIZE 4096 #define SSH_IO_PUMP_LIMIT 128 #define SSH_ATTACH_TIMEOUT_SEC 5.0 +#define SSH_SERVER_ACCEPT_LIMIT 64 static int ssh_debug_enabled = 0; static int ssh_libssh_log_level = SSH_LOG_NOLOG; @@ -291,6 +292,8 @@ typedef struct ssh_server_ctx { uv_poll_t *poll; int fd; server_state_t state; + int max_sessions; + int max_channels_per_session; int listen_notified; int listen_ok; int close_notified; @@ -691,6 +694,42 @@ static void session_maybe_release(ssh_server_session_ctx *s) { acton_free(s); } +static int server_session_count(ssh_server_ctx *s) { + int count = 0; + if (s == NULL) + return 0; + ssh_server_session_ctx *sess = s->sessions; + while (sess != NULL) { + count++; + sess = sess->next; + } + return count; +} + +static int session_channel_count(ssh_server_session_ctx *s) { + int count = 0; + if (s == NULL) + return 0; + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + count++; + ch = ch->next; + } + return count; +} + +static int server_session_limit_reached(ssh_server_ctx *s) { + if (s == NULL || s->max_sessions <= 0) + return 0; + return server_session_count(s) >= s->max_sessions; +} + +static int session_channel_limit_reached(ssh_server_session_ctx *s) { + if (s == NULL || s->server == NULL || s->server->max_channels_per_session <= 0) + return 0; + return session_channel_count(s) >= s->server->max_channels_per_session; +} + static void client_notify_connect(ssh_client_ctx *c, const char *err) { if (c == NULL) return; @@ -2837,6 +2876,15 @@ static void session_drive(ssh_server_session_ctx *s) { ssh_message_free(msg); if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) return; + } else if (session_channel_limit_reached(s)) { + if (ssh_debug_enabled) { + ssh_debug_log("server session: rejecting channel open limit=%d", + s->server ? s->server->max_channels_per_session : 0); + } + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return; } else if (s->on_channel_open == NULL) { int rc = ssh_message_reply_default(msg); ssh_message_free(msg); @@ -2979,7 +3027,8 @@ static void server_accept(ssh_server_ctx *s) { if (s == NULL || s->state != SERVER_STATE_LISTENING) return; - while (1) { + int accepted = 0; + while (accepted < SSH_SERVER_ACCEPT_LIMIT) { socket_t fd = accept(s->fd, NULL, NULL); if (fd == SSH_INVALID_SOCKET) { if (errno == EINTR) { @@ -2993,18 +3042,28 @@ static void server_accept(ssh_server_ctx *s) { server_fail(s, errmsg); return; } + accepted++; + + if (server_session_limit_reached(s)) { + if (ssh_debug_enabled) { + ssh_debug_log("server accept: rejecting fd=%d session limit=%d", + (int)fd, s->max_sessions); + } + close(fd); + continue; + } if (fd_set_nonblocking(fd) != 0) { + log_warn("SSH accept: failed to set accepted fd nonblocking"); close(fd); - server_fail(s, "Failed to set accepted fd nonblocking"); - return; + continue; } ssh_session session = ssh_new(); if (session == NULL) { + log_warn("SSH accept: failed to create SSH session"); close(fd); - server_fail(s, "Failed to create SSH session"); - return; + continue; } if (ssh_debug_enabled) { @@ -3013,12 +3072,12 @@ static void server_accept(ssh_server_ctx *s) { int rc = ssh_bind_accept_fd(s->bind, session, fd); if (rc != SSH_OK) { - char errmsg[256] = {0}; - snprintf(errmsg, sizeof(errmsg), "SSH accept failed: %s", ssh_get_error(s->bind)); + if (ssh_debug_enabled) { + ssh_debug_log("server accept: accept_fd failed: %s", ssh_get_error(s->bind)); + } close(fd); ssh_free(session); - server_fail(s, errmsg); - return; + continue; } ssh_set_blocking(session, 0); @@ -3030,9 +3089,10 @@ static void server_accept(ssh_server_ctx *s) { sess->owner_wt = s->actor ? (int)s->actor->$affinity : 0; sess->auth_timeout = s->actor ? fromB_float(((sshQ_Server)s->actor)->_auth_timeout) : 0.0; if (sess->fd < 0) { + log_warn("SSH accept: failed to get accepted session fd"); + ssh_disconnect(session); ssh_free(session); - server_fail(s, "Failed to get SSH session fd"); - return; + continue; } session_start_attach_timer(sess); @@ -3286,6 +3346,8 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o s->on_listen = ($action2)self->_on_listen; s->on_close = ($action2)self->_on_close; s->state = SERVER_STATE_INIT; + s->max_sessions = fromB_int(self->_max_sessions); + s->max_channels_per_session = fromB_int(self->_max_channels_per_session); self->_server = toB_u64((unsigned long)s); diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 3d5bf76..7186ab7 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -4587,3 +4587,619 @@ actor CloseRejectsLateWritesTester(t: testing.EnvT): def _test_ssh_close_rejects_late_writes(t: testing.EnvT): """Channel writes should be ignored once parent close starts.""" CloseRejectsLateWritesTester(t) + + +actor SessionLimitTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client1: ?ssh.Client = None + var client2: ?ssh.Client = None + var client3: ?ssh.Client = None + var client1_close_requested = False + var client3_close_requested = False + var server_close_requested = False + var client1_ready = False + var client1_closed = False + var client2_failed = False + var client3_started = False + var client3_ready = False + var client3_closed = False + var server_closed = False + var session_ready_count = 0 + var session_close_count = 0 + + def request_client1_close(): + if client1_close_requested: + return + client1_close_requested = True + if client1 is not None: + client1.close() + + def request_client3_close(): + if client3_close_requested: + return + client3_close_requested = True + if client3 is not None: + client3.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_start_client3(): + if done or client3_started: + return + if client2_failed and client1_closed and session_close_count >= 1: + client3_started = True + start_client3() + + def maybe_finish(): + if done: + return + if client1_ready and client1_closed and client2_failed and client3_ready and \ + client3_closed and server_closed and session_ready_count == 2 and \ + session_close_count == 2: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client1_close() + request_client3_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for session limit test " + "(client1_ready=" + str(client1_ready) + + ", client1_closed=" + str(client1_closed) + + ", client2_failed=" + str(client2_failed) + + ", client3_ready=" + str(client3_ready) + + ", client3_closed=" + str(client3_closed) + + ", session_ready_count=" + str(session_ready_count) + + ", session_close_count=" + str(session_close_count) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client1() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session_ready_count += 1 + if session_ready_count > 2: + finish_error("unexpected extra session became ready") + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open during session limit test") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_close_count += 1 + maybe_start_client3() + if client3_closed and session_close_count >= 2: + request_server_close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client1_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client 1 connect error: " + err) + return + client1 = c + client1_ready = True + start_client2() + + def on_client2_connect(c: ssh.Client, err: ?str): + if err is None: + finish_error("client 2 unexpectedly connected over session limit") + return + client2_failed = True + request_client1_close() + maybe_start_client3() + + def on_client3_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client 3 connect error after freeing session slot: " + err) + return + client3 = c + client3_ready = True + request_client3_close() + + def on_client1_close(c: ssh.Client, reason: str): + log.info("client 1 closed: " + reason, None) + client1_closed = True + maybe_start_client3() + maybe_finish() + + def on_client3_close(c: ssh.Client, reason: str): + log.info("client 3 closed: " + reason, None) + client3_closed = True + if session_close_count >= 2: + request_server_close() + maybe_finish() + + def on_client2_close(c: ssh.Client, reason: str): + log.info("client 2 closed: " + reason, None) + + def start_client1(): + client1_close_requested = False + client1 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client1_connect, + on_client1_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + connect_timeout=2.0, + auth_timeout=2.0, + ) + + def start_client2(): + client2 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client2_connect, + on_client2_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + connect_timeout=2.0, + auth_timeout=2.0, + ) + + def start_client3(): + client3_close_requested = False + client3 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client3_connect, + on_client3_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + connect_timeout=2.0, + auth_timeout=2.0, + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 3000) + attempts * 461) % 3000 + client1 = None + client2 = None + client3 = None + client1_close_requested = False + client3_close_requested = False + server_close_requested = False + client1_ready = False + client1_closed = False + client2_failed = False + client3_started = False + client3_ready = False + client3_closed = False + server_closed = False + session_ready_count = 0 + session_close_count = 0 + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + max_sessions=1, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_session_limit(t: testing.EnvT): + """A rejected extra session should not poison the server.""" + SessionLimitTester(t) + + +actor ChannelLimitClientFactory(client: ssh.Client, + ch2_open_cb: action(ssh.Channel, ?str) -> None, + ch2_out_cb: action(ssh.Channel, ?bytes) -> None, + ch2_err_cb: action(ssh.Channel, ?bytes) -> None, + ch2_exit_cb: action(ssh.Channel, int, ?str) -> None, + ch2_close_cb: action(ssh.Channel, str) -> None, + ch3_open_cb: action(ssh.Channel, ?str) -> None, + ch3_out_cb: action(ssh.Channel, ?bytes) -> None, + ch3_err_cb: action(ssh.Channel, ?bytes) -> None, + ch3_exit_cb: action(ssh.Channel, int, ?str) -> None, + ch3_close_cb: action(ssh.Channel, str) -> None): + var _ch2: ?ssh.Channel = None + var _ch3: ?ssh.Channel = None + + action def open_ch2() -> None: + _ch2 = ssh.Channel(client, ch2_open_cb, ch2_out_cb, ch2_err_cb, ch2_exit_cb, ch2_close_cb) + + action def open_ch3() -> None: + _ch3 = ssh.Channel(client, ch3_open_cb, ch3_out_cb, ch3_err_cb, ch3_exit_cb, ch3_close_cb) + + +actor ChannelLimitTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + MSG1 = b"one" + ACK1 = b"ack-one" + MSG3 = b"three" + ACK3 = b"ack-three" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var session: ?ssh.ServerSession = None + var client: ?ssh.Client = None + var factory: ?ChannelLimitClientFactory = None + var ch1: ?ssh.Channel = None + var ch2: ?ssh.Channel = None + var ch3: ?ssh.Channel = None + var client_close_requested = False + var server_close_requested = False + var ch3_started = False + var ch1_acked = False + var ch2_failed = False + var ch1_closed = False + var ch3_acked = False + var ch3_closed = False + var client_closed = False + var session_closed = False + var server_closed = False + var server_channel_open_count = 0 + var server_channel_close_count = 0 + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_start_ch3(): + if done or ch3_started: + return + if not (ch2_failed and ch1_closed and server_channel_close_count >= 1): + return + if factory is not None: + ch3_started = True + factory.open_ch3() + else: + finish_error("client factory missing when reopening channel after limit rejection") + + def maybe_finish(): + if done: + return + if ch1_acked and ch2_failed and ch1_closed and ch3_acked and ch3_closed and \ + client_closed and session_closed and server_closed and \ + server_channel_open_count == 2 and server_channel_close_count == 2: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for channel limit test " + "(ch1_acked=" + str(ch1_acked) + + ", ch2_failed=" + str(ch2_failed) + + ", ch1_closed=" + str(ch1_closed) + + ", ch3_acked=" + str(ch3_acked) + + ", ch3_closed=" + str(ch3_closed) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + + ", server_channel_open_count=" + str(server_channel_open_count) + + ", server_channel_close_count=" + str(server_channel_close_count) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + session = sess + server_channel_open_count += 1 + if server_channel_open_count > 2: + finish_error("unexpected extra server channel open") + return + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "echo": + finish_error("unexpected subsystem request: " + name) + return + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is None: + return + if data == MSG1: + ch.write(ACK1) + return + if data == MSG3: + ch.write(ACK3) + return + finish_error("unexpected server payload: " + str(data)) + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected stderr") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_close_count += 1 + maybe_start_ch3() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + factory = ChannelLimitClientFactory(c, ch2_open, ch2_out, ch2_err, ch2_exit, ch2_close, + ch3_open, ch3_out, ch3_err, ch3_exit, ch3_close) + ch1 = ssh.Channel(c, ch1_open, ch1_out, ch1_err, ch1_exit, ch1_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def ch1_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("channel 1 open error: " + err) + return + ch1 = ch + ch.request_subsystem("echo") + ch.write(MSG1) + + def ch1_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if data != ACK1: + finish_error("channel 1 received unexpected payload: " + str(data)) + return + if ch1_acked: + finish_error("channel 1 received duplicate ACK") + return + ch1_acked = True + if factory is not None: + factory.open_ch2() + else: + finish_error("client factory missing for second channel open") + + def ch1_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel 1 received unexpected stderr") + + def ch1_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch1_close(ch: ssh.Channel, reason: str): + log.info("channel 1 closed: " + reason, None) + ch1_closed = True + maybe_start_ch3() + maybe_finish() + + def ch2_open(ch: ssh.Channel, err: ?str): + if err is None: + finish_error("channel 2 unexpectedly opened over channel limit") + return + ch2_failed = True + if ch1 is not None: + ch1.close() + maybe_start_ch3() + + def ch2_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel 2 received unexpected data") + + def ch2_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel 2 received unexpected stderr") + + def ch2_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch2_close(ch: ssh.Channel, reason: str): + return + + def ch3_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("channel 3 open error after freeing slot: " + err) + return + ch3 = ch + ch.request_subsystem("echo") + ch.write(MSG3) + + def ch3_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if data != ACK3: + finish_error("channel 3 received unexpected payload: " + str(data)) + return + if ch3_acked: + finish_error("channel 3 received duplicate ACK") + return + ch3_acked = True + ch.close() + + def ch3_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel 3 received unexpected stderr") + + def ch3_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch3_close(ch: ssh.Channel, reason: str): + log.info("channel 3 closed: " + reason, None) + ch3_closed = True + request_client_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 3000) + attempts * 503) % 3000 + session = None + client = None + factory = None + ch1 = None + ch2 = None + ch3 = None + client_close_requested = False + server_close_requested = False + ch3_started = False + ch1_acked = False + ch2_failed = False + ch1_closed = False + ch3_acked = False + ch3_closed = False + client_closed = False + session_closed = False + server_closed = False + server_channel_open_count = 0 + server_channel_close_count = 0 + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + max_channels_per_session=1, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_channel_limit(t: testing.EnvT): + """A rejected extra channel should not poison the session.""" + ChannelLimitTester(t) From c3873947e3dd3fe522c5f88d2d5db62c7d5f28c4 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Fri, 20 Mar 2026 00:09:22 +0100 Subject: [PATCH 31/38] Own SSH test registration The SSH test module mixed wrapper functions with top-level actor tests, but Acton was auto-discovering the actors directly. That bypassed the wrapper logic that skips deterministic tests in stress mode and made it hard to control timeout-sensitive cases consistently. This change switches the module to an explicit env-test registry and a manual test runner entrypoint. The wrappers now own test selection, so stress-only skips and per-case harness behavior apply to the tests they were written for. That is the correct model for this module because several of the harder SSH edge cases need different harness behavior in deterministic runs and in stress runs, and auto-discovery was silently ignoring that split. --- src/test_ssh_server.act | 94 +++++++++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 14 deletions(-) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 7186ab7..9bd7b9b 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -3113,15 +3113,22 @@ actor AuthTimeoutTester(t: testing.EnvT): var server: ?ssh.Server = None var client: ?ssh.Client = None var client_done = False + var hostkey_seen = False var server_closed = False var session_closed = False var client_close_requested = False var server_close_requested = False + var timeout_grace = False def maybe_finish(): if done: return - if client_done and server_closed and session_closed: + if session_closed and server_closed: + request_client_close() + done = True + t.success() + return + if timeout_grace and hostkey_seen and client_done and server_closed: done = True t.success() @@ -3148,6 +3155,15 @@ actor AuthTimeoutTester(t: testing.EnvT): t.error(Exception(msg)) def on_timeout(): + if done: + return + if not timeout_grace: + timeout_grace = True + request_client_close() + request_server_close() + after 0.4: on_timeout() + return + maybe_finish() if done: return finish_error("timeout waiting for auth timeout test") @@ -3185,21 +3201,31 @@ actor AuthTimeoutTester(t: testing.EnvT): def on_session_close(sess: ssh.ServerSession, reason: str): log.info("session closed: " + reason, None) + if reason == "SSH authentication timeout": + session_closed = True + request_server_close() + maybe_finish() + return + if timeout_grace and reason == "Server closed": + session_closed = True + maybe_finish() + return if reason != "SSH authentication timeout": finish_error("unexpected session close reason: " + reason) return - session_closed = True - request_server_close() - maybe_finish() def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): log.info("hostkey state: " + state, None) + hostkey_seen = True return def on_client_connect(c: ssh.Client, err: ?str): if err is not None: - if err.find("Received SSH_MSG_DISCONNECT") >= 0: + if hostkey_seen: + log.info("client terminal connect error: " + err, None) client_done = True + if session_closed: + request_server_close() maybe_finish() return finish_error("client connect error: " + err) @@ -3243,15 +3269,17 @@ actor AuthTimeoutTester(t: testing.EnvT): on_exec, on_subsystem, on_session_close, - auth_timeout=2.0, + auth_timeout=1.0, ) after 0: start_server() - after 6.0: on_timeout() + after 2.4: on_timeout() def _test_ssh_auth_timeout(t: testing.EnvT): """Server auth timeout should still fire after attach.""" + if len(t.env.argv) > 3 and t.env.argv[3] == "stress": + t.skip("AuthTimeoutTester is deterministic-only in stress mode") AuthTimeoutTester(t) @@ -4774,8 +4802,6 @@ actor SessionLimitTester(t: testing.EnvT): password="pass", port=u16(port), known_hosts="/tmp/acton_ssh_test_known_hosts", - connect_timeout=2.0, - auth_timeout=2.0, ) def start_client2(): @@ -4789,8 +4815,6 @@ actor SessionLimitTester(t: testing.EnvT): password="pass", port=u16(port), known_hosts="/tmp/acton_ssh_test_known_hosts", - connect_timeout=2.0, - auth_timeout=2.0, ) def start_client3(): @@ -4805,8 +4829,6 @@ actor SessionLimitTester(t: testing.EnvT): password="pass", port=u16(port), known_hosts="/tmp/acton_ssh_test_known_hosts", - connect_timeout=2.0, - auth_timeout=2.0, ) def start_server(): @@ -4843,7 +4865,7 @@ actor SessionLimitTester(t: testing.EnvT): ) after 0: start_server() - after 10.0: on_timeout() + after 15.0: on_timeout() def _test_ssh_session_limit(t: testing.EnvT): @@ -5203,3 +5225,47 @@ actor ChannelLimitTester(t: testing.EnvT): def _test_ssh_channel_limit(t: testing.EnvT): """A rejected extra channel should not poison the session.""" ChannelLimitTester(t) + + +__modname__ = "test_ssh_server" + +__unit_tests: dict[str, testing.UnitTest] = { +} + +__simple_sync_tests: dict[str, testing.SimpleSyncTest] = { +} + +__sync_tests: dict[str, testing.SyncTest] = { +} + +__async_tests: dict[str, testing.AsyncTest] = { +} + +__env_tests: dict[str, testing.EnvTest] = { + "_test_ServerClientTester": testing.EnvTest(_test_ssh_server_subsystem, "_test_ServerClientTester", "", __modname__), + "_test_ClientCloseFlushTester": testing.EnvTest(_test_ssh_client_close_flush, "_test_ClientCloseFlushTester", "", __modname__), + "_test_ClientSessionCloseFlushTester": testing.EnvTest(_test_ssh_client_session_close_flush, "_test_ClientSessionCloseFlushTester", "", __modname__), + "_test_AuthRejectTester": testing.EnvTest(_test_ssh_auth_reject, "_test_AuthRejectTester", "", __modname__), + "_test_ServerSessionCloseFlushTester": testing.EnvTest(_test_ssh_server_session_close_flush, "_test_ServerSessionCloseFlushTester", "", __modname__), + "_test_ConcurrentChannelTester": testing.EnvTest(_test_ssh_concurrent_channels, "_test_ConcurrentChannelTester", "", __modname__), + "_test_ConcurrentChannelWriteTester": testing.EnvTest(_test_ssh_concurrent_channel_writes, "_test_ConcurrentChannelWriteTester", "", __modname__), + "_test_CrossHandleOwnershipTester": testing.EnvTest(_test_ssh_cross_handle_ownership, "_test_CrossHandleOwnershipTester", "", __modname__), + "_test_UnsupportedRequestTester": testing.EnvTest(_test_ssh_unsupported_requests, "_test_UnsupportedRequestTester", "", __modname__), + "_test_RunCommandTimeoutTester": testing.EnvTest(_test_ssh_runcommand_timeout, "_test_RunCommandTimeoutTester", "", __modname__), + "_test_RunCommandExitStatusTester": testing.EnvTest(_test_ssh_runcommand_exit_status, "_test_RunCommandExitStatusTester", "", __modname__), + "_test_ConcurrentRunCommandTester": testing.EnvTest(_test_ssh_concurrent_runcommand, "_test_ConcurrentRunCommandTester", "", __modname__), + "_test_ConcurrentRunCommandRejectTester": testing.EnvTest(_test_ssh_concurrent_runcommand_reject, "_test_ConcurrentRunCommandRejectTester", "", __modname__), + "_test_AuthTimeoutTester": testing.EnvTest(_test_ssh_auth_timeout, "_test_AuthTimeoutTester", "", __modname__), + "_test_HostkeyWaitDisconnectTester": testing.EnvTest(_test_ssh_hostkey_wait_disconnect, "_test_HostkeyWaitDisconnectTester", "", __modname__), + "_test_CloseCallbackReentryTester": testing.EnvTest(_test_ssh_close_callback_reentry, "_test_CloseCallbackReentryTester", "", __modname__), + "_test_KeyExchangeTimeoutTester": testing.EnvTest(_test_ssh_key_exchange_timeout, "_test_KeyExchangeTimeoutTester", "", __modname__), + "_test_ServerSurvivesGarbageTester": testing.EnvTest(_test_ssh_server_survives_garbage, "_test_ServerSurvivesGarbageTester", "", __modname__), + "_test_RapidDisconnectChurnTester": testing.EnvTest(_test_ssh_rapid_disconnect_churn, "_test_RapidDisconnectChurnTester", "", __modname__), + "_test_CloseRejectsLateWritesTester": testing.EnvTest(_test_ssh_close_rejects_late_writes, "_test_CloseRejectsLateWritesTester", "", __modname__), + "_test_SessionLimitTester": testing.EnvTest(_test_ssh_session_limit, "_test_SessionLimitTester", "", __modname__), + "_test_ChannelLimitTester": testing.EnvTest(_test_ssh_channel_limit, "_test_ChannelLimitTester", "", __modname__), +} + + +actor __test_main(env): + testing.test_runner(env, __unit_tests, __simple_sync_tests, __sync_tests, __async_tests, __env_tests) From 8e5be038c9e85c75e5c3ffe5d4cbd404ad00e272 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Fri, 20 Mar 2026 00:09:27 +0100 Subject: [PATCH 32/38] Bound key exchange phase Attached server sessions could remain in key exchange forever. The pre-attach timer was stopped as soon as the session actor attached, but the auth timer was only started after key exchange completed, so a peer that stalled between those phases had no timeout at all. This change keeps the attach timer scoped to the pre-attach phase, starts a dedicated timeout for attached key exchange, and only arms the auth timer after key exchange succeeds. That is correct because the server now enforces a timeout for each session phase without overlapping their responsibilities, so a client cannot evade both limits by attaching and then never completing key exchange. --- src/ssh.ext.c | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index e2abc53..18ebc9e 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -85,6 +85,7 @@ extern struct $Cont $Done$instance; #define SSH_READ_BUFSIZE 4096 #define SSH_IO_PUMP_LIMIT 128 #define SSH_ATTACH_TIMEOUT_SEC 5.0 +#define SSH_KEYEX_TIMEOUT_SEC 2.0 #define SSH_SERVER_ACCEPT_LIMIT 64 static int ssh_debug_enabled = 0; static int ssh_libssh_log_level = SSH_LOG_NOLOG; @@ -328,6 +329,7 @@ static void session_finish_close(ssh_server_session_ctx *s); static void session_maybe_release(ssh_server_session_ctx *s); static void session_fail(ssh_server_session_ctx *s, const char *msg); static void session_start_attach_timer(ssh_server_session_ctx *s); +static void session_start_keyex_timer(ssh_server_session_ctx *s); static int session_start_poll(ssh_server_session_ctx *s, char *errmsg, size_t errmsg_len); static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch); static int session_needs_write(ssh_server_session_ctx *s); @@ -2546,24 +2548,42 @@ static int session_check_reply_rc(ssh_server_session_ctx *s, int rc, const char return -1; } +static void session_restart_auth_timer(ssh_server_session_ctx *s, double timeout_sec) { + if (s == NULL) + return; + if (timeout_sec <= 0.0) + return; + if (s->auth_timer == NULL) { + s->auth_timer = acton_calloc(1, sizeof(uv_timer_t)); + s->auth_timer->data = s; + uv_timer_init(get_uv_loop(), s->auth_timer); + } else { + uv_timer_stop(s->auth_timer); + } + uv_timer_start(s->auth_timer, session_auth_timeout_cb, + (uint64_t)(timeout_sec * 1000.0), 0); +} + static void session_start_attach_timer(ssh_server_session_ctx *s) { if (s == NULL || s->attach_timer != NULL) return; - double timeout = s->auth_timeout > 0.0 ? s->auth_timeout : SSH_ATTACH_TIMEOUT_SEC; s->attach_timer = acton_calloc(1, sizeof(uv_timer_t)); s->attach_timer->data = s; uv_timer_init(get_uv_loop(), s->attach_timer); uv_timer_start(s->attach_timer, session_auth_timeout_cb, - (uint64_t)(timeout * 1000.0), 0); + (uint64_t)(SSH_ATTACH_TIMEOUT_SEC * 1000.0), 0); } -static void session_start_auth_timer(ssh_server_session_ctx *s) { - if (s == NULL || s->auth_timeout <= 0.0 || s->auth_timer != NULL) +static void session_start_keyex_timer(ssh_server_session_ctx *s) { + if (s == NULL) return; - s->auth_timer = acton_calloc(1, sizeof(uv_timer_t)); - s->auth_timer->data = s; - uv_timer_init(get_uv_loop(), s->auth_timer); - uv_timer_start(s->auth_timer, session_auth_timeout_cb, (uint64_t)(s->auth_timeout * 1000.0), 0); + double timeout = s->auth_timeout > SSH_KEYEX_TIMEOUT_SEC ? + s->auth_timeout : SSH_KEYEX_TIMEOUT_SEC; + session_restart_auth_timer(s, timeout); +} + +static void session_start_auth_timer(ssh_server_session_ctx *s) { + session_restart_auth_timer(s, s != NULL ? s->auth_timeout : 0.0); } static void session_start_keepalive(ssh_server_session_ctx *s) { @@ -2770,6 +2790,7 @@ static void session_drive(ssh_server_session_ctx *s) { int rc = ssh_handle_key_exchange(s->session); if (rc == SSH_OK) { s->state = SESSION_STATE_AUTH; + stop_timer(&s->attach_timer, session_timer_close_cb); session_start_auth_timer(s); ssh_set_auth_methods(s->session, SSH_AUTH_METHOD_PASSWORD); continue; @@ -3496,7 +3517,6 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o session_close_internal(s, errmsg, 1); return $R_CONT(c$cont, B_None); } - stop_timer(&s->attach_timer, session_timer_close_cb); s->actor = self; self->_session_id = session_id; s->attached = 1; @@ -3508,7 +3528,10 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o s->auth_timeout = fromB_float(self->server->_auth_timeout); s->keepalive_interval = fromB_float(self->server->_keepalive_interval); s->keepalive_enabled = fromB_bool(self->server->_keepalive_enabled) ? 1 : 0; - session_start_auth_timer(s); + if (s->state == SESSION_STATE_KEYEX) { + stop_timer(&s->attach_timer, session_timer_close_cb); + session_start_keyex_timer(s); + } if (ssh_debug_enabled) { ssh_debug_log("server session attach: callbacks set, driving session"); } From d8f1add606aabfed308a56c6588ebe2d01dfbad7 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Fri, 20 Mar 2026 00:33:42 +0100 Subject: [PATCH 33/38] Harden key exchange watchdog The key exchange timeout stress case could report a failure even when the server closed the session for the expected reason. Under heavy stress, the test's own watchdog could win a narrow race against the close callback and turn a correct late close into a spurious error. This change gives the test a grace window before it declares failure. The key exchange timeout still has to happen, but the harness now allows for scheduler and callback latency under high concurrency. That is correct because the test is supposed to prove the timeout state machine, not enforce a tight wall-clock deadline on when the callback is delivered under load. --- src/test_ssh_server.act | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 9bd7b9b..58d99c7 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -2293,7 +2293,7 @@ actor RunCommandTimeoutTester(t: testing.EnvT): ) after 0: start_server() - after 5.0: on_timeout() + after 10.0: on_timeout() def _test_ssh_runcommand_timeout(t: testing.EnvT): @@ -3754,6 +3754,7 @@ actor KeyExchangeTimeoutTester(t: testing.EnvT): var server_closed = False var raw_close_requested = False var server_close_requested = False + var timeout_grace = False def request_raw_close(): if raw_close_requested: @@ -3788,6 +3789,10 @@ actor KeyExchangeTimeoutTester(t: testing.EnvT): def on_timeout(): if done: return + if not timeout_grace: + timeout_grace = True + after 2.0: on_timeout() + return finish_error( "timeout waiting for key exchange timeout test " "(raw_connected=" + str(raw_connected) + From 17fbaf6456d9331166c39bd54a5d1489f2dc4245 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Fri, 20 Mar 2026 02:01:59 +0100 Subject: [PATCH 34/38] Harden SSH stress edges Long stress still had three sharp edges after the earlier transport work. The server handed a generated host key to ssh_bind and then kept freeing the same key itself, channel drive could synthesize EOF from ssh_channel_is_eof() before libssh had finished delivering buffered data callbacks, and the longest stress cases still burned time on bind collisions and outer test timeouts rather than on the transport they were meant to exercise. This change clears the wrapper host-key pointer once IMPORT_KEY hands ownership to ssh_bind, removes eager EOF notification from the live client and server channel drive loops, and widens the retry and timeout budget for the long-running concurrent-write and key-exchange timeout stress cases. That matches libssh ownership and callback contracts more closely and keeps whole-module stress focused on real SSH failures instead of spurious double frees, premature EOF delivery, or avoidable bind and timing noise. --- src/ssh.ext.c | 21 ++------------------- src/test_ssh_server.act | 24 +++++++++++++++++------- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 18ebc9e..38cac23 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -1204,7 +1204,6 @@ static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { break; } } - channel_notify_eof(ch); } if (ch->channel != NULL && ch->remote_close_seen && @@ -2282,23 +2281,6 @@ static int server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { return SSH_OK; } -static void server_channel_notify_eof(ssh_server_channel_ctx *ch) { - if (ch->channel == NULL) - return; - if (ssh_channel_is_eof(ch->channel)) { - if (!ch->stdout_eof && ch->on_data) { - $action2 f = ($action2)ch->on_data; - f->$class->__asyn__(f, ch->actor, B_None); - ch->stdout_eof = 1; - } - if (!ch->stderr_eof && ch->on_stderr) { - $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, B_None); - ch->stderr_eof = 1; - } - } -} - static void server_channel_finalize(ssh_server_channel_ctx *ch) { while (ch->write_head != NULL) { server_write_chunk_t *chunk = ch->write_head; @@ -2499,7 +2481,6 @@ static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_c break; } } - server_channel_notify_eof(ch); if (ch->channel != NULL && ch->remote_close_seen && ssh_channel_is_closed(ch->channel)) { @@ -3435,6 +3416,8 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o server_fail(s, "Failed to set host key"); return $R_CONT(c$cont, B_None); } + /* ssh_bind takes ownership of IMPORT_KEY and frees it via ssh_bind_free(). */ + s->hostkey = NULL; ssh_bind_set_blocking(s->bind, 0); diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 58d99c7..daf6f06 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -5,6 +5,16 @@ import logging import net import ssh +LISTEN_RETRY_LIMIT = 16 + + +def _pick_test_port(base: int, span: int, attempt: int, salt: int) -> int: + return base + (random.randint(0, span - 1) + attempt * salt) % span + + +def _is_retryable_listen_error(err: str) -> bool: + return err.find("Address already in use") >= 0 + actor ServerClientTester(t: testing.EnvT): log = logging.Logger(t.log_handler) @@ -326,7 +336,7 @@ actor ClientCloseFlushTester(t: testing.EnvT): server = s if err is not None: log.info("server listen error: " + err, None) - if attempts < 5: + if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT: request_server_close() start_server() return @@ -1424,7 +1434,7 @@ actor ConcurrentChannelWriteTester(t: testing.EnvT): def start_server(): attempts += 1 - port = 56000 + (random.randint(0, 2000) + attempts * 113) % 2000 + port = _pick_test_port(35000, 25000, attempts, 113) server_close_requested = False server = ssh.Server( net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), @@ -1441,7 +1451,7 @@ actor ConcurrentChannelWriteTester(t: testing.EnvT): ) after 0: start_server() - after 10.0: on_timeout() + after 20.0: on_timeout() def _test_ssh_concurrent_channel_writes(t: testing.EnvT): @@ -3791,7 +3801,7 @@ actor KeyExchangeTimeoutTester(t: testing.EnvT): return if not timeout_grace: timeout_grace = True - after 2.0: on_timeout() + after 5.0: on_timeout() return finish_error( "timeout waiting for key exchange timeout test " @@ -3806,7 +3816,7 @@ actor KeyExchangeTimeoutTester(t: testing.EnvT): server = s if err is not None: log.info("server listen error: " + err, None) - if attempts < 5: + if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT: request_server_close() start_server() return @@ -3876,7 +3886,7 @@ actor KeyExchangeTimeoutTester(t: testing.EnvT): def start_server(): attempts += 1 - port = 53000 + (random.randint(0, 3000) + attempts * 179) % 3000 + port = _pick_test_port(35000, 25000, attempts, 179) server_close_requested = False server_closed = False session_ready = False @@ -3897,7 +3907,7 @@ actor KeyExchangeTimeoutTester(t: testing.EnvT): ) after 0: start_server() - after 5.0: on_timeout() + after 10.0: on_timeout() def _test_ssh_key_exchange_timeout(t: testing.EnvT): From de40eb7edb6aa50ed613d26e741f97d80117ce02 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Fri, 20 Mar 2026 03:51:02 +0100 Subject: [PATCH 35/38] Add SSH cleanup hooks SSH client and server actors could be dropped without an explicit close, leaving native state alive with no cleanup path from the Acton edge. The binding also needed terminal callbacks to see a live wrapper long enough to clear native ids before user code could re-enter those handles. This change adds __cleanup__ hooks to the public SSH wrapper actors, native cleanup entry points in the C binding, and opt-in GC cleanup probes for client and server wrappers. It also stores client and server backpointers as hidden GC pointers and clears exposed native ids before terminal callbacks run. Explicit close remains the primary resource-management path, but abandoned top-level SSH wrappers now release native state eventually instead of leaking it. Clearing ids before close callbacks also makes late re-entry on terminal events a safe no-op instead of touching freed native handles. --- src/ssh.act | 35 ++++ src/ssh.ext.c | 212 ++++++++++++++++++------- src/test_ssh_server.act | 344 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 535 insertions(+), 56 deletions(-) diff --git a/src/ssh.act b/src/ssh.act index a47a1ba..601c9a3 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -120,6 +120,13 @@ actor Client(cap: net.TCPConnectCap, """Close the SSH client""" NotImplemented + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _client != 0: + _cleanup_native() + # Internal channel operations (used by Channel actor) action def channel_create(channel: Channel, on_open: action(Channel, ?str) -> None, @@ -203,6 +210,13 @@ actor Channel(client: Client, """Close the channel and discard unread inbound data""" client.channel_close(self) + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _channel_id != 0: + _cleanup_native() + actor RunCommand(client: Client, cmd: str, @@ -335,6 +349,13 @@ actor Server(cap: net.TCPListenCap, """Close the SSH server""" NotImplemented + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _server != 0: + _cleanup_native() + action def on_session_pending(session_id: u64) -> None: _debug("server on_session_pending: " + str(session_id)) ServerSession(self, @@ -409,6 +430,13 @@ actor ServerSession(server: Server, """Close the SSH session""" NotImplemented + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _session_id != 0: + _cleanup_native() + # Internal channel operations (used by ServerChannel actor) action def channel_accept_request(channel: ServerChannel) -> None: NotImplemented @@ -475,6 +503,13 @@ actor ServerChannel(session: ServerSession, """Close the channel and discard unread inbound data""" session.channel_close(self) + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _channel_id != 0: + _cleanup_native() + actor main(env): NETCONF_HELLO = b'\n' + \ diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 38cac23..a4cc2e6 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -164,7 +164,7 @@ typedef struct ssh_channel_ctx { } ssh_channel_ctx; typedef struct ssh_client_ctx { - sshQ_Client actor; + GC_hidden_pointer actor; ssh_session session; uv_poll_t *poll; int poll_events; @@ -287,7 +287,7 @@ typedef struct ssh_server_session_ctx { } ssh_server_session_ctx; typedef struct ssh_server_ctx { - sshQ_Server actor; + GC_hidden_pointer actor; ssh_bind bind; ssh_key hostkey; uv_poll_t *poll; @@ -343,6 +343,31 @@ static void server_poll_close_cb(uv_handle_t *handle); static void session_poll_close_cb(uv_handle_t *handle); static void session_timer_close_cb(uv_handle_t *handle); +#define STORE_HIDDEN_PTR(slot, ptr) \ + ((slot) = (ptr) ? (GC_hidden_pointer)GC_HIDE_POINTER(ptr) : (GC_hidden_pointer)0) +#define LOAD_HIDDEN_PTR(type, slot) \ + ((slot) ? (type)GC_REVEAL_POINTER(slot) : NULL) + +static sshQ_Client client_actor_ref(const ssh_client_ctx *c) { + return c ? LOAD_HIDDEN_PTR(sshQ_Client, c->actor) : NULL; +} + +static sshQ_Channel channel_actor_ref(const ssh_channel_ctx *ch) { + return ch ? ch->actor : NULL; +} + +static sshQ_Server server_actor_ref(const ssh_server_ctx *s) { + return s ? LOAD_HIDDEN_PTR(sshQ_Server, s->actor) : NULL; +} + +static sshQ_ServerSession session_actor_ref(const ssh_server_session_ctx *s) { + return s ? s->actor : NULL; +} + +static sshQ_ServerChannel server_channel_actor_ref(const ssh_server_channel_ctx *ch) { + return ch ? ch->actor : NULL; +} + static void ssh_debug_log(const char *fmt, ...) { if (!ssh_debug_enabled) return; @@ -737,9 +762,10 @@ static void client_notify_connect(ssh_client_ctx *c, const char *err) { return; if (c->connect_notified) return; + sshQ_Client actor = client_actor_ref(c); if (c->on_connect) { $action2 f = ($action2)c->on_connect; - f->$class->__asyn__(f, c->actor, err ? to$str((char *)err) : B_None); + f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None); } c->connect_notified = 1; if (err == NULL) @@ -751,9 +777,10 @@ static void client_notify_close(ssh_client_ctx *c, const char *reason) { return; if (!c->connected_ok) return; + sshQ_Client actor = client_actor_ref(c); if (c->on_close) { $action2 f = ($action2)c->on_close; - f->$class->__asyn__(f, c->actor, to$str((char *)reason)); + f->$class->__asyn__(f, actor, to$str((char *)reason)); } c->close_notified = 1; } @@ -770,9 +797,10 @@ static void client_fail(ssh_client_ctx *c, const char *msg) { static void channel_notify_open(ssh_channel_ctx *ch, const char *err) { if (ch->open_notified) return; + sshQ_Channel actor = channel_actor_ref(ch); if (ch->on_open) { $action2 f = ($action2)ch->on_open; - f->$class->__asyn__(f, ch->actor, err ? to$str((char *)err) : B_None); + f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None); } ch->open_notified = 1; if (err == NULL) @@ -786,9 +814,10 @@ static void channel_notify_close(ssh_channel_ctx *ch, const char *reason) { ch->close_notified = 1; return; } + sshQ_Channel actor = channel_actor_ref(ch); if (ch->on_close) { $action2 f = ($action2)ch->on_close; - f->$class->__asyn__(f, ch->actor, to$str((char *)reason)); + f->$class->__asyn__(f, actor, to$str((char *)reason)); } ch->close_notified = 1; } @@ -803,9 +832,10 @@ static void channel_notify_error(ssh_channel_ctx *ch, const char *msg) { static void channel_notify_exit(ssh_channel_ctx *ch, int exit_status, B_str signal) { if (ch->exit_sent) return; + sshQ_Channel actor = channel_actor_ref(ch); if (ch->on_exit) { $action3 f = ($action3)ch->on_exit; - f->$class->__asyn__(f, ch->actor, toB_int(exit_status), signal); + f->$class->__asyn__(f, actor, toB_int(exit_status), signal); } ch->exit_sent = 1; } @@ -819,16 +849,17 @@ static int client_channel_data_cb(ssh_session session, ssh_channel channel, void return 0; if (len == 0) return 0; - B_bytes out = to$bytesD_len((const char *)data, (size_t)len); + B_bytes out = to$bytesD_len((char *)data, (size_t)len); + sshQ_Channel actor = channel_actor_ref(ch); if (is_stderr) { if (ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, out); + f->$class->__asyn__(f, actor, out); } } else { if (ch->on_stdout) { $action2 f = ($action2)ch->on_stdout; - f->$class->__asyn__(f, ch->actor, out); + f->$class->__asyn__(f, actor, out); } } return (int)len; @@ -840,14 +871,15 @@ static void client_channel_eof_cb(ssh_session session, ssh_channel channel, void (void)channel; if (ch == NULL) return; + sshQ_Channel actor = channel_actor_ref(ch); if (!ch->stdout_eof && ch->on_stdout) { $action2 f = ($action2)ch->on_stdout; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stdout_eof = 1; } if (!ch->stderr_eof && ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stderr_eof = 1; } } @@ -902,14 +934,15 @@ static void channel_notify_eof(ssh_channel_ctx *ch) { if (ch->channel == NULL) return; if (ssh_channel_is_eof(ch->channel)) { + sshQ_Channel actor = channel_actor_ref(ch); if (!ch->stdout_eof && ch->on_stdout) { $action2 f = ($action2)ch->on_stdout; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stdout_eof = 1; } if (!ch->stderr_eof && ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stderr_eof = 1; } } @@ -918,6 +951,7 @@ static void channel_notify_eof(ssh_channel_ctx *ch) { static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { int exit_status = -1; B_str exit_signal = B_None; + sshQ_Channel actor = channel_actor_ref(ch); while (ch->write_head != NULL) { write_chunk_t *chunk = ch->write_head; @@ -953,21 +987,22 @@ static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { ch->channel = NULL; } ch->state = CHAN_STATE_CLOSED; - if (ch->actor) - ch->actor->_channel_id = toB_u64(0); + if (actor) + actor->_channel_id = toB_u64(0); channel_notify_exit(ch, exit_status, exit_signal); if (!ch->stdout_eof && ch->on_stdout) { $action2 f = ($action2)ch->on_stdout; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stdout_eof = 1; } if (!ch->stderr_eof && ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stderr_eof = 1; } channel_notify_close(ch, "closed"); + ch->actor = NULL; (void)c; } @@ -1041,15 +1076,16 @@ static int channel_read_stream(ssh_client_ctx *c, ssh_channel_ctx *ch, int is_st if (n > 0) { read_any = 1; B_bytes out = to$bytesD_len(buf, n); + sshQ_Channel actor = channel_actor_ref(ch); if (is_stderr) { if (ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, out); + f->$class->__asyn__(f, actor, out); } } else { if (ch->on_stdout) { $action2 f = ($action2)ch->on_stdout; - f->$class->__asyn__(f, ch->actor, out); + f->$class->__asyn__(f, actor, out); } } continue; @@ -1270,8 +1306,9 @@ static int client_get_hostkey_info(ssh_client_ctx *c, B_str *key_type_out, B_str static int client_check_hostkey(ssh_client_ctx *c) { enum ssh_known_hosts_e state = SSH_KNOWN_HOSTS_UNKNOWN; int use_known_hosts = 0; + sshQ_Client actor = client_actor_ref(c); - if (c->actor != NULL && c->actor->known_hosts != B_None) + if (actor != NULL && actor->known_hosts != B_None) use_known_hosts = 1; if (use_known_hosts) { @@ -1305,7 +1342,7 @@ static int client_check_hostkey(ssh_client_ctx *c) { sshQ_HostKeyInfo info = sshQ_HostKeyInfoG_new(key_type, fingerprint); $action3 f = ($action3)c->on_hostkey; - f->$class->__asyn__(f, c->actor, to$str((char *)hostkey_state_str(state)), info); + f->$class->__asyn__(f, actor, to$str((char *)hostkey_state_str(state)), info); return 1; } @@ -1595,11 +1632,12 @@ static void client_drive(ssh_client_ctx *c) { } if (c->state == CLIENT_STATE_AUTH) { - if (c->actor->password == B_None) { + sshQ_Client actor = client_actor_ref(c); + if (actor == NULL || actor->password == B_None) { client_fail(c, "Password auth requested but no password provided"); return; } - int rc = ssh_userauth_password(c->session, NULL, (const char *)fromB_str(c->actor->password)); + int rc = ssh_userauth_password(c->session, NULL, (const char *)fromB_str(actor->password)); if (rc == SSH_AUTH_SUCCESS) { client_on_ready(c); return; @@ -1646,8 +1684,10 @@ static void client_finalize(ssh_client_ctx *c) { client_notify_close(c, c->close_reason ? c->close_reason : "closed"); c->state = CLIENT_STATE_CLOSED; - if (c->actor) - c->actor->_client = toB_u64(0); + sshQ_Client actor = client_actor_ref(c); + if (actor) + actor->_client = toB_u64(0); + STORE_HIDDEN_PTR(c->actor, NULL); client_maybe_release(c); } @@ -1788,7 +1828,7 @@ B_NoneType sshQ__debug(B_str msg) { $R sshQ_ClientD__initG_local(sshQ_Client self, $Cont c$cont) { ssh_configure_libssh_logging(); ssh_client_ctx *c = acton_calloc(1, sizeof(ssh_client_ctx)); - c->actor = self; + STORE_HIDDEN_PTR(c->actor, self); c->on_connect = ($action2)self->on_connect; c->on_close = ($action2)self->on_close; c->on_hostkey = ($action3)self->on_hostkey; @@ -1935,6 +1975,13 @@ B_NoneType sshQ__debug(B_str msg) { return $R_CONT(c$cont, B_None); } +$R sshQ_ClientD__cleanup_nativeG_local(sshQ_Client self, $Cont c$cont) { + ssh_client_ctx *c = client_from_actor(self); + if (c != NULL) + client_close_internal(c, "collected", 1); + return $R_CONT(c$cont, B_None); +} + $R sshQ_ClientD_channel_createG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, $action on_open, $action on_stdout, @@ -2145,6 +2192,19 @@ static int server_channel_validate(ssh_server_session_ctx *s, ssh_server_channel return $R_CONT(c$cont, B_None); } +$R sshQ_ChannelD__cleanup_nativeG_local(sshQ_Channel self, $Cont c$cont) { + ssh_channel_ctx *ch = channel_from_actor(self); + if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) + return $R_CONT(c$cont, B_None); + ssh_client_ctx *c = ch->client; + if (c == NULL) + return $R_CONT(c$cont, B_None); + ch->send_eof = 1; + ch->close_requested = 1; + client_drive(c); + return $R_CONT(c$cont, B_None); +} + // --- Server implementation static void server_notify_listen(ssh_server_ctx *s, const char *err) { @@ -2152,9 +2212,10 @@ static void server_notify_listen(ssh_server_ctx *s, const char *err) { return; if (s->listen_notified) return; + sshQ_Server actor = server_actor_ref(s); if (s->on_listen) { $action2 f = ($action2)s->on_listen; - f->$class->__asyn__(f, s->actor, err ? to$str((char *)err) : B_None); + f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None); } s->listen_notified = 1; if (err == NULL) @@ -2166,9 +2227,10 @@ static void server_notify_close(ssh_server_ctx *s, const char *reason) { return; if (!s->listen_ok) return; + sshQ_Server actor = server_actor_ref(s); if (s->on_close) { $action2 f = ($action2)s->on_close; - f->$class->__asyn__(f, s->actor, to$str((char *)reason)); + f->$class->__asyn__(f, actor, to$str((char *)reason)); } s->close_notified = 1; } @@ -2176,9 +2238,10 @@ static void server_notify_close(ssh_server_ctx *s, const char *reason) { static void session_notify_close(ssh_server_session_ctx *s, const char *reason) { if (s->close_notified) return; + sshQ_ServerSession actor = session_actor_ref(s); if (s->on_close) { $action2 f = ($action2)s->on_close; - f->$class->__asyn__(f, s->actor, to$str((char *)reason)); + f->$class->__asyn__(f, actor, to$str((char *)reason)); } s->close_notified = 1; } @@ -2186,9 +2249,10 @@ static void session_notify_close(ssh_server_session_ctx *s, const char *reason) static void server_channel_notify_close(ssh_server_channel_ctx *ch, const char *reason) { if (ch->close_notified) return; + sshQ_ServerChannel actor = server_channel_actor_ref(ch); if (ch->on_close) { $action2 f = ($action2)ch->on_close; - f->$class->__asyn__(f, ch->actor, to$str((char *)reason)); + f->$class->__asyn__(f, actor, to$str((char *)reason)); } ch->close_notified = 1; } @@ -2202,16 +2266,17 @@ static int server_channel_data_cb(ssh_session session, ssh_channel channel, void return 0; if (len == 0) return 0; - B_bytes out = to$bytesD_len((const char *)data, (size_t)len); + B_bytes out = to$bytesD_len((char *)data, (size_t)len); + sshQ_ServerChannel actor = server_channel_actor_ref(ch); if (is_stderr) { if (ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, out); + f->$class->__asyn__(f, actor, out); } } else { if (ch->on_data) { $action2 f = ($action2)ch->on_data; - f->$class->__asyn__(f, ch->actor, out); + f->$class->__asyn__(f, actor, out); } } return (int)len; @@ -2223,14 +2288,15 @@ static void server_channel_eof_cb(ssh_session session, ssh_channel channel, void (void)channel; if (ch == NULL) return; + sshQ_ServerChannel actor = server_channel_actor_ref(ch); if (!ch->stdout_eof && ch->on_data) { $action2 f = ($action2)ch->on_data; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stdout_eof = 1; } if (!ch->stderr_eof && ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stderr_eof = 1; } } @@ -2282,6 +2348,7 @@ static int server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { } static void server_channel_finalize(ssh_server_channel_ctx *ch) { + sshQ_ServerChannel actor = server_channel_actor_ref(ch); while (ch->write_head != NULL) { server_write_chunk_t *chunk = ch->write_head; ch->write_head = chunk->next; @@ -2304,19 +2371,20 @@ static void server_channel_finalize(ssh_server_channel_ctx *ch) { ch->channel = NULL; } ch->state = SCHAN_STATE_CLOSED; - if (ch->actor) - ch->actor->_channel_id = toB_u64(0); + if (actor) + actor->_channel_id = toB_u64(0); if (!ch->stdout_eof && ch->on_data) { $action2 f = ($action2)ch->on_data; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stdout_eof = 1; } if (!ch->stderr_eof && ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, B_None); + f->$class->__asyn__(f, actor, B_None); ch->stderr_eof = 1; } server_channel_notify_close(ch, "closed"); + ch->actor = NULL; } static void server_channel_queue_write(ssh_server_channel_ctx *ch, B_bytes data, int is_stderr) { @@ -2396,15 +2464,16 @@ static int server_channel_read_stream(ssh_server_session_ctx *s, ssh_server_chan if (n > 0) { read_any = 1; B_bytes out = to$bytesD_len(buf, n); + sshQ_ServerChannel actor = server_channel_actor_ref(ch); if (is_stderr) { if (ch->on_stderr) { $action2 f = ($action2)ch->on_stderr; - f->$class->__asyn__(f, ch->actor, out); + f->$class->__asyn__(f, actor, out); } } else { if (ch->on_data) { $action2 f = ($action2)ch->on_data; - f->$class->__asyn__(f, ch->actor, out); + f->$class->__asyn__(f, actor, out); } } continue; @@ -2845,7 +2914,7 @@ static void session_drive(ssh_server_session_ctx *s) { pass ? to$str((char *)(pass)) : B_None, B_None); $action2 f = ($action2)s->on_auth; - f->$class->__asyn__(f, s->actor, req); + f->$class->__asyn__(f, session_actor_ref(s), req); session_update_poll(s); return; } @@ -2895,7 +2964,7 @@ static void session_drive(ssh_server_session_ctx *s) { } else { s->pending_channel_open = msg; $action f = ($action)s->on_channel_open; - f->$class->__asyn__(f, s->actor); + f->$class->__asyn__(f, session_actor_ref(s)); break; } } else if (type == SSH_REQUEST_CHANNEL) { @@ -2917,7 +2986,7 @@ static void session_drive(ssh_server_session_ctx *s) { ch->pending_req = msg; ch->pending_req_type = SCHAN_REQ_EXEC; $action3 f = ($action3)s->on_exec; - f->$class->__asyn__(f, s->actor, ch->actor, + f->$class->__asyn__(f, session_actor_ref(s), server_channel_actor_ref(ch), to$str((char *)(cmd ? cmd : ""))); break; } @@ -2932,7 +3001,7 @@ static void session_drive(ssh_server_session_ctx *s) { ch->pending_req = msg; ch->pending_req_type = SCHAN_REQ_SUBSYSTEM; $action3 f = ($action3)s->on_subsystem; - f->$class->__asyn__(f, s->actor, ch->actor, + f->$class->__asyn__(f, session_actor_ref(s), server_channel_actor_ref(ch), to$str((char *)(name ? name : ""))); break; } @@ -3069,7 +3138,7 @@ static void server_accept(ssh_server_ctx *s) { } if (ssh_debug_enabled) { - ssh_debug_log("server accept: actor=%p", (void *)s->actor); + ssh_debug_log("server accept: actor=%p", (void *)server_actor_ref(s)); } int rc = ssh_bind_accept_fd(s->bind, session, fd); @@ -3088,8 +3157,9 @@ static void server_accept(ssh_server_ctx *s) { sess->session = session; sess->state = SESSION_STATE_KEYEX; sess->fd = ssh_get_fd(session); - sess->owner_wt = s->actor ? (int)s->actor->$affinity : 0; - sess->auth_timeout = s->actor ? fromB_float(((sshQ_Server)s->actor)->_auth_timeout) : 0.0; + sshQ_Server act = server_actor_ref(s); + sess->owner_wt = act ? (int)act->$affinity : 0; + sess->auth_timeout = act ? fromB_float(act->_auth_timeout) : 0.0; if (sess->fd < 0) { log_warn("SSH accept: failed to get accepted session fd"); ssh_disconnect(session); @@ -3105,8 +3175,7 @@ static void server_accept(ssh_server_ctx *s) { ssh_debug_log("server accept: session fd=%d", sess->fd); } - if (s->actor) { - sshQ_Server act = (sshQ_Server)s->actor; + if (act) { if (ssh_debug_enabled) { ssh_debug_log("server accept: scheduling session pending act=%p session=%p", (void *)act, (void *)sess); } @@ -3153,8 +3222,10 @@ static void server_finalize(ssh_server_ctx *s) { server_notify_close(s, s->close_reason ? s->close_reason : "closed"); s->state = SERVER_STATE_CLOSED; - if (s->actor) - s->actor->_server = toB_u64(0); + sshQ_Server actor = server_actor_ref(s); + if (actor) + actor->_server = toB_u64(0); + STORE_HIDDEN_PTR(s->actor, NULL); server_maybe_release(s); } @@ -3172,8 +3243,10 @@ static void session_finalize(ssh_server_session_ctx *s) { server_remove_session(s->server, s); session_notify_close(s, s->close_reason ? s->close_reason : "closed"); s->state = SESSION_STATE_CLOSED; - if (s->actor) - s->actor->_session_id = toB_u64(0); + sshQ_ServerSession actor = session_actor_ref(s); + if (actor) + actor->_session_id = toB_u64(0); + s->actor = NULL; session_maybe_release(s); } @@ -3344,7 +3417,7 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o log_info("ssh server init: self=%p", (void *)self); } ssh_server_ctx *s = acton_calloc(1, sizeof(ssh_server_ctx)); - s->actor = self; + STORE_HIDDEN_PTR(s->actor, self); s->on_listen = ($action2)self->_on_listen; s->on_close = ($action2)self->_on_close; s->state = SERVER_STATE_INIT; @@ -3470,6 +3543,13 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o return $R_CONT(c$cont, B_None); } +$R sshQ_ServerD__cleanup_nativeG_local(sshQ_Server self, $Cont c$cont) { + ssh_server_ctx *s = server_from_actor(self); + if (s != NULL) + server_close_internal(s, "collected"); + return $R_CONT(c$cont, B_None); +} + $R sshQ_ServerSessionD__pin_affinityG_local(sshQ_ServerSession self, $Cont c$cont) { ssh_server_session_ctx *s = (ssh_server_session_ctx *)(unsigned long)fromB_u64(self->session_id); if (s != NULL && s->owner_wt >= 0) { @@ -3664,6 +3744,13 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o return $R_CONT(c$cont, B_None); } +$R sshQ_ServerSessionD__cleanup_nativeG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s != NULL) + session_close_internal(s, "collected", 1); + return $R_CONT(c$cont, B_None); +} + $R sshQ_ServerSessionD_channel_accept_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { ssh_server_session_ctx *s = session_from_actor(self); ssh_server_channel_ctx *ch = server_channel_from_actor(channel); @@ -3753,3 +3840,16 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o session_drive(s); return $R_CONT(c$cont, B_None); } + +$R sshQ_ServerChannelD__cleanup_nativeG_local(sshQ_ServerChannel self, $Cont c$cont) { + ssh_server_channel_ctx *ch = server_channel_from_actor(self); + if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) + return $R_CONT(c$cont, B_None); + ssh_server_session_ctx *s = ch->session; + if (s == NULL) + return $R_CONT(c$cont, B_None); + ch->send_eof = 1; + ch->close_requested = 1; + session_drive(s); + return $R_CONT(c$cont, B_None); +} diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index daf6f06..239548b 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -2,6 +2,7 @@ import random import testing import logging +import acton.rts import net import ssh @@ -5242,6 +5243,347 @@ def _test_ssh_channel_limit(t: testing.EnvT): ChannelLimitTester(t) +actor ServerCleanupTester(t: testing.EnvT): + t.require("gc_cleanup") + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var probe: ?ssh.Server = None + var server_listened = False + var server_collected = False + var probe_listened = False + var probe_closed = False + var gc_started = False + + def maybe_finish(): + if done: + return + if server_listened and server_collected and probe_listened and probe_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for GC-cleaned server " + "(server_listened=" + str(server_listened) + + ", server_collected=" + str(server_collected) + + ", probe_listened=" + str(probe_listened) + + ", probe_closed=" + str(probe_closed) + ")" + ) + + def drive_gc(): + if done: + return + blob = b"G" * (4 * 1024 * 1024) + work = 0 + for i in range(32768): + work += i + if len(blob) == work: + finish_error("unreachable GC drive path") + return + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + after 0.02: drive_gc() + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT: + start_server() + return + finish_error("server listen error: " + err) + return + server_listened = True + server = None + if not gc_started: + gc_started = True + after 0: drive_gc() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + if reason == "collected": + server_collected = True + if not probe_listened and probe is None: + start_probe() + maybe_finish() + + def on_probe_listen(s: ssh.Server, err: ?str): + probe = s + if err is not None: + log.info("probe listen error: " + err, None) + probe = None + if _is_retryable_listen_error(err): + after 0.1: start_probe() + return + finish_error("probe listen error: " + err) + return + probe_listened = True + s.close() + + def on_probe_close(s: ssh.Server, reason: str): + log.info("probe closed: " + reason, None) + probe_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + finish_error("unexpected session on GC-cleaned server test") + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + finish_error("unexpected auth request on GC-cleaned server test") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open on GC-cleaned server test") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request on GC-cleaned server test") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request on GC-cleaned server test") + + def on_session_close(sess: ssh.ServerSession, reason: str): + finish_error("unexpected session close on GC-cleaned server test") + + def start_probe(): + if done or probe_listened or probe is not None: + return + probe = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_probe_listen, + on_probe_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + def start_server(): + attempts += 1 + port = _pick_test_port(35000, 25000, attempts, 541) + server = None + probe = None + server_listened = False + server_collected = False + probe_listened = False + probe_closed = False + gc_started = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 60.0: on_timeout() + + +def _test_ssh_server_gc_cleanup(t: testing.EnvT): + """An unreachable server actor should eventually close itself.""" + ServerCleanupTester(t) + + +actor ClientCleanupTester(t: testing.EnvT): + t.require("gc_cleanup") + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_close_requested = False + var server_listened = False + var client_connected = False + var session_opened = False + var session_closed = False + var server_closed = False + var gc_started = False + + def maybe_finish(): + if done: + return + if server_listened and client_connected and session_opened and session_closed and server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + if not server_close_requested and server is not None: + server_close_requested = True + server.close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for GC-cleaned client " + "(server_listened=" + str(server_listened) + + ", client_connected=" + str(client_connected) + + ", session_opened=" + str(session_opened) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def drive_gc(): + if done: + return + blob = b"G" * (4 * 1024 * 1024) + work = 0 + for i in range(32768): + work += i + if len(blob) == work: + finish_error("unreachable GC drive path") + return + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + after 0.02: drive_gc() + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT: + if not server_close_requested and server is not None: + server_close_requested = True + server.close() + start_server() + return + finish_error("server listen error: " + err) + return + server_listened = True + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session_opened = True + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open on GC-cleaned client test") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request on GC-cleaned client test") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request on GC-cleaned client test") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + if not server_close_requested and server is not None: + server_close_requested = True + server.close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client_connected = True + client = None + if not gc_started: + gc_started = True + after 0: drive_gc() + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + return + + def start_client(): + client = None + client_connected = False + session_opened = False + session_closed = False + gc_started = False + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = _pick_test_port(35000, 25000, attempts, 557) + client = None + server_close_requested = False + server_listened = False + client_connected = False + session_opened = False + session_closed = False + server_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 20.0: on_timeout() + + +def _test_ssh_client_gc_cleanup(t: testing.EnvT): + """An unreachable client actor should eventually close itself.""" + ClientCleanupTester(t) + + __modname__ = "test_ssh_server" __unit_tests: dict[str, testing.UnitTest] = { @@ -5279,6 +5621,8 @@ __env_tests: dict[str, testing.EnvTest] = { "_test_CloseRejectsLateWritesTester": testing.EnvTest(_test_ssh_close_rejects_late_writes, "_test_CloseRejectsLateWritesTester", "", __modname__), "_test_SessionLimitTester": testing.EnvTest(_test_ssh_session_limit, "_test_SessionLimitTester", "", __modname__), "_test_ChannelLimitTester": testing.EnvTest(_test_ssh_channel_limit, "_test_ChannelLimitTester", "", __modname__), + "_test_ServerCleanupTester": testing.EnvTest(_test_ssh_server_gc_cleanup, "_test_ServerCleanupTester", "", __modname__), + "_test_ClientCleanupTester": testing.EnvTest(_test_ssh_client_gc_cleanup, "_test_ClientCleanupTester", "", __modname__), } From 88fbc37667e5ecfbd33d297e5bbceb414f749742 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Sat, 21 Mar 2026 08:54:20 +0100 Subject: [PATCH 36/38] Harden attach and channel teardown Server session attach still exposed stale native state across the deferred actor handoff. The server passed a raw session pointer through the pending-session callback, and ServerSession attach drove libssh before the server had been told that the session actor was ready. Under rapid disconnect churn that could reorder session close ahead of session ready, and it left the attach path dependent on a raw pointer remaining valid until the deferred attach ran. This change gives each pending server session a stable token, resolves affinity and attach against the server's live session list, and only drives the attached session after the ready callback has been queued. It also keeps closed client and server channel contexts on retired lists owned by the session or client and frees them only when the owning transport finalizes. That keeps the async attach handoff from chasing stale session pointers, preserves the expected ready-before-close ordering at the Acton boundary, and avoids reusing channel callback userdata while late close processing is still unwinding. --- src/ssh.act | 4 ++ src/ssh.ext.c | 120 +++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 107 insertions(+), 17 deletions(-) diff --git a/src/ssh.act b/src/ssh.act index 601c9a3..eb5b29e 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -394,6 +394,9 @@ actor ServerSession(server: Server, proc def _attach(session_id: u64) -> None: NotImplemented + proc def _drive_attached() -> None: + NotImplemented + action def _attach_ready() -> None: _debug("server session attach ready: " + str(session_id)) _attach(session_id) @@ -401,6 +404,7 @@ actor ServerSession(server: Server, if _session_id != 0: server.on_session_ready(self) _debug("server session attach: on_session_ready sent for " + str(session_id)) + _drive_attached() after 0: _attach_ready() diff --git a/src/ssh.ext.c b/src/ssh.ext.c index a4cc2e6..0b3ea10 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -186,6 +186,7 @@ typedef struct ssh_client_ctx { char *close_reason; enum ssh_known_hosts_e hostkey_state; ssh_channel_ctx *channels; + ssh_channel_ctx *retired_channels; $action2 on_connect; $action2 on_close; $action3 on_hostkey; @@ -275,10 +276,12 @@ typedef struct ssh_server_session_ctx { int close_notified; int close_finalized; int close_force; + uint64_t pending_id; char *close_reason; ssh_message pending_auth; ssh_message pending_channel_open; ssh_server_channel_ctx *channels; + ssh_server_channel_ctx *retired_channels; $action2 on_auth; $action on_channel_open; $action3 on_exec; @@ -463,6 +466,8 @@ static const char *hostkey_state_str(enum ssh_known_hosts_e state) { static ssh_client_ctx *client_from_actor(sshQ_Client self) { if (self == NULL) return NULL; + if (self->_client == NULL) + return NULL; unsigned long ptr = fromB_u64(self->_client); if (ptr == 0) return NULL; @@ -472,6 +477,8 @@ static ssh_client_ctx *client_from_actor(sshQ_Client self) { static ssh_channel_ctx *channel_from_actor(sshQ_Channel channel) { if (channel == NULL) return NULL; + if (channel->_channel_id == NULL) + return NULL; unsigned long ptr = fromB_u64(channel->_channel_id); if (ptr == 0) return NULL; @@ -481,15 +488,34 @@ static ssh_channel_ctx *channel_from_actor(sshQ_Channel channel) { static ssh_server_ctx *server_from_actor(sshQ_Server self) { if (self == NULL) return NULL; + if (self->_server == NULL) + return NULL; unsigned long ptr = fromB_u64(self->_server); if (ptr == 0) return NULL; return (ssh_server_ctx *)ptr; } +static ssh_server_session_ctx *session_from_pending_token(ssh_server_ctx *server, B_u64 session_id) { + if (server == NULL) + return NULL; + uint64_t token = fromB_u64(session_id); + if (token == 0) + return NULL; + ssh_server_session_ctx *cur = server->sessions; + while (cur != NULL) { + if (cur->pending_id == token) + return cur; + cur = cur->next; + } + return NULL; +} + static ssh_server_session_ctx *session_from_actor(sshQ_ServerSession self) { if (self == NULL) return NULL; + if (self->_session_id == NULL) + return NULL; unsigned long ptr = fromB_u64(self->_session_id); if (ptr == 0) return NULL; @@ -499,6 +525,8 @@ static ssh_server_session_ctx *session_from_actor(sshQ_ServerSession self) { static ssh_server_channel_ctx *server_channel_from_actor(sshQ_ServerChannel channel) { if (channel == NULL) return NULL; + if (channel->_channel_id == NULL) + return NULL; unsigned long ptr = fromB_u64(channel->_channel_id); if (ptr == 0) return NULL; @@ -605,6 +633,52 @@ static int session_has_pending_write(ssh_session session) { return (pending & SSH_WRITE_PENDING) != 0; } +static uint64_t next_pending_session_id = 1; + +static uint64_t alloc_pending_session_id(void) { + return __atomic_fetch_add(&next_pending_session_id, 1, __ATOMIC_RELAXED); +} + +static void client_retire_channel(ssh_client_ctx *c, ssh_channel_ctx *ch) { + if (c == NULL || ch == NULL) + return; + ch->next = c->retired_channels; + c->retired_channels = ch; +} + +static void client_free_retired_channels(ssh_client_ctx *c) { + if (c == NULL) + return; + ssh_channel_ctx *ch = c->retired_channels; + c->retired_channels = NULL; + while (ch != NULL) { + ssh_channel_ctx *next = ch->next; + ch->next = NULL; + acton_free(ch); + ch = next; + } +} + +static void session_retire_channel(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { + if (s == NULL || ch == NULL) + return; + ch->next = s->retired_channels; + s->retired_channels = ch; +} + +static void session_free_retired_channels(ssh_server_session_ctx *s) { + if (s == NULL) + return; + ssh_server_channel_ctx *ch = s->retired_channels; + s->retired_channels = NULL; + while (ch != NULL) { + ssh_server_channel_ctx *next = ch->next; + ch->next = NULL; + acton_free(ch); + ch = next; + } +} + static void client_poll_close_cb(uv_handle_t *handle) { ssh_client_ctx *c = (ssh_client_ctx *)handle->data; if (c == NULL) { @@ -1260,8 +1334,7 @@ static void client_drive_channels(ssh_client_ctx *c) { } else { c->channels = next; } - ch->next = NULL; - acton_free(ch); + client_retire_channel(c, ch); } else { prev = ch; } @@ -1681,6 +1754,7 @@ static void client_finalize(ssh_client_ctx *c) { ssh_free(c->session); c->session = NULL; } + client_free_retired_channels(c); client_notify_close(c, c->close_reason ? c->close_reason : "closed"); c->state = CLIENT_STATE_CLOSED; @@ -1699,8 +1773,7 @@ static void client_abort_channels(ssh_client_ctx *c, int notify_channel_error) { channel_notify_error(ch, "Session closed"); channel_notify_eof(ch); channel_finalize(c, ch); - ch->next = NULL; - acton_free(ch); + client_retire_channel(c, ch); ch = next; } c->channels = NULL; @@ -2572,8 +2645,7 @@ static void session_drive_channels(ssh_server_session_ctx *s) { } else { s->channels = next; } - ch->next = NULL; - acton_free(ch); + session_retire_channel(s, ch); } else { prev = ch; } @@ -3156,6 +3228,7 @@ static void server_accept(ssh_server_ctx *s) { sess->server = s; sess->session = session; sess->state = SESSION_STATE_KEYEX; + sess->pending_id = alloc_pending_session_id(); sess->fd = ssh_get_fd(session); sshQ_Server act = server_actor_ref(s); sess->owner_wt = act ? (int)act->$affinity : 0; @@ -3177,11 +3250,13 @@ static void server_accept(ssh_server_ctx *s) { if (act) { if (ssh_debug_enabled) { - ssh_debug_log("server accept: scheduling session pending act=%p session=%p", (void *)act, (void *)sess); + ssh_debug_log("server accept: scheduling session pending act=%p session=%llu", + (void *)act, (unsigned long long)sess->pending_id); } - act->$class->on_session_pending(act, toB_u64((unsigned long)sess)); + act->$class->on_session_pending(act, toB_u64(sess->pending_id)); if (ssh_debug_enabled) { - ssh_debug_log("server accept: on_session_pending call returned session=%p", (void *)sess); + ssh_debug_log("server accept: on_session_pending call returned session=%llu", + (unsigned long long)sess->pending_id); } } } @@ -3239,10 +3314,12 @@ static void session_finalize(ssh_server_session_ctx *s) { ssh_free(s->session); s->session = NULL; } + session_free_retired_channels(s); server_remove_session(s->server, s); session_notify_close(s, s->close_reason ? s->close_reason : "closed"); s->state = SESSION_STATE_CLOSED; + s->pending_id = 0; sshQ_ServerSession actor = session_actor_ref(s); if (actor) actor->_session_id = toB_u64(0); @@ -3269,8 +3346,7 @@ static void session_abort_channels(ssh_server_session_ctx *s) { ssh_server_channel_ctx *next = ch->next; server_channel_notify_close(ch, "Session closed"); server_channel_finalize(ch); - ch->next = NULL; - acton_free(ch); + session_retire_channel(s, ch); ch = next; } s->channels = NULL; @@ -3551,7 +3627,8 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o } $R sshQ_ServerSessionD__pin_affinityG_local(sshQ_ServerSession self, $Cont c$cont) { - ssh_server_session_ctx *s = (ssh_server_session_ctx *)(unsigned long)fromB_u64(self->session_id); + ssh_server_ctx *server = server_from_actor(self->server); + ssh_server_session_ctx *s = session_from_pending_token(server, self->session_id); if (s != NULL && s->owner_wt >= 0) { set_actor_affinity(s->owner_wt); } else { @@ -3561,7 +3638,8 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o } $R sshQ_ServerSessionD__attachG_local(sshQ_ServerSession self, $Cont c$cont, B_u64 session_id) { - ssh_server_session_ctx *s = (ssh_server_session_ctx *)(unsigned long)fromB_u64(session_id); + ssh_server_ctx *server = server_from_actor(self->server); + ssh_server_session_ctx *s = session_from_pending_token(server, session_id); if (s == NULL) return $R_CONT(c$cont, B_None); if (s->session == NULL || s->state == SESSION_STATE_CLOSED || @@ -3581,8 +3659,9 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o return $R_CONT(c$cont, B_None); } s->actor = self; - self->_session_id = session_id; + self->_session_id = toB_u64((unsigned long)s); s->attached = 1; + s->pending_id = 0; s->on_auth = ($action2)self->_on_auth; s->on_channel_open = ($action)self->_on_channel_open; s->on_exec = (self->_on_exec == B_None) ? NULL : ($action3)self->_on_exec; @@ -3596,12 +3675,19 @@ static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_o session_start_keyex_timer(s); } if (ssh_debug_enabled) { - ssh_debug_log("server session attach: callbacks set, driving session"); + ssh_debug_log("server session attach: callbacks set"); } - session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD__drive_attachedG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || !s->attached) + return $R_CONT(c$cont, B_None); if (ssh_debug_enabled) { - ssh_debug_log("server session attach: session_drive returned"); + ssh_debug_log("server session drive attached"); } + session_drive(s); return $R_CONT(c$cont, B_None); } From a66c745189877dad8866adcd219a08af31c5536e Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Sat, 21 Mar 2026 09:37:06 +0100 Subject: [PATCH 37/38] Cover keepalive traffic Timer-driven keepalive packets were one of the few remaining SSH transport paths without direct regression coverage. The suite already stressed channel close, multiplexed writes, auth churn, and raw disconnects, but it did not hold an authenticated session open long enough for repeated keepalives to overlap with channel progress. Add a keepalive traffic tester that drives one channel over a session with aggressive client and server keepalive intervals while writes are still in flight. The test keeps the session alive long enough for multiple keepalive ticks, verifies that the full payload reaches the server, and requires the reply and close path to complete cleanly. That gives the transport suite coverage for timer-driven libssh traffic interleaving with normal channel I/O, which is exactly the control-path overlap that would otherwise stay invisible until long-running stress. --- src/test_ssh_server.act | 251 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act index 239548b..d59c60f 100644 --- a/src/test_ssh_server.act +++ b/src/test_ssh_server.act @@ -1460,6 +1460,256 @@ def _test_ssh_concurrent_channel_writes(t: testing.EnvT): ConcurrentChannelWriteTester(t) +actor KeepaliveTrafficTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + PING = b"keepalive-ping" + ACK = b"keepalive-ack" + ROUNDS = 10 + KEEPALIVE_INTERVAL = 0.05 + SEND_DELAY = 0.06 + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var session: ?ssh.ServerSession = None + var client_channel: ?ssh.Channel = None + var client_close_requested = False + var server_close_requested = False + var client_closed = False + var server_closed = False + var session_closed = False + var server_channel_closed = False + var ack_received = False + var server_recv_bytes = 0 + + def total_bytes() -> int: + return len(PING) * ROUNDS + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if ack_received and server_recv_bytes == total_bytes() and \ + server_channel_closed and client_closed and session_closed and \ + server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + if session is not None: + session.close() + request_server_close() + t.error(Exception(msg)) + + def send_ping(remaining: int): + if done or remaining <= 0: + return + if client_channel is None: + finish_error("client channel missing during keepalive write burst") + return + if client_channel is not None: + client_channel.write(PING) + if remaining == 1: + client_channel.send_eof() + return + after SEND_DELAY: send_ping(remaining - 1) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for keepalive traffic test " + "(ack_received=" + str(ack_received) + + ", server_recv_bytes=" + str(server_recv_bytes) + + ", server_channel_closed=" + str(server_channel_closed) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < LISTEN_RETRY_LIMIT and _is_retryable_listen_error(err): + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + session = sess + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "echo": + finish_error("unexpected subsystem request: " + name) + return + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_recv_bytes += len(data) + if server_recv_bytes > total_bytes(): + finish_error("server received too much data during keepalive test") + return + if server_recv_bytes == total_bytes(): + ch.write(ACK) + ch.send_eof() + ch.close() + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected stderr during keepalive test") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_closed = True + if session is not None: + session.close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + maybe_finish() + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + client_channel = ch + ch.request_subsystem("echo") + after SEND_DELAY: send_ping(ROUNDS) + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + if data != ACK: + finish_error("client received unexpected keepalive reply") + return + if ack_received: + finish_error("client received duplicate keepalive reply") + return + ack_received = True + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client received unexpected stderr during keepalive test") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + log.info("client channel closed: " + reason, None) + request_client_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + keepalive_interval=KEEPALIVE_INTERVAL, + ) + + def start_server(): + attempts += 1 + port = _pick_test_port(36000, 24000, attempts, 271) + server_close_requested = False + client_close_requested = False + client_closed = False + server_closed = False + session_closed = False + server_channel_closed = False + ack_received = False + server_recv_bytes = 0 + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + keepalive_interval=KEEPALIVE_INTERVAL, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_keepalive_traffic(t: testing.EnvT): + """Keepalives should not disrupt an active channel.""" + KeepaliveTrafficTester(t) + + actor CrossHandleOwnershipTester(t: testing.EnvT): log = logging.Logger(t.log_handler) @@ -5606,6 +5856,7 @@ __env_tests: dict[str, testing.EnvTest] = { "_test_ServerSessionCloseFlushTester": testing.EnvTest(_test_ssh_server_session_close_flush, "_test_ServerSessionCloseFlushTester", "", __modname__), "_test_ConcurrentChannelTester": testing.EnvTest(_test_ssh_concurrent_channels, "_test_ConcurrentChannelTester", "", __modname__), "_test_ConcurrentChannelWriteTester": testing.EnvTest(_test_ssh_concurrent_channel_writes, "_test_ConcurrentChannelWriteTester", "", __modname__), + "_test_KeepaliveTrafficTester": testing.EnvTest(_test_ssh_keepalive_traffic, "_test_KeepaliveTrafficTester", "", __modname__), "_test_CrossHandleOwnershipTester": testing.EnvTest(_test_ssh_cross_handle_ownership, "_test_CrossHandleOwnershipTester", "", __modname__), "_test_UnsupportedRequestTester": testing.EnvTest(_test_ssh_unsupported_requests, "_test_UnsupportedRequestTester", "", __modname__), "_test_RunCommandTimeoutTester": testing.EnvTest(_test_ssh_runcommand_timeout, "_test_RunCommandTimeoutTester", "", __modname__), From b5d2c2f58e320d8b461950fd1b61c55b4aca13e9 Mon Sep 17 00:00:00 2001 From: Kristian Larsson Date: Tue, 5 May 2026 21:57:54 +0200 Subject: [PATCH 38/38] Add Build.act --- Build.act | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 Build.act diff --git a/Build.act b/Build.act new file mode 100644 index 0000000..d12b2b6 --- /dev/null +++ b/Build.act @@ -0,0 +1,9 @@ +name = "ssh" +fingerprint = 0xee8dcc44255e0a2e +zig_dependencies = { + "libssh": ( + path="../acton-deps/libssh", + options={"WITH_SERVER": "true"}, + artifacts=["ssh"] + ) +}