diff --git a/Build.act b/Build.act new file mode 100644 index 0000000..d12b2b6 --- /dev/null +++ b/Build.act @@ -0,0 +1,9 @@ +name = "ssh" +fingerprint = 0xee8dcc44255e0a2e +zig_dependencies = { + "libssh": ( + path="../acton-deps/libssh", + options={"WITH_SERVER": "true"}, + artifacts=["ssh"] + ) +} diff --git a/build.act.json b/build.act.json index ec3c9b7..ac51969 100644 --- a/build.act.json +++ b/build.act.json @@ -2,8 +2,10 @@ "dependencies": {}, "zig_dependencies": { "libssh": { - "url": "https://github.com/actonlang/libssh/archive/refs/heads/zig-build.tar.gz", - "hash": "122076b4676ec7d777e5efc24ae4b7f1ab4fa3df407ab9649fdd55f9fc594ca053ec", + "path": "../acton-deps/libssh", + "options": { + "WITH_SERVER": "true" + }, "artifacts": [ "ssh" ] diff --git a/src/ssh.act b/src/ssh.act index a8b0268..eb5b29e 100644 --- a/src/ssh.act +++ b/src/ssh.act @@ -1,69 +1,590 @@ +""" +Acton SSH module + +This module exposes a callback-driven SSH client/server API built on libssh +and integrated with Acton's libuv loop. The main actors are: + +- Client: connects to a server, reports on_connect(err) and on_close(reason), + and optionally on_hostkey(state, HostKeyInfo) for strict host key checks. +- Channel: client-side data stream (stdout/stderr/exit/close callbacks). +- RunCommand: convenience wrapper that buffers stdout/stderr and reports exit. +- Server: listens for incoming connections and creates ServerSession actors. +- ServerSession: per-connection state; emits on_auth(AuthRequest), + on_channel_open, and optional on_exec/on_subsystem. +- ServerChannel: server-side data stream for a single channel. + +Usage sketch (client): + def on_connect(c, err): + if err is not None: ... + Channel(c, on_open, on_stdout, on_stderr, on_exit, on_close) + + def on_hostkey(c, state, info): + if state == HOSTKEY_OK: c.accept_hostkey() + else: c.reject_hostkey("...") + +Usage sketch (server): + def on_auth(sess, req): sess.accept_auth() + def on_channel_open(sess): sess.accept_channel(ServerChannel(...)) + def on_subsystem(sess, ch, name): ch.accept_request(); ch.write(...) + +Notes: + - All operations are async actor actions; callbacks receive errors as ?str. + - ?bytes callbacks use None to signal EOF. + - Timeouts and keepalive are configurable; host keys are generated in-memory + unless host_key_path is set. + - Server admission limits are configurable; values <= 0 disable a limit. + - For debugging, set ACTON_SSH_DEBUG and ACTON_SSH_LIBSSH_LOG. +""" + import net -import testing + def version() -> str: """Get the libssh version""" NotImplemented -def _test_version(): - testing.assertEqual("0.11.0", version()) + +def _debug(msg: str) -> None: + """Internal debug logging (gated by ACTON_SSH_DEBUG).""" + NotImplemented + + +HOSTKEY_OK = "ok" +HOSTKEY_UNKNOWN = "unknown" +HOSTKEY_NOT_FOUND = "not_found" +HOSTKEY_CHANGED = "changed" +HOSTKEY_OTHER = "other" +HOSTKEY_ERROR = "error" + +class HostKeyInfo(): + """Server host key details""" + key_type: str + fingerprint: str + + def __init__(self, key_type: str, fingerprint: str): + self.key_type = key_type + self.fingerprint = fingerprint + + +class AuthRequest(): + """Authentication request details""" + method: str + user: str + password: ?str + pubkey: ?bytes + + def __init__(self, method: str, user: str, password: ?str=None, pubkey: ?bytes=None): + self.method = method + self.user = user + self.password = password + self.pubkey = pubkey + actor Client(cap: net.TCPConnectCap, host: str, username: str, - on_connect: action(Client) -> None, + on_connect: action(Client, ?str) -> None, on_close: action(Client, str) -> None, - key: ?str=None, + on_hostkey: ?action(Client, str, HostKeyInfo) -> None=None, password: ?str=None, port: u16=22, + known_hosts: ?str=None, + connect_timeout: float=10.0, + auth_timeout: float=10.0, + keepalive_interval: float=30.0, + keepalive_enabled: bool=True, ): """SSH Client""" # haha, this is really a pointer :P - var _ssh_session: u64 = 0 + var _client: u64 = 0 + + proc def _pin_affinity() -> None: + NotImplemented + _pin_affinity() proc def _init() -> None: """Initialize the SSH client""" NotImplemented _init() - print("SSH Client connected") -# action def close(on_close: action(TLSConnection) -> None) -> None: -# """Close the connection""" -# NotImplemented -# -# def reconnect(): -# close(_connect) -# -# def _connect(c): -# NotImplemented -# + action def accept_hostkey() -> None: + """Accept the currently pending host key (not persisted)""" + NotImplemented + + action def reject_hostkey(reason: str) -> None: + """Reject the currently pending host key""" + NotImplemented + + action def close() -> None: + """Close the SSH client""" + NotImplemented + + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _client != 0: + _cleanup_native() + + # Internal channel operations (used by Channel actor) + action def channel_create(channel: Channel, + on_open: action(Channel, ?str) -> None, + on_stdout: action(Channel, ?bytes) -> None, + on_stderr: action(Channel, ?bytes) -> None, + on_exit: action(Channel, int, ?str) -> None, + on_close: action(Channel, str) -> None) -> None: + NotImplemented + + action def channel_request_exec(channel: Channel, cmd: str) -> None: + NotImplemented + + action def channel_request_shell(channel: Channel, + term: str, + cols: int, + rows: int, + width_px: int, + height_px: int, + with_pty: bool) -> None: + NotImplemented + + action def channel_request_subsystem(channel: Channel, name: str) -> None: + NotImplemented + + action def channel_write(channel: Channel, data: bytes) -> None: + NotImplemented + + action def channel_send_eof(channel: Channel) -> None: + NotImplemented + + action def channel_close(channel: Channel) -> None: + NotImplemented + + +actor Channel(client: Client, + on_open: action(Channel, ?str) -> None, + on_stdout: action(Channel, ?bytes) -> None, + on_stderr: action(Channel, ?bytes) -> None, + on_exit: action(Channel, int, ?str) -> None, + on_close: action(Channel, str) -> None): + """SSH Channel""" + + # pointer to internal channel state + var _channel_id: u64 = 0 + var _on_open: action(Channel, ?str) -> None = on_open + var _on_stdout: action(Channel, ?bytes) -> None = on_stdout + var _on_stderr: action(Channel, ?bytes) -> None = on_stderr + var _on_exit: action(Channel, int, ?str) -> None = on_exit + var _on_close: action(Channel, str) -> None = on_close + + proc def _init() -> None: + client.channel_create(self, _on_open, _on_stdout, _on_stderr, _on_exit, _on_close) + _init() + + action def request_exec(cmd: str) -> None: + """Request to execute a command""" + client.channel_request_exec(self, cmd) + + action def request_shell(term: str="xterm-256color", + cols: int=80, + rows: int=24, + width_px: int=0, + height_px: int=0, + with_pty: bool=True) -> None: + """Request an interactive shell""" + client.channel_request_shell(self, term, cols, rows, width_px, height_px, with_pty) + + action def request_subsystem(name: str) -> None: + """Request a subsystem""" + client.channel_request_subsystem(self, name) + + action def write(data: bytes) -> None: + """Write data to the channel""" + client.channel_write(self, data) + + action def send_eof() -> None: + """Half-close writes while keeping the read side open""" + client.channel_send_eof(self) + + action def close() -> None: + """Close the channel and discard unread inbound data""" + client.channel_close(self) + + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _channel_id != 0: + _cleanup_native() + + +actor RunCommand(client: Client, + cmd: str, + on_exit: action(Channel, int, ?str, bytes, bytes, ?str) -> None, + timeout: ?float=None): + """Run a command and collect output""" + + var out_buf = b"" + var err_buf = b"" + var _out_done = False + var _err_done = False + var _exited = False + var _exit_code = 0 + var _exit_signal: ?str = None + var _error: ?str = None + var _done = False + + def _finish(ch: Channel): + if _done: + return + _done = True + on_exit(ch, _exit_code, _exit_signal, out_buf, err_buf, _error) + + def _on_open(ch: Channel, err: ?str): + if _done: + return + if err is not None: + _error = err + _finish(ch) + return + ch.request_exec(cmd) + + def _on_stdout(ch: Channel, data: ?bytes): + if _done: + return + if data is not None: + out_buf += data + else: + _out_done = True + _check_done(ch) + + def _on_stderr(ch: Channel, data: ?bytes): + if _done: + return + if data is not None: + err_buf += data + else: + _err_done = True + _check_done(ch) + + def _on_exit(ch: Channel, code: int, sig: ?str): + if _done: + return + _exited = True + _exit_code = code + _exit_signal = sig + _check_done(ch) + + def _on_close(ch: Channel, reason: str): + if _done: + return + if _error is None and reason != "closed": + _error = reason + if _error is not None: + _finish(ch) + return + + def _check_done(ch: Channel): + if _out_done and _err_done and _exited and _error is None: + _finish(ch) + + var _channel = Channel(client, _on_open, _on_stdout, _on_stderr, _on_exit, _on_close) + + if timeout is not None: + def _on_timeout(): + if _done: + return + _error = "timeout" + _finish(_channel) + _channel.close() + after timeout: _on_timeout() + + +actor Server(cap: net.TCPListenCap, + host: str, + port: u16, + on_listen: action(Server, ?str) -> None, + on_close: action(Server, str) -> None, + on_session: action(ServerSession) -> None, + on_auth: action(ServerSession, AuthRequest) -> None, + on_channel_open: action(ServerSession) -> None, + on_exec: ?action(ServerSession, ServerChannel, str) -> None=None, + on_subsystem: ?action(ServerSession, ServerChannel, str) -> None=None, + on_session_close: ?action(ServerSession, str) -> None=None, + host_key_path: ?str=None, + host_key_type: str="ed25519", + host_key_bits: int=2048, + auth_timeout: float=10.0, + keepalive_interval: float=30.0, + keepalive_enabled: bool=True, + max_sessions: int=128, + max_channels_per_session: int=32, + ): + """SSH Server""" + + var _server: u64 = 0 + var _host: str = host + var _port: u16 = port + var _host_key_path: ?str = host_key_path + var _host_key_type: str = host_key_type + var _host_key_bits: int = host_key_bits + var _auth_timeout: float = auth_timeout + var _keepalive_interval: float = keepalive_interval + var _keepalive_enabled: bool = keepalive_enabled + var _max_sessions: int = max_sessions + var _max_channels_per_session: int = max_channels_per_session + var _on_listen: action(Server, ?str) -> None = on_listen + var _on_close: action(Server, str) -> None = on_close + + proc def _pin_affinity() -> None: + NotImplemented + _pin_affinity() + + proc def _init() -> None: + """Initialize the SSH server""" + NotImplemented + _init() + + action def close() -> None: + """Close the SSH server""" + NotImplemented -# TODO: implement support for channels -# AFAIK, all things over ssh are done via channels, so need some channel -# primitive, maybe an actor per channel that then multiplexes into the Client -# session? Prolly need some higher level wrappers for common things like -# starting a shell or running a single command. SFTP / SCP would be nice too, -# but for sometime in the future. Custom subsystems need to be supported too. + proc def _cleanup_native() -> None: + NotImplemented + action def __cleanup__() -> None: + if _server != 0: + _cleanup_native() + action def on_session_pending(session_id: u64) -> None: + _debug("server on_session_pending: " + str(session_id)) + ServerSession(self, + session_id, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close) + _debug("server on_session_pending: created session actor for " + str(session_id)) + + action def on_session_ready(session: ServerSession) -> None: + on_session(session) + + +actor ServerSession(server: Server, + session_id: u64, + on_auth: action(ServerSession, AuthRequest) -> None, + on_channel_open: action(ServerSession) -> None, + on_exec: ?action(ServerSession, ServerChannel, str) -> None=None, + on_subsystem: ?action(ServerSession, ServerChannel, str) -> None=None, + on_close: ?action(ServerSession, str) -> None=None): + """SSH Server Session""" + + var _session_id: u64 = 0 + var _on_auth: action(ServerSession, AuthRequest) -> None = on_auth + var _on_channel_open: action(ServerSession) -> None = on_channel_open + var _on_exec: ?action(ServerSession, ServerChannel, str) -> None = on_exec + var _on_subsystem: ?action(ServerSession, ServerChannel, str) -> None = on_subsystem + var _on_close: ?action(ServerSession, str) -> None = on_close + + proc def _pin_affinity() -> None: + NotImplemented + _pin_affinity() + + proc def _attach(session_id: u64) -> None: + NotImplemented + + proc def _drive_attached() -> None: + NotImplemented + + action def _attach_ready() -> None: + _debug("server session attach ready: " + str(session_id)) + _attach(session_id) + _debug("server session attach: _attach returned for " + str(session_id)) + if _session_id != 0: + server.on_session_ready(self) + _debug("server session attach: on_session_ready sent for " + str(session_id)) + _drive_attached() + + after 0: _attach_ready() + + action def accept_auth() -> None: + """Accept the pending authentication request""" + NotImplemented + + action def reject_auth(reason: str) -> None: + """Reject the pending authentication request""" + NotImplemented + + action def accept_channel(channel: ServerChannel) -> None: + """Accept the pending channel open request""" + channel.accept_open() + + action def accept_channel_open(channel: ServerChannel, + on_data: action(ServerChannel, ?bytes) -> None, + on_stderr: action(ServerChannel, ?bytes) -> None, + on_close: action(ServerChannel, str) -> None) -> None: + NotImplemented + + action def reject_channel(reason: str) -> None: + """Reject the pending channel open request""" + NotImplemented + + action def close() -> None: + """Close the SSH session""" + NotImplemented + + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _session_id != 0: + _cleanup_native() + + # Internal channel operations (used by ServerChannel actor) + action def channel_accept_request(channel: ServerChannel) -> None: + NotImplemented + + action def channel_reject_request(channel: ServerChannel, reason: str) -> None: + NotImplemented + + action def channel_write(channel: ServerChannel, data: bytes) -> None: + NotImplemented + + action def channel_write_stderr(channel: ServerChannel, data: bytes) -> None: + NotImplemented + + action def channel_send_eof(channel: ServerChannel) -> None: + NotImplemented + + action def channel_send_exit_status(channel: ServerChannel, status: int) -> None: + NotImplemented + + action def channel_close(channel: ServerChannel) -> None: + NotImplemented + + +actor ServerChannel(session: ServerSession, + on_data: action(ServerChannel, ?bytes) -> None, + on_stderr: action(ServerChannel, ?bytes) -> None, + on_close: action(ServerChannel, str) -> None): + """SSH Server Channel""" + + var _channel_id: u64 = 0 + var _on_data: action(ServerChannel, ?bytes) -> None = on_data + var _on_stderr: action(ServerChannel, ?bytes) -> None = on_stderr + var _on_close: action(ServerChannel, str) -> None = on_close + + action def accept_request() -> None: + """Accept the pending channel request""" + session.channel_accept_request(self) + + action def accept_open() -> None: + """Accept the pending channel open request""" + session.accept_channel_open(self, _on_data, _on_stderr, _on_close) + + action def reject_request(reason: str) -> None: + """Reject the pending channel request""" + session.channel_reject_request(self, reason) + + action def write(data: bytes) -> None: + """Write data to the channel (stdout)""" + session.channel_write(self, data) + + action def write_stderr(data: bytes) -> None: + """Write data to the channel (stderr)""" + session.channel_write_stderr(self, data) + + action def send_eof() -> None: + """Half-close writes while keeping the read side open""" + session.channel_send_eof(self) + + action def send_exit_status(status: int) -> None: + """Send an exit status for the channel""" + session.channel_send_exit_status(self, status) + + action def close() -> None: + """Close the channel and discard unread inbound data""" + session.channel_close(self) + + proc def _cleanup_native() -> None: + NotImplemented + + action def __cleanup__() -> None: + if _channel_id != 0: + _cleanup_native() actor main(env): - def on_connect(client: Client): - print("Connected") + NETCONF_HELLO = b'\n' + \ + b'\n' + \ + b' \n' + \ + b' urn:ietf:params:netconf:base:1.0\n' + \ + b' \n' + \ + b']]>]]>' + var out_buf = b"" + var c: ?Client = None - def on_close(client: Client, error: str): - print("Error", error) + def on_connect(client: Client, err: ?str): + if err is not None: + print("SSH error", err) + env.exit(1) + return + print("SSH connected") - print(version()) - c = Client( + def on_close(client: Client, reason: str): + print("SSH closed", reason) + env.exit(0) + + def on_hostkey(client: Client, state: str, info: HostKeyInfo): + print("Host key", state, info.key_type, info.fingerprint) + # WARNING: accepting any host key is insecure; do this only for testing + client.accept_hostkey() + + def ch_open(ch: Channel, err: ?str): + if err is not None: + print("Channel open error", err) + if c is not None: + c.close() + return + ch.request_subsystem("netconf") + ch.write(NETCONF_HELLO) + + def ch_out(ch: Channel, data: ?bytes): + if data is not None: + out_buf += data + if out_buf.find(b"]]>]]>") >= 0: + print(out_buf) + ch.close() + + def ch_err(ch: Channel, data: ?bytes): + if data is not None: + print(data) + + def ch_exit(ch: Channel, code: int, sig: ?str): + print("exit", code, sig) + + def ch_close(ch: Channel, reason: str): + print("channel closed", reason) + if c is not None: + c.close() + + client = Client( net.TCPConnectCap(net.TCPCap(net.NetCap(env.cap))), - "localhost", - "foo", + "127.0.0.1", + "admin", on_connect, on_close, - password="bar", - port=2223, + on_hostkey, + password="admin", + port=42830, + known_hosts="/tmp/acton_ssh_known_hosts", ) - env.exit(0) + + c = client + Channel(client, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def _on_close_timeout(): + if c is not None: + c.close() + after 10.0: _on_close_timeout() diff --git a/src/ssh.ext.c b/src/ssh.ext.c index 2112598..0b3ea10 100644 --- a/src/ssh.ext.c +++ b/src/ssh.ext.c @@ -1,120 +1,3941 @@ +/* + * Acton <-> libssh integration overview + * + * This file is the external-C glue that drives libssh from Acton's libuv loop + * and exposes it to Acton actors. The core goals are: + * - Nonblocking SSH I/O integrated with libuv (no blocking syscalls). + * - Actor-safe, async callback-driven API in Acton. + * - GC-safe memory: libssh allocations use Acton's allocator. + * + * Event loop integration + * - Each libssh session is created nonblocking. + * - We attach a uv_poll watcher to the libssh socket fd. + * - On poll events we call ssh_session_handle_poll() (via + * session_apply_poll_events), then drive a small state machine + * (connect/auth/ready for client, keyex/auth/ready for server). + * - ssh_get_poll_flags()/ssh_get_status() decide which poll events to arm. + * + * Buffered data + SSH_AGAIN (why we keep driving without fd readability) + * - libssh maintains its own internal buffers. After a poll callback, libssh + * may have already read bytes into those buffers even though the socket is + * no longer readable at the OS level. + * - When a nonblocking API returns SSH_AGAIN and ssh_get_status() includes + * SSH_READ_PENDING, it means "call again, there is buffered data to + * process" even if the fd will not trigger another readable event. + * - If we only wait for uv_poll readability, we can deadlock: + * 1) uv_poll READABLE fires; ssh_session_handle_poll() drains the fd. + * 2) ssh_connect()/ssh_handle_key_exchange()/ssh_userauth_password() + * returns SSH_AGAIN. + * 3) No more kernel readability events happen, but libssh still has + * buffered protocol bytes (SSH_READ_PENDING). + * 4) We wait for an event that never comes and eventually time out. + * - The fix is to keep driving the state machine in a bounded loop while + * SSH_READ_PENDING is set, even without fd readability. + * We cap iterations with SSH_IO_PUMP_LIMIT to avoid CPU spin. + * + * Channel I/O + * - SSH channels carry two streams: "data" and "extended data". We expose + * these as stdout/stderr callbacks (client on_stdout/on_stderr, server + * on_data/on_stderr). This is protocol-level stdout/stderr, not host OS + * process stdio. + * - Channels install libssh callbacks for data/extended-data/EOF/close. + * - Inbound data always flows through these callbacks. As we drive libssh + * (via ssh_session_handle_poll), libssh invokes the registered C callback + * functions, and those callbacks call the corresponding Acton action + * methods (foo->$class->on_stdout/on_stderr/on_close, etc.). We do not run + * manual read loops; libssh owns buffering and read state. + * - Channel writes are queued and flushed when libssh reports write pending. + * + * Actor/GC/threading model + * - Client and ServerSession actors own libssh state and are pinned to a + * worker thread. Channel actors invoke action methods on their owning + * Client/ServerSession actor for all operations; there is no hidden + * cross-actor C magic. + * - We replace libssh allocators with Acton's GC allocator so libuv/GC + * roots remain visible (libssh structures can reference GC memory). + * + * Config & filesystem + * - libssh config processing is disabled by default; known_hosts is only + * read if explicitly configured by the Acton API. + * - Server host keys are generated in-memory unless a path is provided. + */ +#include +#include #include -#include -// TODO: figure out how to include rts/log so we get access to log_error etc +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -void noop_free(void *ptr) { +#include "rts/log.h" + +uv_loop_t *get_uv_loop(void); +extern struct $Cont $Done$instance; + +#define SSH_READ_BUFSIZE 4096 +#define SSH_IO_PUMP_LIMIT 128 +#define SSH_ATTACH_TIMEOUT_SEC 5.0 +#define SSH_KEYEX_TIMEOUT_SEC 2.0 +#define SSH_SERVER_ACCEPT_LIMIT 64 +static int ssh_debug_enabled = 0; +static int ssh_libssh_log_level = SSH_LOG_NOLOG; + +typedef enum { + CLIENT_STATE_INIT = 0, + CLIENT_STATE_CONNECTING, + CLIENT_STATE_HOSTKEY, + CLIENT_STATE_HOSTKEY_WAIT, + CLIENT_STATE_AUTH, + CLIENT_STATE_READY, + CLIENT_STATE_ERROR, + CLIENT_STATE_CLOSING, + CLIENT_STATE_CLOSED, +} client_state_t; + +typedef enum { + CHAN_STATE_INIT = 0, + CHAN_STATE_OPENING, + CHAN_STATE_OPEN, + CHAN_STATE_RUNNING, + CHAN_STATE_CLOSING, + CHAN_STATE_CLOSED, + CHAN_STATE_ERROR, +} channel_state_t; + +typedef enum { + CHAN_REQ_NONE = 0, + CHAN_REQ_SHELL, + CHAN_REQ_EXEC, + CHAN_REQ_SUBSYSTEM, +} channel_request_t; + +typedef struct write_chunk { + B_bytes data; + size_t offset; + struct write_chunk *next; +} write_chunk_t; + +typedef struct ssh_channel_ctx { + struct ssh_channel_ctx *next; + ssh_channel channel; + struct ssh_client_ctx *client; + struct ssh_channel_callbacks_struct *callbacks; + sshQ_Channel actor; + channel_state_t state; + channel_request_t pending_req; + int pty_pending; + int pty_done; + B_str exec_cmd; + B_str subsystem; + B_str term; + int cols; + int rows; + int width_px; + int height_px; + int send_eof; + int eof_sent; + int close_requested; + int close_sent; + int remote_close_seen; + int write_wontblock; + int stdout_eof; + int stderr_eof; + int exit_sent; + int open_notified; + int open_succeeded; + int close_notified; + $action2 on_open; + $action2 on_close; + $action2 on_stdout; + $action2 on_stderr; + $action3 on_exit; + write_chunk_t *write_head; + write_chunk_t *write_tail; +} ssh_channel_ctx; + +typedef struct ssh_client_ctx { + GC_hidden_pointer actor; + ssh_session session; + uv_poll_t *poll; + int poll_events; + uv_timer_t *connect_timer; + uv_timer_t *auth_timer; + uv_timer_t *keepalive_timer; + double connect_timeout; + double auth_timeout; + double keepalive_interval; + int keepalive_enabled; + client_state_t state; + int fd; + int connect_notified; + int connected_ok; + int close_notified; + int close_finalized; + int close_force; + int write_ready; + char *close_reason; + enum ssh_known_hosts_e hostkey_state; + ssh_channel_ctx *channels; + ssh_channel_ctx *retired_channels; + $action2 on_connect; + $action2 on_close; + $action3 on_hostkey; +} ssh_client_ctx; + +typedef enum { + SERVER_STATE_INIT = 0, + SERVER_STATE_LISTENING, + SERVER_STATE_ERROR, + SERVER_STATE_CLOSING, + SERVER_STATE_CLOSED, +} server_state_t; + +typedef enum { + SESSION_STATE_PENDING = 0, + SESSION_STATE_KEYEX, + SESSION_STATE_AUTH, + SESSION_STATE_READY, + SESSION_STATE_ERROR, + SESSION_STATE_CLOSING, + SESSION_STATE_CLOSED, +} session_state_t; + +typedef enum { + SCHAN_STATE_OPEN = 0, + SCHAN_STATE_CLOSING, + SCHAN_STATE_CLOSED, + SCHAN_STATE_ERROR, +} schan_state_t; + +typedef enum { + SCHAN_REQ_NONE = 0, + SCHAN_REQ_EXEC, + SCHAN_REQ_SUBSYSTEM, +} schan_req_t; + +typedef struct server_write_chunk { + B_bytes data; + size_t offset; + int is_stderr; + struct server_write_chunk *next; +} server_write_chunk_t; + +typedef struct ssh_server_channel_ctx { + struct ssh_server_channel_ctx *next; + ssh_channel channel; + struct ssh_channel_callbacks_struct *callbacks; + struct ssh_server_session_ctx *session; + sshQ_ServerChannel actor; + schan_state_t state; + int send_eof; + int close_requested; + int close_sent; + int remote_close_seen; + int write_wontblock; + int eof_sent; + int stdout_eof; + int stderr_eof; + int close_notified; + ssh_message pending_req; + schan_req_t pending_req_type; + $action2 on_data; + $action2 on_stderr; + $action2 on_close; + server_write_chunk_t *write_head; + server_write_chunk_t *write_tail; +} ssh_server_channel_ctx; + +typedef struct ssh_server_session_ctx { + struct ssh_server_session_ctx *next; + sshQ_ServerSession actor; + struct ssh_server_ctx *server; + ssh_session session; + uv_poll_t *poll; + int poll_events; + uv_timer_t *attach_timer; + uv_timer_t *auth_timer; + uv_timer_t *keepalive_timer; + double auth_timeout; + double keepalive_interval; + int keepalive_enabled; + session_state_t state; + int attached; + int fd; + int owner_wt; + int write_ready; + int close_notified; + int close_finalized; + int close_force; + uint64_t pending_id; + char *close_reason; + ssh_message pending_auth; + ssh_message pending_channel_open; + ssh_server_channel_ctx *channels; + ssh_server_channel_ctx *retired_channels; + $action2 on_auth; + $action on_channel_open; + $action3 on_exec; + $action3 on_subsystem; + $action2 on_close; +} ssh_server_session_ctx; + +typedef struct ssh_server_ctx { + GC_hidden_pointer actor; + ssh_bind bind; + ssh_key hostkey; + uv_poll_t *poll; + int fd; + server_state_t state; + int max_sessions; + int max_channels_per_session; + int listen_notified; + int listen_ok; + int close_notified; + int close_finalized; + char *close_reason; + ssh_server_session_ctx *sessions; + $action2 on_listen; + $action2 on_close; +} ssh_server_ctx; + +static void client_drive(ssh_client_ctx *c); +static void client_update_poll(ssh_client_ctx *c); +static void client_close_internal(ssh_client_ctx *c, const char *reason, int force_close); +static void client_finalize(ssh_client_ctx *c); +static void client_finish_close(ssh_client_ctx *c); +static void client_maybe_release(ssh_client_ctx *c); +static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch); +static int client_needs_write(ssh_client_ctx *c); +static void client_pump_io(ssh_client_ctx *c); + +static void server_accept(ssh_server_ctx *s); +static void server_close_internal(ssh_server_ctx *s, const char *reason); +static void server_finalize(ssh_server_ctx *s); +static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *sess); +static void server_maybe_release(ssh_server_ctx *s); +static void session_drive(ssh_server_session_ctx *s); +static void session_update_poll(ssh_server_session_ctx *s); +static void session_pump_io(ssh_server_session_ctx *s); +static void session_close_internal(ssh_server_session_ctx *s, const char *reason, int force_close); +static void session_finalize(ssh_server_session_ctx *s); +static void session_finish_close(ssh_server_session_ctx *s); +static void session_maybe_release(ssh_server_session_ctx *s); +static void session_fail(ssh_server_session_ctx *s, const char *msg); +static void session_start_attach_timer(ssh_server_session_ctx *s); +static void session_start_keyex_timer(ssh_server_session_ctx *s); +static int session_start_poll(ssh_server_session_ctx *s, char *errmsg, size_t errmsg_len); +static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch); +static int session_needs_write(ssh_server_session_ctx *s); +static void session_poll_cb(uv_poll_t *handle, int status, int events); +static void server_poll_cb(uv_poll_t *handle, int status, int events); +static void session_auth_timeout_cb(uv_timer_t *timer); +static void session_keepalive_cb(uv_timer_t *timer); +static void client_poll_close_cb(uv_handle_t *handle); +static void client_timer_close_cb(uv_handle_t *handle); +static void server_poll_close_cb(uv_handle_t *handle); +static void session_poll_close_cb(uv_handle_t *handle); +static void session_timer_close_cb(uv_handle_t *handle); + +#define STORE_HIDDEN_PTR(slot, ptr) \ + ((slot) = (ptr) ? (GC_hidden_pointer)GC_HIDE_POINTER(ptr) : (GC_hidden_pointer)0) +#define LOAD_HIDDEN_PTR(type, slot) \ + ((slot) ? (type)GC_REVEAL_POINTER(slot) : NULL) + +static sshQ_Client client_actor_ref(const ssh_client_ctx *c) { + return c ? LOAD_HIDDEN_PTR(sshQ_Client, c->actor) : NULL; +} + +static sshQ_Channel channel_actor_ref(const ssh_channel_ctx *ch) { + return ch ? ch->actor : NULL; +} + +static sshQ_Server server_actor_ref(const ssh_server_ctx *s) { + return s ? LOAD_HIDDEN_PTR(sshQ_Server, s->actor) : NULL; +} + +static sshQ_ServerSession session_actor_ref(const ssh_server_session_ctx *s) { + return s ? s->actor : NULL; +} + +static sshQ_ServerChannel server_channel_actor_ref(const ssh_server_channel_ctx *ch) { + return ch ? ch->actor : NULL; +} + +static void ssh_debug_log(const char *fmt, ...) { + if (!ssh_debug_enabled) + return; + va_list ap; + va_start(ap, fmt); + vfprintf(stderr, fmt, ap); + va_end(ap); + fprintf(stderr, "\n"); + fflush(stderr); +} + +static int parse_libssh_log_level(const char *value) { + if (value == NULL || value[0] == '\0') + return SSH_LOG_NOLOG; + char *endptr = NULL; + long lvl = strtol(value, &endptr, 10); + if (endptr != value && endptr && *endptr == '\0') { + if (lvl < 0) + return SSH_LOG_NOLOG; + if (lvl > SSH_LOG_TRACE) + lvl = SSH_LOG_TRACE; + return (int)lvl; + } + if (strcasecmp(value, "warn") == 0 || strcasecmp(value, "warning") == 0) + return SSH_LOG_WARN; + if (strcasecmp(value, "info") == 0 || strcasecmp(value, "protocol") == 0) + return SSH_LOG_INFO; + if (strcasecmp(value, "debug") == 0 || strcasecmp(value, "packet") == 0) + return SSH_LOG_DEBUG; + if (strcasecmp(value, "trace") == 0 || strcasecmp(value, "functions") == 0) + return SSH_LOG_TRACE; + if (strcasecmp(value, "none") == 0 || strcasecmp(value, "off") == 0) + return SSH_LOG_NOLOG; + return SSH_LOG_NOLOG; +} + +static void ssh_log_cb(int priority, const char *function, const char *buffer, void *userdata) { + (void)userdata; + if (buffer == NULL) + return; + fprintf(stderr, "libssh[%d] %s: %s\n", priority, function ? function : "", buffer); + fflush(stderr); +} + +static void ssh_configure_libssh_logging(void) { + if (ssh_libssh_log_level <= SSH_LOG_NOLOG) + return; + ssh_set_log_callback(ssh_log_cb); + ssh_set_log_level(ssh_libssh_log_level); +} + +static int session_apply_poll_events(ssh_session session, int events) { + if (session == NULL) + return -1; + int revents = 0; + if (events & UV_READABLE) + revents |= POLLIN; + if (events & UV_WRITABLE) + revents |= POLLOUT; +#ifdef UV_DISCONNECT + if (events & UV_DISCONNECT) + revents |= POLLHUP; +#endif +#ifdef UV_PRIORITIZED + if (events & UV_PRIORITIZED) + revents |= POLLPRI; +#endif + if (revents == 0) + return 0; + if (ssh_session_handle_poll(session, revents) != SSH_OK) + return -1; + return 0; +} + +static const char *hostkey_state_str(enum ssh_known_hosts_e state) { + switch (state) { + case SSH_KNOWN_HOSTS_OK: + return "ok"; + case SSH_KNOWN_HOSTS_UNKNOWN: + return "unknown"; + case SSH_KNOWN_HOSTS_NOT_FOUND: + return "not_found"; + case SSH_KNOWN_HOSTS_CHANGED: + return "changed"; + case SSH_KNOWN_HOSTS_OTHER: + return "other"; + case SSH_KNOWN_HOSTS_ERROR: + default: + return "error"; + } +} + +static ssh_client_ctx *client_from_actor(sshQ_Client self) { + if (self == NULL) + return NULL; + if (self->_client == NULL) + return NULL; + unsigned long ptr = fromB_u64(self->_client); + if (ptr == 0) + return NULL; + return (ssh_client_ctx *)ptr; +} + +static ssh_channel_ctx *channel_from_actor(sshQ_Channel channel) { + if (channel == NULL) + return NULL; + if (channel->_channel_id == NULL) + return NULL; + unsigned long ptr = fromB_u64(channel->_channel_id); + if (ptr == 0) + return NULL; + return (ssh_channel_ctx *)ptr; +} + +static ssh_server_ctx *server_from_actor(sshQ_Server self) { + if (self == NULL) + return NULL; + if (self->_server == NULL) + return NULL; + unsigned long ptr = fromB_u64(self->_server); + if (ptr == 0) + return NULL; + return (ssh_server_ctx *)ptr; +} + +static ssh_server_session_ctx *session_from_pending_token(ssh_server_ctx *server, B_u64 session_id) { + if (server == NULL) + return NULL; + uint64_t token = fromB_u64(session_id); + if (token == 0) + return NULL; + ssh_server_session_ctx *cur = server->sessions; + while (cur != NULL) { + if (cur->pending_id == token) + return cur; + cur = cur->next; + } + return NULL; +} + +static ssh_server_session_ctx *session_from_actor(sshQ_ServerSession self) { + if (self == NULL) + return NULL; + if (self->_session_id == NULL) + return NULL; + unsigned long ptr = fromB_u64(self->_session_id); + if (ptr == 0) + return NULL; + return (ssh_server_session_ctx *)ptr; +} + +static ssh_server_channel_ctx *server_channel_from_actor(sshQ_ServerChannel channel) { + if (channel == NULL) + return NULL; + if (channel->_channel_id == NULL) + return NULL; + unsigned long ptr = fromB_u64(channel->_channel_id); + if (ptr == 0) + return NULL; + return (ssh_server_channel_ctx *)ptr; +} + +static ssh_server_channel_ctx *server_channel_from_ssh(ssh_server_session_ctx *s, ssh_channel chan) { + if (s == NULL || chan == NULL) + return NULL; + ssh_server_channel_ctx *cur = s->channels; + while (cur != NULL) { + if (cur->channel == chan) + return cur; + cur = cur->next; + } + return NULL; +} + +static void close_poll(uv_poll_t **poll, uv_close_cb close_cb) { + if (*poll != NULL) { + if (uv_is_closing((uv_handle_t *)*poll)) + return; + uv_poll_stop(*poll); + uv_close((uv_handle_t *)*poll, close_cb); + } +} + +static void stop_timer(uv_timer_t **timer, uv_close_cb close_cb) { + if (*timer != NULL) { + if (uv_is_closing((uv_handle_t *)*timer)) + return; + uv_timer_stop(*timer); + uv_close((uv_handle_t *)*timer, close_cb); + } +} + +static int fd_has_data(int fd) { + if (fd < 0) + return 0; + char byte; + ssize_t rc; + do { + rc = recv(fd, &byte, 1, MSG_PEEK | MSG_DONTWAIT); + } while (rc < 0 && errno == EINTR); + if (rc > 0) + return 1; + if (rc == 0) + return 1; + if (errno == EAGAIN || errno == EWOULDBLOCK) + return 0; + if (errno == ECONNRESET || errno == ECONNABORTED || errno == ENOTCONN) + return 1; + if (errno == EBADF || errno == ENOTSOCK || errno == EINVAL) + return 0; + return 1; +} + +static int fd_can_write(int fd) { + if (fd < 0) + return 0; + struct pollfd pfd; + pfd.fd = fd; + pfd.events = POLLOUT; + pfd.revents = 0; + int rc; + do { + rc = poll(&pfd, 1, 0); + } while (rc < 0 && errno == EINTR); + if (rc <= 0) + return 0; + if (pfd.revents & POLLOUT) + return 1; + if (pfd.revents & POLLNVAL) + return 0; + return 0; +} + +static int fd_set_nonblocking(int fd) { + if (fd < 0) + return -1; + int flags = fcntl(fd, F_GETFL, 0); + if (flags < 0) + return -1; + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0) + return -1; + return 0; +} + +static void format_session_error(ssh_session session, const char *prefix, + char *buf, size_t buflen) { + const char *err = NULL; + if (session != NULL) + err = ssh_get_error(session); + if (err != NULL && err[0] != '\0') + snprintf(buf, buflen, "%s: %s", prefix, err); + else + snprintf(buf, buflen, "%s", prefix); +} + +static int session_has_pending_write(ssh_session session) { + if (session == NULL) + return 0; + int pending = ssh_get_status(session) | ssh_get_poll_flags(session); + return (pending & SSH_WRITE_PENDING) != 0; +} + +static uint64_t next_pending_session_id = 1; + +static uint64_t alloc_pending_session_id(void) { + return __atomic_fetch_add(&next_pending_session_id, 1, __ATOMIC_RELAXED); +} + +static void client_retire_channel(ssh_client_ctx *c, ssh_channel_ctx *ch) { + if (c == NULL || ch == NULL) + return; + ch->next = c->retired_channels; + c->retired_channels = ch; +} + +static void client_free_retired_channels(ssh_client_ctx *c) { + if (c == NULL) + return; + ssh_channel_ctx *ch = c->retired_channels; + c->retired_channels = NULL; + while (ch != NULL) { + ssh_channel_ctx *next = ch->next; + ch->next = NULL; + acton_free(ch); + ch = next; + } +} + +static void session_retire_channel(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { + if (s == NULL || ch == NULL) + return; + ch->next = s->retired_channels; + s->retired_channels = ch; +} + +static void session_free_retired_channels(ssh_server_session_ctx *s) { + if (s == NULL) + return; + ssh_server_channel_ctx *ch = s->retired_channels; + s->retired_channels = NULL; + while (ch != NULL) { + ssh_server_channel_ctx *next = ch->next; + ch->next = NULL; + acton_free(ch); + ch = next; + } +} + +static void client_poll_close_cb(uv_handle_t *handle) { + ssh_client_ctx *c = (ssh_client_ctx *)handle->data; + if (c == NULL) { + acton_free(handle); + return; + } + if (c->poll == (uv_poll_t *)handle) { + c->poll = NULL; + c->poll_events = 0; + } + if (c->state == CLIENT_STATE_CLOSING) { + client_finalize(c); + } + acton_free(handle); +} + +static void client_timer_close_cb(uv_handle_t *handle) { + ssh_client_ctx *c = (ssh_client_ctx *)handle->data; + if (c == NULL) { + acton_free(handle); + return; + } + if ((uv_timer_t *)handle == c->connect_timer) + c->connect_timer = NULL; + if ((uv_timer_t *)handle == c->auth_timer) + c->auth_timer = NULL; + if ((uv_timer_t *)handle == c->keepalive_timer) + c->keepalive_timer = NULL; + client_maybe_release(c); + acton_free(handle); +} + +static void server_poll_close_cb(uv_handle_t *handle) { + ssh_server_ctx *s = (ssh_server_ctx *)handle->data; + if (s == NULL) { + acton_free(handle); + return; + } + if (s->poll == (uv_poll_t *)handle) + s->poll = NULL; + if (s->state == SERVER_STATE_CLOSING) { + server_finalize(s); + } + acton_free(handle); +} + +static void session_poll_close_cb(uv_handle_t *handle) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; + if (s == NULL) { + acton_free(handle); + return; + } + if (s->poll == (uv_poll_t *)handle) { + s->poll = NULL; + s->poll_events = 0; + } + if (s->state == SESSION_STATE_CLOSING) { + session_finalize(s); + } + acton_free(handle); +} + +static void session_timer_close_cb(uv_handle_t *handle) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; + if (s == NULL) { + acton_free(handle); + return; + } + if ((uv_timer_t *)handle == s->attach_timer) + s->attach_timer = NULL; + if ((uv_timer_t *)handle == s->auth_timer) + s->auth_timer = NULL; + if ((uv_timer_t *)handle == s->keepalive_timer) + s->keepalive_timer = NULL; + session_maybe_release(s); + acton_free(handle); +} + +static void client_maybe_release(ssh_client_ctx *c) { + if (c == NULL || !c->close_finalized) + return; + if (c->poll != NULL || c->connect_timer != NULL || + c->auth_timer != NULL || c->keepalive_timer != NULL) + return; + if (c->close_reason != NULL) { + acton_free(c->close_reason); + c->close_reason = NULL; + } + acton_free(c); +} + +static void server_maybe_release(ssh_server_ctx *s) { + if (s == NULL || !s->close_finalized) + return; + if (s->poll != NULL || s->sessions != NULL) + return; + if (s->close_reason != NULL) { + acton_free(s->close_reason); + s->close_reason = NULL; + } + acton_free(s); +} + +static void session_maybe_release(ssh_server_session_ctx *s) { + if (s == NULL || !s->close_finalized) + return; + if (s->poll != NULL || s->attach_timer != NULL || + s->auth_timer != NULL || s->keepalive_timer != NULL) + return; + if (s->close_reason != NULL) { + acton_free(s->close_reason); + s->close_reason = NULL; + } + acton_free(s); +} + +static int server_session_count(ssh_server_ctx *s) { + int count = 0; + if (s == NULL) + return 0; + ssh_server_session_ctx *sess = s->sessions; + while (sess != NULL) { + count++; + sess = sess->next; + } + return count; +} + +static int session_channel_count(ssh_server_session_ctx *s) { + int count = 0; + if (s == NULL) + return 0; + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + count++; + ch = ch->next; + } + return count; +} + +static int server_session_limit_reached(ssh_server_ctx *s) { + if (s == NULL || s->max_sessions <= 0) + return 0; + return server_session_count(s) >= s->max_sessions; +} + +static int session_channel_limit_reached(ssh_server_session_ctx *s) { + if (s == NULL || s->server == NULL || s->server->max_channels_per_session <= 0) + return 0; + return session_channel_count(s) >= s->server->max_channels_per_session; +} + +static void client_notify_connect(ssh_client_ctx *c, const char *err) { + if (c == NULL) + return; + if (c->connect_notified) + return; + sshQ_Client actor = client_actor_ref(c); + if (c->on_connect) { + $action2 f = ($action2)c->on_connect; + f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None); + } + c->connect_notified = 1; + if (err == NULL) + c->connected_ok = 1; +} + +static void client_notify_close(ssh_client_ctx *c, const char *reason) { + if (c->close_notified) + return; + if (!c->connected_ok) + return; + sshQ_Client actor = client_actor_ref(c); + if (c->on_close) { + $action2 f = ($action2)c->on_close; + f->$class->__asyn__(f, actor, to$str((char *)reason)); + } + c->close_notified = 1; +} + +static void client_fail(ssh_client_ctx *c, const char *msg) { + if (c == NULL || c->state == CLIENT_STATE_CLOSED || c->state == CLIENT_STATE_ERROR) + return; + c->state = CLIENT_STATE_ERROR; + if (!c->connected_ok) + client_notify_connect(c, msg); + client_close_internal(c, msg, 1); +} + +static void channel_notify_open(ssh_channel_ctx *ch, const char *err) { + if (ch->open_notified) + return; + sshQ_Channel actor = channel_actor_ref(ch); + if (ch->on_open) { + $action2 f = ($action2)ch->on_open; + f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None); + } + ch->open_notified = 1; + if (err == NULL) + ch->open_succeeded = 1; +} + +static void channel_notify_close(ssh_channel_ctx *ch, const char *reason) { + if (ch->close_notified) + return; + if (!ch->open_succeeded) { + ch->close_notified = 1; + return; + } + sshQ_Channel actor = channel_actor_ref(ch); + if (ch->on_close) { + $action2 f = ($action2)ch->on_close; + f->$class->__asyn__(f, actor, to$str((char *)reason)); + } + ch->close_notified = 1; +} + +static void channel_notify_error(ssh_channel_ctx *ch, const char *msg) { + if (ch->open_succeeded) + channel_notify_close(ch, msg); + else + channel_notify_open(ch, msg); +} + +static void channel_notify_exit(ssh_channel_ctx *ch, int exit_status, B_str signal) { + if (ch->exit_sent) + return; + sshQ_Channel actor = channel_actor_ref(ch); + if (ch->on_exit) { + $action3 f = ($action3)ch->on_exit; + f->$class->__asyn__(f, actor, toB_int(exit_status), signal); + } + ch->exit_sent = 1; +} + +static int client_channel_data_cb(ssh_session session, ssh_channel channel, void *data, + uint32_t len, int is_stderr, void *userdata) { + ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) + return 0; + if (len == 0) + return 0; + B_bytes out = to$bytesD_len((char *)data, (size_t)len); + sshQ_Channel actor = channel_actor_ref(ch); + if (is_stderr) { + if (ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, out); + } + } else { + if (ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, actor, out); + } + } + return (int)len; +} + +static void client_channel_eof_cb(ssh_session session, ssh_channel channel, void *userdata) { + ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL) + return; + sshQ_Channel actor = channel_actor_ref(ch); + if (!ch->stdout_eof && ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, B_None); + ch->stderr_eof = 1; + } +} + +static void client_channel_close_cb(ssh_session session, ssh_channel channel, void *userdata) { + ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL) + return; + ch->remote_close_seen = 1; +} + +static int client_channel_write_wontblock_cb(ssh_session session, ssh_channel channel, + uint32_t bytes, void *userdata) { + ssh_channel_ctx *ch = (ssh_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) + return 0; + ch->write_wontblock = bytes > 0 ? 1 : 0; + if (ssh_debug_enabled) { + ssh_debug_log("client channel write_wontblock: bytes=%u ch=%p", bytes, (void *)ch); + } + return 0; +} + +static int client_channel_setup_callbacks(ssh_channel_ctx *ch) { + if (ch == NULL || ch->channel == NULL) + return SSH_ERROR; + if (ch->callbacks != NULL) + return SSH_OK; + struct ssh_channel_callbacks_struct *cb = acton_calloc(1, sizeof(*cb)); + ssh_callbacks_init(cb); + cb->userdata = ch; + cb->channel_data_function = client_channel_data_cb; + cb->channel_eof_function = client_channel_eof_cb; + cb->channel_close_function = client_channel_close_cb; + cb->channel_write_wontblock_function = client_channel_write_wontblock_cb; + if (ssh_add_channel_callbacks(ch->channel, cb) != SSH_OK) { + acton_free(cb); + return SSH_ERROR; + } + ch->callbacks = cb; + if (ssh_debug_enabled) { + ssh_debug_log("client channel callbacks set ch=%p", (void *)ch); + } + return SSH_OK; +} + +static void channel_notify_eof(ssh_channel_ctx *ch) { + if (ch->channel == NULL) + return; + if (ssh_channel_is_eof(ch->channel)) { + sshQ_Channel actor = channel_actor_ref(ch); + if (!ch->stdout_eof && ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, B_None); + ch->stderr_eof = 1; + } + } +} + +static void channel_finalize(ssh_client_ctx *c, ssh_channel_ctx *ch) { + int exit_status = -1; + B_str exit_signal = B_None; + sshQ_Channel actor = channel_actor_ref(ch); + + while (ch->write_head != NULL) { + write_chunk_t *chunk = ch->write_head; + ch->write_head = chunk->next; + acton_free(chunk); + } + ch->write_tail = NULL; + if (ch->pending_req != CHAN_REQ_NONE) { + ch->pending_req = CHAN_REQ_NONE; + channel_notify_error(ch, "SSH channel request failed: channel closed"); + } + if (ch->channel != NULL) { + if (ssh_channel_is_closed(ch->channel)) { + uint32_t exit_code = 0; + char *signal = NULL; + int core_dumped = 0; + int rc = ssh_channel_get_exit_state(ch->channel, &exit_code, &signal, &core_dumped); + if (rc == SSH_OK) { + exit_status = (int)exit_code; + if (signal != NULL) + exit_signal = to$str(signal); + } + if (signal) + ssh_string_free_char(signal); + (void)core_dumped; + } + if (ch->callbacks) { + ssh_remove_channel_callbacks(ch->channel, ch->callbacks); + acton_free(ch->callbacks); + ch->callbacks = NULL; + } + ssh_channel_free(ch->channel); + ch->channel = NULL; + } + ch->state = CHAN_STATE_CLOSED; + if (actor) + actor->_channel_id = toB_u64(0); + + channel_notify_exit(ch, exit_status, exit_signal); + if (!ch->stdout_eof && ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, B_None); + ch->stderr_eof = 1; + } + channel_notify_close(ch, "closed"); + ch->actor = NULL; + (void)c; +} + +static void channel_fail(ssh_client_ctx *c, ssh_channel_ctx *ch, const char *msg) { + if (ch->state == CHAN_STATE_ERROR || ch->state == CHAN_STATE_CLOSED) + return; + ch->state = CHAN_STATE_ERROR; + channel_notify_error(ch, msg); + if (ch->channel) { + ssh_channel_close(ch->channel); + } + channel_finalize(c, ch); +} + +static void channel_queue_write(ssh_channel_ctx *ch, B_bytes data) { + write_chunk_t *chunk = acton_calloc(1, sizeof(write_chunk_t)); + chunk->data = data; + chunk->offset = 0; + chunk->next = NULL; + if (ch->write_tail) { + ch->write_tail->next = chunk; + } else { + ch->write_head = chunk; + } + ch->write_tail = chunk; +} + +static void channel_try_write(ssh_client_ctx *c, ssh_channel_ctx *ch) { + while (ch->write_head != NULL && ch->write_head->data->nbytes == ch->write_head->offset) { + write_chunk_t *chunk = ch->write_head; + ch->write_head = chunk->next; + acton_free(chunk); + if (ch->write_head == NULL) + ch->write_tail = NULL; + } + + if (ch->write_head == NULL || !ch->write_wontblock || session_has_pending_write(c->session)) + return; + + write_chunk_t *chunk = ch->write_head; + size_t remaining = chunk->data->nbytes - chunk->offset; + ch->write_wontblock = 0; + int rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + if (rc > 0) { + chunk->offset += (size_t)rc; + if (chunk->offset >= chunk->data->nbytes) { + ch->write_head = chunk->next; + acton_free(chunk); + if (ch->write_head == NULL) + ch->write_tail = NULL; + } + } else if (rc == 0 || rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH channel write error: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } +} + +static int channel_read_stream(ssh_client_ctx *c, ssh_channel_ctx *ch, int is_stderr) { + char buf[SSH_READ_BUFSIZE]; + int read_any = 0; + for (;;) { + int n = ssh_channel_read_buffered(ch->channel, buf, sizeof(buf), is_stderr); + if (ssh_debug_enabled) { + ssh_debug_log("client channel read: rc=%d stderr=%d", n, is_stderr); + } + if (n > 0) { + read_any = 1; + B_bytes out = to$bytesD_len(buf, n); + sshQ_Channel actor = channel_actor_ref(ch); + if (is_stderr) { + if (ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, out); + } + } else { + if (ch->on_stdout) { + $action2 f = ($action2)ch->on_stdout; + f->$class->__asyn__(f, actor, out); + } + } + continue; + } + if (n == 0 || n == SSH_AGAIN) { + break; + } + if (n == SSH_EOF) { + break; + } + if (n == SSH_ERROR) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH channel read error: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + break; + } + } + return read_any; +} + +static void channel_drive(ssh_client_ctx *c, ssh_channel_ctx *ch) { + if (ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) + return; + + if (ch->state == CHAN_STATE_INIT) { + ch->channel = ssh_channel_new(c->session); + if (ch->channel == NULL) { + channel_fail(c, ch, "Failed to create SSH channel"); + return; + } + if (client_channel_setup_callbacks(ch) != SSH_OK) { + channel_fail(c, ch, "Failed to set SSH channel callbacks"); + return; + } + if (ssh_debug_enabled) { + ssh_debug_log("client channel new ch=%p callbacks=%p", (void *)ch, (void *)ch->callbacks); + } + ch->state = CHAN_STATE_OPENING; + } + + if (ch->state == CHAN_STATE_OPENING) { + int rc = ssh_channel_open_session(ch->channel); + if (rc == SSH_OK) { + ch->state = CHAN_STATE_OPEN; + channel_notify_open(ch, NULL); + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to open SSH channel: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + + if (ch->state == CHAN_STATE_OPEN || ch->state == CHAN_STATE_RUNNING) { + if (ch->pty_pending && !ch->pty_done) { + const char *term = ch->term ? (const char *)fromB_str(ch->term) : "xterm-256color"; + int rc = ssh_channel_request_pty_size(ch->channel, term, ch->cols, ch->rows); + if (rc == SSH_OK) { + ch->pty_done = 1; + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to request PTY: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + + if (ch->pending_req != CHAN_REQ_NONE) { + int rc = SSH_ERROR; + if (ch->pending_req == CHAN_REQ_SHELL) { + rc = ssh_channel_request_shell(ch->channel); + } else if (ch->pending_req == CHAN_REQ_EXEC) { + rc = ssh_channel_request_exec(ch->channel, (const char *)fromB_str(ch->exec_cmd)); + } else if (ch->pending_req == CHAN_REQ_SUBSYSTEM) { + rc = ssh_channel_request_subsystem(ch->channel, (const char *)fromB_str(ch->subsystem)); + } + + if (rc == SSH_OK) { + if (ssh_debug_enabled) { + ssh_debug_log("client channel request ok"); + } + ch->pending_req = CHAN_REQ_NONE; + ch->state = CHAN_STATE_RUNNING; + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + if (ssh_debug_enabled) { + ssh_debug_log("client channel request again"); + } + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH channel request failed: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + } + + if (ch->state == CHAN_STATE_OPEN || ch->state == CHAN_STATE_RUNNING) { + if (ch->write_head != NULL) { + channel_try_write(c, ch); + if (ch->state == CHAN_STATE_ERROR) + return; + } + + if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL && + !session_has_pending_write(c->session)) { + int rc = ssh_channel_send_eof(ch->channel); + if (rc == SSH_OK) { + ch->eof_sent = 1; + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to send EOF: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + + if (ch->close_requested && !ch->close_sent && ch->write_head == NULL && + (!ch->send_eof || ch->eof_sent) && + !session_has_pending_write(c->session)) { + int rc = ssh_channel_close(ch->channel); + if (rc == SSH_OK) { + ch->close_sent = 1; + ch->state = CHAN_STATE_CLOSING; + } else if (rc == SSH_AGAIN) { + c->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to close channel: %s", ssh_get_error(c->session)); + channel_fail(c, ch, errmsg); + return; + } + } + + if (ch->callbacks == NULL) { + for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + did |= channel_read_stream(c, ch, 0); + did |= channel_read_stream(c, ch, 1); + if (!did) + break; + } + } + } + + if (ch->channel != NULL && ch->remote_close_seen && + ssh_channel_is_closed(ch->channel)) { + channel_finalize(c, ch); + } +} + +static void client_drive_channels(ssh_client_ctx *c) { + ssh_channel_ctx *prev = NULL; + ssh_channel_ctx *ch = c->channels; + while (ch != NULL) { + ssh_channel_ctx *next = ch->next; + channel_drive(c, ch); + if (ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) { + if (prev != NULL) { + prev->next = next; + } else { + c->channels = next; + } + client_retire_channel(c, ch); + } else { + prev = ch; + } + ch = next; + } +} + +static int client_get_hostkey_info(ssh_client_ctx *c, B_str *key_type_out, B_str *fingerprint_out) { + ssh_key key = NULL; + unsigned char *hash = NULL; + size_t hash_len = 0; + char *fingerprint = NULL; + + int rc = ssh_get_server_publickey(c->session, &key); + if (rc != SSH_OK || key == NULL) + return -1; + + enum ssh_keytypes_e key_type = ssh_key_type(key); + const char *key_type_str = ssh_key_type_to_char(key_type); + + rc = ssh_get_publickey_hash(key, SSH_PUBLICKEY_HASH_SHA256, &hash, &hash_len); + if (rc != SSH_OK) { + ssh_key_free(key); + return -1; + } + + fingerprint = ssh_get_fingerprint_hash(SSH_PUBLICKEY_HASH_SHA256, hash, hash_len); + + if (key_type_out) + *key_type_out = to$str((char *)(key_type_str ? key_type_str : "")); + if (fingerprint_out) + *fingerprint_out = to$str((char *)(fingerprint ? fingerprint : "")); + + ssh_clean_pubkey_hash(&hash); + ssh_key_free(key); + if (fingerprint) + ssh_string_free_char(fingerprint); + + return 0; +} + +static int client_check_hostkey(ssh_client_ctx *c) { + enum ssh_known_hosts_e state = SSH_KNOWN_HOSTS_UNKNOWN; + int use_known_hosts = 0; + sshQ_Client actor = client_actor_ref(c); + + if (actor != NULL && actor->known_hosts != B_None) + use_known_hosts = 1; + + if (use_known_hosts) { + state = ssh_session_is_known_server(c->session); + if (state == SSH_KNOWN_HOSTS_OK) + return 0; + + if (state == SSH_KNOWN_HOSTS_ERROR) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Host key check error: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return -1; + } + } + + c->hostkey_state = state; + + if (!c->on_hostkey) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Host key not accepted (%s)", hostkey_state_str(state)); + client_fail(c, errmsg); + return -1; + } + + B_str key_type = to$str((char *)""); + B_str fingerprint = to$str((char *)""); + if (client_get_hostkey_info(c, &key_type, &fingerprint) != 0) { + key_type = to$str((char *)""); + fingerprint = to$str((char *)""); + } + + sshQ_HostKeyInfo info = sshQ_HostKeyInfoG_new(key_type, fingerprint); + $action3 f = ($action3)c->on_hostkey; + f->$class->__asyn__(f, actor, to$str((char *)hostkey_state_str(state)), info); + return 1; +} + +static void connect_timeout_cb(uv_timer_t *timer) { + ssh_client_ctx *c = (ssh_client_ctx *)timer->data; + if (c == NULL) + return; + if (c->state == CLIENT_STATE_CONNECTING || c->state == CLIENT_STATE_HOSTKEY || c->state == CLIENT_STATE_HOSTKEY_WAIT) { + client_fail(c, "SSH connect timeout"); + } +} + +static void auth_timeout_cb(uv_timer_t *timer) { + ssh_client_ctx *c = (ssh_client_ctx *)timer->data; + if (c == NULL) + return; + if (c->state == CLIENT_STATE_AUTH) { + client_fail(c, "SSH authentication timeout"); + } +} + +static void keepalive_cb(uv_timer_t *timer) { + ssh_client_ctx *c = (ssh_client_ctx *)timer->data; + if (c == NULL) + return; + if (c->state != CLIENT_STATE_READY || c->session == NULL) + return; + int rc = ssh_send_ignore(c->session, "keepalive"); + if (rc == SSH_AGAIN) { + client_update_poll(c); + return; + } + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH keepalive failed: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return; + } + client_update_poll(c); +} + +static void client_start_connect_timer(ssh_client_ctx *c) { + if (c->connect_timeout <= 0.0 || c->connect_timer != NULL) + return; + c->connect_timer = acton_calloc(1, sizeof(uv_timer_t)); + c->connect_timer->data = c; + uv_timer_init(get_uv_loop(), c->connect_timer); + uv_timer_start(c->connect_timer, connect_timeout_cb, (uint64_t)(c->connect_timeout * 1000), 0); +} + +static void client_start_auth_timer(ssh_client_ctx *c) { + if (c->auth_timeout <= 0.0 || c->auth_timer != NULL) + return; + c->auth_timer = acton_calloc(1, sizeof(uv_timer_t)); + c->auth_timer->data = c; + uv_timer_init(get_uv_loop(), c->auth_timer); + uv_timer_start(c->auth_timer, auth_timeout_cb, (uint64_t)(c->auth_timeout * 1000), 0); +} + +static void client_start_keepalive(ssh_client_ctx *c) { + if (!c->keepalive_enabled || c->keepalive_interval <= 0.0 || c->keepalive_timer != NULL) + return; + c->keepalive_timer = acton_calloc(1, sizeof(uv_timer_t)); + c->keepalive_timer->data = c; + uv_timer_init(get_uv_loop(), c->keepalive_timer); + uint64_t interval_ms = (uint64_t)(c->keepalive_interval * 1000); + uv_timer_start(c->keepalive_timer, keepalive_cb, interval_ms, interval_ms); +} + +static void poll_cb(uv_poll_t *handle, int status, int events) { + ssh_client_ctx *c = (ssh_client_ctx *)handle->data; + if (c == NULL) + return; + if (ssh_debug_enabled) { + ssh_debug_log("client poll: status=%d events=0x%x state=%d", status, events, c->state); + } + if (status < 0) { + char errmsg[256] = {0}; + uv_strerror_r(status, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + client_fail(c, errmsg); + return; + } + int libssh_events = 0; + if ((events & UV_READABLE) && fd_has_data(c->fd)) { + ssh_set_fd_toread(c->session); + libssh_events |= UV_READABLE; + } +#ifdef UV_DISCONNECT + if (events & UV_DISCONNECT) { + ssh_set_fd_toread(c->session); + libssh_events |= UV_DISCONNECT; + } +#endif + if ((events & UV_WRITABLE) && fd_can_write(c->fd)) { + c->write_ready = 1; + ssh_set_fd_towrite(c->session); + libssh_events |= UV_WRITABLE; + } + if (session_apply_poll_events(c->session, libssh_events) != 0) { + char errmsg[256] = {0}; + format_session_error(c->session, "SSH poll callback error", errmsg, sizeof(errmsg)); + client_fail(c, errmsg); + return; + } + client_drive(c); + client_pump_io(c); + c->write_ready = 0; +} + +static void client_update_poll(ssh_client_ctx *c) { + if (c->poll == NULL || c->session == NULL) + return; + if (c->state == CLIENT_STATE_CLOSED) + return; + if (uv_is_closing((uv_handle_t *)c->poll)) + return; + int status = ssh_get_status(c->session); + if (status & SSH_CLOSED_ERROR) { + client_fail(c, "SSH session closed with error"); + return; + } + if (status & SSH_CLOSED) { + client_close_internal(c, "SSH session closed", 1); + return; + } + int flags = ssh_get_poll_flags(c->session); + int pending = flags | status; + int events = UV_READABLE; +#ifdef UV_DISCONNECT + events |= UV_DISCONNECT; +#endif + if (pending & SSH_WRITE_PENDING) + events |= UV_WRITABLE; + if ((events & UV_WRITABLE) == 0 && client_needs_write(c)) + events |= UV_WRITABLE; + if (ssh_debug_enabled && (events != c->poll_events || (pending & SSH_WRITE_PENDING))) { + ssh_debug_log("client update poll: status=0x%x flags=0x%x pending=0x%x events=0x%x state=%d", + status, flags, pending, events, c->state); + } + if (events != c->poll_events) { + int uv_rc = uv_poll_start(c->poll, events, poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + client_fail(c, errmsg); + return; + } + c->poll_events = events; + } +} + +static void client_pump_io(ssh_client_ctx *c) { + if (c == NULL || c->session == NULL) + return; + int i; + for (i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + if (c->session == NULL) + return; + int has_data = fd_has_data(c->fd); + if (has_data) { + if (ssh_debug_enabled) { + ssh_debug_log("client pump: readable"); + } + ssh_set_fd_toread(c->session); + if (session_apply_poll_events(c->session, UV_READABLE) != 0) { + char errmsg[256] = {0}; + format_session_error(c->session, "SSH poll callback error", errmsg, sizeof(errmsg)); + client_fail(c, errmsg); + return; + } + client_drive(c); + did = 1; + } + if (c->session == NULL) + return; + if (!did) { + int status = ssh_get_status(c->session); + if (status & SSH_READ_PENDING) { + if (ssh_debug_enabled) { + ssh_debug_log("client pump: buffered read pending=0x%x", status); + } + client_drive(c); + /* Avoid spinning when only buffered data remains. */ + break; + } + } + if (!did) + break; + } + if (ssh_debug_enabled && i >= SSH_IO_PUMP_LIMIT) { + int status = ssh_get_status(c->session); + int flags = ssh_get_poll_flags(c->session); + ssh_debug_log("client pump: hit limit status=0x%x flags=0x%x", status, flags); + } +} + +static int client_needs_write(ssh_client_ctx *c) { + if (c == NULL) + return 0; + ssh_channel_ctx *ch = c->channels; + while (ch != NULL) { + if (ch->pending_req != CHAN_REQ_NONE || ch->write_head != NULL) + return 1; + if (ch->send_eof && !ch->eof_sent) + return 1; + if (ch->close_requested && !ch->close_sent) + return 1; + ch = ch->next; + } + return 0; +} + +static void client_on_ready(ssh_client_ctx *c) { + c->state = CLIENT_STATE_READY; + stop_timer(&c->connect_timer, client_timer_close_cb); + stop_timer(&c->auth_timer, client_timer_close_cb); + client_notify_connect(c, NULL); + client_start_keepalive(c); + client_drive_channels(c); + client_update_poll(c); +} + +static void client_drive(ssh_client_ctx *c) { + if (c == NULL) + return; + if (c->state == CLIENT_STATE_ERROR || c->state == CLIENT_STATE_CLOSED) + return; + if (c->state == CLIENT_STATE_CLOSING) { + client_drive_channels(c); + client_finish_close(c); + return; + } + + int spin = 0; + while (1) { + if (c->state == CLIENT_STATE_CONNECTING) { + int rc = ssh_connect(c->session); + if (ssh_debug_enabled) { + log_debug("ssh_connect rc=%d state=%d", rc, c->state); + } + if (rc == SSH_OK) { + stop_timer(&c->connect_timer, client_timer_close_cb); + if (fd_set_nonblocking(c->fd) != 0) { + client_fail(c, "Failed to restore SSH session fd nonblocking"); + return; + } + c->state = CLIENT_STATE_HOSTKEY; + continue; + } else if (rc == SSH_AGAIN) { + int status = ssh_get_status(c->session); + if (status & SSH_WRITE_PENDING) + c->write_ready = 0; + if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) { + if (ssh_debug_enabled) { + ssh_debug_log("client drive: buffered read pending during connect"); + } + continue; + } + client_update_poll(c); + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH connect failed: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return; + } + } + + if (c->state == CLIENT_STATE_HOSTKEY) { + int rc = client_check_hostkey(c); + if (rc == 0) { + c->state = CLIENT_STATE_AUTH; + client_start_auth_timer(c); + continue; + } else if (rc == 1) { + c->state = CLIENT_STATE_HOSTKEY_WAIT; + return; + } else { + return; + } + } + + if (c->state == CLIENT_STATE_HOSTKEY_WAIT) { + client_update_poll(c); + return; + } + + if (c->state == CLIENT_STATE_AUTH) { + sshQ_Client actor = client_actor_ref(c); + if (actor == NULL || actor->password == B_None) { + client_fail(c, "Password auth requested but no password provided"); + return; + } + int rc = ssh_userauth_password(c->session, NULL, (const char *)fromB_str(actor->password)); + if (rc == SSH_AUTH_SUCCESS) { + client_on_ready(c); + return; + } else if (rc == SSH_AUTH_AGAIN) { + int status = ssh_get_status(c->session); + if (status & SSH_WRITE_PENDING) + c->write_ready = 0; + if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) { + if (ssh_debug_enabled) { + ssh_debug_log("client drive: buffered read pending during auth"); + } + continue; + } + client_update_poll(c); + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH auth failed: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return; + } + } + + if (c->state == CLIENT_STATE_READY) { + client_drive_channels(c); + client_update_poll(c); + return; + } + + return; + } +} + +static void client_finalize(ssh_client_ctx *c) { + if (c == NULL || c->close_finalized) + return; + c->close_finalized = 1; + + if (c->session != NULL) { + ssh_disconnect(c->session); + ssh_free(c->session); + c->session = NULL; + } + client_free_retired_channels(c); + + client_notify_close(c, c->close_reason ? c->close_reason : "closed"); + c->state = CLIENT_STATE_CLOSED; + sshQ_Client actor = client_actor_ref(c); + if (actor) + actor->_client = toB_u64(0); + STORE_HIDDEN_PTR(c->actor, NULL); + client_maybe_release(c); +} + +static void client_abort_channels(ssh_client_ctx *c, int notify_channel_error) { + ssh_channel_ctx *ch = c->channels; + while (ch != NULL) { + ssh_channel_ctx *next = ch->next; + if (notify_channel_error) + channel_notify_error(ch, "Session closed"); + channel_notify_eof(ch); + channel_finalize(c, ch); + client_retire_channel(c, ch); + ch = next; + } + c->channels = NULL; +} + +static void client_request_channel_close(ssh_client_ctx *c) { + ssh_channel_ctx *ch = c->channels; + while (ch != NULL) { + if (ch->state != CHAN_STATE_CLOSED && ch->state != CHAN_STATE_ERROR) { + ch->send_eof = 1; + ch->close_requested = 1; + } + ch = ch->next; + } +} + +static void client_finish_close(ssh_client_ctx *c) { + if (c == NULL || c->state != CLIENT_STATE_CLOSING) + return; + if (c->close_force) { + if (c->poll != NULL) { + close_poll(&c->poll, client_poll_close_cb); + c->poll_events = 0; + return; + } + client_finalize(c); + return; + } + if (c->channels != NULL) { + client_update_poll(c); + return; + } + if (c->session != NULL && session_has_pending_write(c->session)) { + client_update_poll(c); + return; + } + if (c->poll != NULL) { + close_poll(&c->poll, client_poll_close_cb); + c->poll_events = 0; + return; + } + client_finalize(c); +} + +static void client_close_internal(ssh_client_ctx *c, const char *reason, int force_close) { + if (c == NULL || c->state == CLIENT_STATE_CLOSED) + return; + if (!force_close && c->state != CLIENT_STATE_READY) + force_close = 1; + + if (!c->connected_ok && !c->connect_notified) { + client_notify_connect(c, reason ? reason : "closed"); + } + if (reason != NULL && c->close_reason == NULL) + c->close_reason = acton_strdup(reason); + + if (c->state == CLIENT_STATE_CLOSING) { + if (force_close && !c->close_force) { + c->close_force = 1; + client_abort_channels(c, 1); + } + client_finish_close(c); + return; + } + + stop_timer(&c->connect_timer, client_timer_close_cb); + stop_timer(&c->auth_timer, client_timer_close_cb); + stop_timer(&c->keepalive_timer, client_timer_close_cb); + + c->state = CLIENT_STATE_CLOSING; + c->close_force = force_close; + if (force_close) { + client_abort_channels(c, 1); + client_finish_close(c); + return; + } + + client_request_channel_close(c); + client_drive(c); +} + +void sshQ___ext_init__() { + const char *dbg_env = getenv("ACTON_SSH_DEBUG"); + const char *log_env = getenv("ACTON_SSH_LIBSSH_LOG"); + if (dbg_env != NULL && dbg_env[0] != '\0') + ssh_debug_enabled = 1; + if (log_env != NULL && log_env[0] != '\0') { + ssh_libssh_log_level = parse_libssh_log_level(log_env); + } + int r = libssh_replace_allocator(acton_malloc, + acton_realloc, + acton_calloc, + acton_free, + acton_strdup, + acton_strndup); + if (r != SSH_OK) { + log_warn("SSH allocator replacement failed"); + } + r = ssh_threads_set_callbacks(ssh_threads_get_default()); + if (r != SSH_OK) { + log_warn("SSH thread callbacks setup failed"); + } + r = ssh_init(); + if (r != SSH_OK) { + log_warn("SSH init failed"); + } +} + +B_str sshQ_version() { + return to$str((char *)ssh_version(0)); +} + +B_NoneType sshQ__debug(B_str msg) { + if (ssh_debug_enabled) { + log_info("%s", fromB_str(msg)); + } + return B_None; +} + +$R sshQ_ClientD__pin_affinityG_local(sshQ_Client self, $Cont c$cont) { + pin_actor_affinity(); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD__initG_local(sshQ_Client self, $Cont c$cont) { + ssh_configure_libssh_logging(); + ssh_client_ctx *c = acton_calloc(1, sizeof(ssh_client_ctx)); + STORE_HIDDEN_PTR(c->actor, self); + c->on_connect = ($action2)self->on_connect; + c->on_close = ($action2)self->on_close; + c->on_hostkey = ($action3)self->on_hostkey; + c->connect_timeout = fromB_float(self->connect_timeout); + c->auth_timeout = fromB_float(self->auth_timeout); + c->keepalive_interval = fromB_float(self->keepalive_interval); + c->keepalive_enabled = fromB_bool(self->keepalive_enabled) ? 1 : 0; + + self->_client = toB_u64((unsigned long)c); + + c->session = ssh_new(); + if (c->session == NULL) { + client_fail(c, "Failed to create SSH session"); + return $R_CONT(c$cont, B_None); + } + + int rc; + int strict = 1; + int port = (int)fromB_u16(self->port); + bool process_config = false; + + rc = ssh_options_set(c->session, SSH_OPTIONS_PROCESS_CONFIG, &process_config); + if (rc != SSH_OK) { + client_fail(c, "Failed to disable SSH config processing"); + return $R_CONT(c$cont, B_None); + } + + rc = ssh_options_set(c->session, SSH_OPTIONS_HOST, fromB_str(self->host)); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH host"); + return $R_CONT(c$cont, B_None); + } + rc = ssh_options_set(c->session, SSH_OPTIONS_PORT, &port); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH port"); + return $R_CONT(c$cont, B_None); + } + rc = ssh_options_set(c->session, SSH_OPTIONS_USER, fromB_str(self->username)); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH username"); + return $R_CONT(c$cont, B_None); + } + if (self->known_hosts != B_None) { + const char *known_hosts = (const char *)fromB_str(self->known_hosts); + rc = ssh_options_set(c->session, SSH_OPTIONS_KNOWNHOSTS, known_hosts); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH known_hosts path"); + return $R_CONT(c$cont, B_None); + } + rc = ssh_options_set(c->session, SSH_OPTIONS_GLOBAL_KNOWNHOSTS, known_hosts); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH global known_hosts path"); + return $R_CONT(c$cont, B_None); + } + } + rc = ssh_options_set(c->session, SSH_OPTIONS_STRICTHOSTKEYCHECK, &strict); + if (rc != SSH_OK) { + client_fail(c, "Failed to set SSH strict host key checking"); + return $R_CONT(c$cont, B_None); + } + + ssh_set_blocking(c->session, 0); + + c->state = CLIENT_STATE_CONNECTING; + rc = ssh_connect(c->session); + if (rc == SSH_OK) { + c->state = CLIENT_STATE_HOSTKEY; + } else if (rc == SSH_AGAIN) { + c->state = CLIENT_STATE_CONNECTING; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH connect failed: %s", ssh_get_error(c->session)); + client_fail(c, errmsg); + return $R_CONT(c$cont, B_None); + } + + c->fd = ssh_get_fd(c->session); + if (c->fd < 0) { + client_fail(c, "Failed to get SSH session fd"); + return $R_CONT(c$cont, B_None); + } + if (c->state != CLIENT_STATE_CONNECTING && fd_set_nonblocking(c->fd) != 0) { + client_fail(c, "Failed to set SSH session fd nonblocking"); + return $R_CONT(c$cont, B_None); + } + + c->poll = acton_calloc(1, sizeof(uv_poll_t)); + c->poll->data = c; + int uv_rc = uv_poll_init(get_uv_loop(), c->poll, c->fd); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + client_fail(c, errmsg); + return $R_CONT(c$cont, B_None); + } + c->poll_events = UV_READABLE | UV_WRITABLE; + uv_rc = uv_poll_start(c->poll, c->poll_events, poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + client_fail(c, errmsg); + return $R_CONT(c$cont, B_None); + } + + client_start_connect_timer(c); + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_accept_hostkeyG_local(sshQ_Client self, $Cont c$cont) { + ssh_client_ctx *c = client_from_actor(self); + if (c == NULL) + return $R_CONT(c$cont, B_None); + if (c->state != CLIENT_STATE_HOSTKEY_WAIT) + return $R_CONT(c$cont, B_None); + + c->state = CLIENT_STATE_AUTH; + client_start_auth_timer(c); + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_reject_hostkeyG_local(sshQ_Client self, $Cont c$cont, B_str reason) { + ssh_client_ctx *c = client_from_actor(self); + if (c == NULL) + return $R_CONT(c$cont, B_None); + if (c->state != CLIENT_STATE_HOSTKEY_WAIT) + return $R_CONT(c$cont, B_None); + + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Host key rejected: %s", fromB_str(reason)); + client_fail(c, errmsg); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_closeG_local(sshQ_Client self, $Cont c$cont) { + ssh_client_ctx *c = client_from_actor(self); + if (c == NULL) + return $R_CONT(c$cont, B_None); + client_close_internal(c, "closed", 0); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD__cleanup_nativeG_local(sshQ_Client self, $Cont c$cont) { + ssh_client_ctx *c = client_from_actor(self); + if (c != NULL) + client_close_internal(c, "collected", 1); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_createG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, + $action on_open, + $action on_stdout, + $action on_stderr, + $action on_exit, + $action on_close) { + ssh_client_ctx *c = client_from_actor(self); + if (c == NULL) { + if (on_open) { + $action2 f = ($action2)on_open; + f->$class->__asyn__(f, channel, to$str((char *)"Client not initialized")); + } + return $R_CONT(c$cont, B_None); + } + + ssh_channel_ctx *ch = acton_calloc(1, sizeof(ssh_channel_ctx)); + ch->client = c; + ch->actor = channel; + ch->callbacks = NULL; + ch->state = CHAN_STATE_INIT; + ch->pending_req = CHAN_REQ_NONE; + ch->pty_pending = 0; + ch->pty_done = 0; + ch->term = to$str((char *)"xterm-256color"); + ch->cols = 80; + ch->rows = 24; + ch->width_px = 0; + ch->height_px = 0; + ch->on_open = ($action2)on_open; + ch->on_close = ($action2)on_close; + ch->on_stdout = ($action2)on_stdout; + ch->on_stderr = ($action2)on_stderr; + ch->on_exit = ($action3)on_exit; + + ch->next = c->channels; + c->channels = ch; + + channel->_channel_id = toB_u64((unsigned long)ch); + + if (c->state == CLIENT_STATE_READY) + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +static int channel_validate(ssh_client_ctx *c, ssh_channel_ctx *ch) { + if (c == NULL || ch == NULL) { + return 1; + } + if (c->state != CLIENT_STATE_READY) { + return 1; + } + if (ch->client != c) { + if (ssh_debug_enabled) { + ssh_debug_log("client channel ownership mismatch: client=%p owner=%p ch=%p", + (void *)c, (void *)ch->client, (void *)ch); + } + return 1; + } + if (ch->state == CHAN_STATE_ERROR || ch->state == CHAN_STATE_CLOSED) { + return 2; + } + return 0; +} + +static int server_channel_validate(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { + if (s == NULL || ch == NULL) { + return -1; + } + if (s->state != SESSION_STATE_READY) { + return -1; + } + if (ch->session != s) { + if (ssh_debug_enabled) { + ssh_debug_log("server channel ownership mismatch: session=%p owner=%p ch=%p", + (void *)s, (void *)ch->session, (void *)ch); + } + return -1; + } + if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) { + return -1; + } + return 0; +} + +$R sshQ_ClientD_channel_request_execG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str cmd) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + int valid = channel_validate(c, ch); + if (valid != 0) { + if (valid == 2 && ch != NULL) + channel_notify_error(ch, "Channel not ready"); + return $R_CONT(c$cont, B_None); + } + + if (ch->pending_req != CHAN_REQ_NONE) { + channel_notify_error(ch, "Channel already has a pending request"); + return $R_CONT(c$cont, B_None); + } + if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) { + channel_notify_error(ch, "Channel already running"); + return $R_CONT(c$cont, B_None); + } + + ch->exec_cmd = cmd; + ch->pending_req = CHAN_REQ_EXEC; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_request_shellG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str term, B_int cols, B_int rows, B_int width_px, B_int height_px, B_bool with_pty) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + int valid = channel_validate(c, ch); + if (valid != 0) { + if (valid == 2 && ch != NULL) + channel_notify_error(ch, "Channel not ready"); + return $R_CONT(c$cont, B_None); + } + + if (ch->pending_req != CHAN_REQ_NONE) { + channel_notify_error(ch, "Channel already has a pending request"); + return $R_CONT(c$cont, B_None); + } + if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) { + channel_notify_error(ch, "Channel already running"); + return $R_CONT(c$cont, B_None); + } + + ch->term = term; + ch->cols = fromB_int(cols); + ch->rows = fromB_int(rows); + ch->width_px = fromB_int(width_px); + ch->height_px = fromB_int(height_px); + ch->pty_pending = fromB_bool(with_pty) ? 1 : 0; + + ch->pending_req = CHAN_REQ_SHELL; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_request_subsystemG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_str name) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + int valid = channel_validate(c, ch); + if (valid != 0) { + if (valid == 2 && ch != NULL) + channel_notify_error(ch, "Channel not ready"); + return $R_CONT(c$cont, B_None); + } + + if (ch->pending_req != CHAN_REQ_NONE) { + channel_notify_error(ch, "Channel already has a pending request"); + return $R_CONT(c$cont, B_None); + } + if (ch->state == CHAN_STATE_RUNNING || ch->state == CHAN_STATE_CLOSING) { + channel_notify_error(ch, "Channel already running"); + return $R_CONT(c$cont, B_None); + } + + ch->subsystem = name; + ch->pending_req = CHAN_REQ_SUBSYSTEM; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_writeG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel, B_bytes data) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + int valid = channel_validate(c, ch); + if (valid != 0) { + if (valid == 2 && ch != NULL) + channel_notify_error(ch, "Channel not ready"); + return $R_CONT(c$cont, B_None); + } + + channel_queue_write(ch, data); + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_send_eofG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + if (channel_validate(c, ch) != 0) + return $R_CONT(c$cont, B_None); + + ch->send_eof = 1; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ClientD_channel_closeG_local(sshQ_Client self, $Cont c$cont, sshQ_Channel channel) { + ssh_client_ctx *c = client_from_actor(self); + ssh_channel_ctx *ch = channel_from_actor(channel); + if (channel_validate(c, ch) != 0) + return $R_CONT(c$cont, B_None); + + ch->send_eof = 1; + ch->close_requested = 1; + client_drive(c); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ChannelD__cleanup_nativeG_local(sshQ_Channel self, $Cont c$cont) { + ssh_channel_ctx *ch = channel_from_actor(self); + if (ch == NULL || ch->state == CHAN_STATE_CLOSED || ch->state == CHAN_STATE_ERROR) + return $R_CONT(c$cont, B_None); + ssh_client_ctx *c = ch->client; + if (c == NULL) + return $R_CONT(c$cont, B_None); + ch->send_eof = 1; + ch->close_requested = 1; + client_drive(c); + return $R_CONT(c$cont, B_None); +} + +// --- Server implementation + +static void server_notify_listen(ssh_server_ctx *s, const char *err) { + if (s == NULL) + return; + if (s->listen_notified) + return; + sshQ_Server actor = server_actor_ref(s); + if (s->on_listen) { + $action2 f = ($action2)s->on_listen; + f->$class->__asyn__(f, actor, err ? to$str((char *)err) : B_None); + } + s->listen_notified = 1; + if (err == NULL) + s->listen_ok = 1; +} + +static void server_notify_close(ssh_server_ctx *s, const char *reason) { + if (s->close_notified) + return; + if (!s->listen_ok) + return; + sshQ_Server actor = server_actor_ref(s); + if (s->on_close) { + $action2 f = ($action2)s->on_close; + f->$class->__asyn__(f, actor, to$str((char *)reason)); + } + s->close_notified = 1; +} + +static void session_notify_close(ssh_server_session_ctx *s, const char *reason) { + if (s->close_notified) + return; + sshQ_ServerSession actor = session_actor_ref(s); + if (s->on_close) { + $action2 f = ($action2)s->on_close; + f->$class->__asyn__(f, actor, to$str((char *)reason)); + } + s->close_notified = 1; +} + +static void server_channel_notify_close(ssh_server_channel_ctx *ch, const char *reason) { + if (ch->close_notified) + return; + sshQ_ServerChannel actor = server_channel_actor_ref(ch); + if (ch->on_close) { + $action2 f = ($action2)ch->on_close; + f->$class->__asyn__(f, actor, to$str((char *)reason)); + } + ch->close_notified = 1; +} + +static int server_channel_data_cb(ssh_session session, ssh_channel channel, void *data, + uint32_t len, int is_stderr, void *userdata) { + ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) + return 0; + if (len == 0) + return 0; + B_bytes out = to$bytesD_len((char *)data, (size_t)len); + sshQ_ServerChannel actor = server_channel_actor_ref(ch); + if (is_stderr) { + if (ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, out); + } + } else { + if (ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, actor, out); + } + } + return (int)len; +} + +static void server_channel_eof_cb(ssh_session session, ssh_channel channel, void *userdata) { + ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL) + return; + sshQ_ServerChannel actor = server_channel_actor_ref(ch); + if (!ch->stdout_eof && ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, B_None); + ch->stderr_eof = 1; + } +} + +static void server_channel_close_cb(ssh_session session, ssh_channel channel, void *userdata) { + ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL) + return; + ch->remote_close_seen = 1; +} + +static int server_channel_write_wontblock_cb(ssh_session session, ssh_channel channel, + uint32_t bytes, void *userdata) { + ssh_server_channel_ctx *ch = (ssh_server_channel_ctx *)userdata; + (void)session; + (void)channel; + if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) + return 0; + ch->write_wontblock = bytes > 0 ? 1 : 0; + if (ssh_debug_enabled) { + ssh_debug_log("server channel write_wontblock: bytes=%u ch=%p", bytes, (void *)ch); + } + return 0; +} + +static int server_channel_setup_callbacks(ssh_server_channel_ctx *ch) { + if (ch == NULL || ch->channel == NULL) + return SSH_ERROR; + if (ch->callbacks != NULL) + return SSH_OK; + struct ssh_channel_callbacks_struct *cb = acton_calloc(1, sizeof(*cb)); + ssh_callbacks_init(cb); + cb->userdata = ch; + cb->channel_data_function = server_channel_data_cb; + cb->channel_eof_function = server_channel_eof_cb; + cb->channel_close_function = server_channel_close_cb; + cb->channel_write_wontblock_function = server_channel_write_wontblock_cb; + if (ssh_add_channel_callbacks(ch->channel, cb) != SSH_OK) { + acton_free(cb); + return SSH_ERROR; + } + ch->callbacks = cb; + if (ssh_debug_enabled) { + ssh_debug_log("server channel callbacks set ch=%p", (void *)ch); + } + return SSH_OK; +} + +static void server_channel_finalize(ssh_server_channel_ctx *ch) { + sshQ_ServerChannel actor = server_channel_actor_ref(ch); + while (ch->write_head != NULL) { + server_write_chunk_t *chunk = ch->write_head; + ch->write_head = chunk->next; + acton_free(chunk); + } + ch->write_tail = NULL; + if (ch->pending_req) { + ssh_message_reply_default(ch->pending_req); + ssh_message_free(ch->pending_req); + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + } + if (ch->channel != NULL) { + if (ch->callbacks) { + ssh_remove_channel_callbacks(ch->channel, ch->callbacks); + acton_free(ch->callbacks); + ch->callbacks = NULL; + } + ssh_channel_free(ch->channel); + ch->channel = NULL; + } + ch->state = SCHAN_STATE_CLOSED; + if (actor) + actor->_channel_id = toB_u64(0); + if (!ch->stdout_eof && ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, actor, B_None); + ch->stdout_eof = 1; + } + if (!ch->stderr_eof && ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, B_None); + ch->stderr_eof = 1; + } + server_channel_notify_close(ch, "closed"); + ch->actor = NULL; +} + +static void server_channel_queue_write(ssh_server_channel_ctx *ch, B_bytes data, int is_stderr) { + server_write_chunk_t *chunk = acton_calloc(1, sizeof(server_write_chunk_t)); + chunk->data = data; + chunk->offset = 0; + chunk->is_stderr = is_stderr; + chunk->next = NULL; + if (ch->write_tail) { + ch->write_tail->next = chunk; + } else { + ch->write_head = chunk; + } + ch->write_tail = chunk; + if (ssh_debug_enabled) { + ssh_debug_log("server channel queue write: %zu bytes", data ? data->nbytes : 0); + } +} + +static void server_channel_try_write(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { + while (ch->write_head != NULL && ch->write_head->data->nbytes == ch->write_head->offset) { + server_write_chunk_t *chunk = ch->write_head; + ch->write_head = chunk->next; + acton_free(chunk); + if (ch->write_head == NULL) + ch->write_tail = NULL; + } + + if (ch->write_head == NULL || !ch->write_wontblock || session_has_pending_write(s->session)) + return; + + server_write_chunk_t *chunk = ch->write_head; + size_t remaining = chunk->data->nbytes - chunk->offset; + ch->write_wontblock = 0; + int rc; + if (chunk->is_stderr) { + rc = ssh_channel_write_stderr(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + } else { + rc = ssh_channel_write(ch->channel, chunk->data->str + chunk->offset, (uint32_t)remaining); + } + if (rc > 0) { + chunk->offset += (size_t)rc; + if (ssh_debug_enabled) { + ssh_debug_log("server channel wrote %d bytes", rc); + } + if (chunk->offset >= chunk->data->nbytes) { + ch->write_head = chunk->next; + acton_free(chunk); + if (ch->write_head == NULL) + ch->write_tail = NULL; + } + } else if (rc == 0 || rc == SSH_AGAIN) { + s->write_ready = 0; + if (ssh_debug_enabled) { + unsigned int window = ssh_channel_window_size(ch->channel); + ssh_debug_log("server channel write pending rc=%d remaining=%zu window=%u", + rc, remaining, window); + } + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server channel write error: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + return; + } +} + +static int server_channel_read_stream(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch, int is_stderr) { + char buf[SSH_READ_BUFSIZE]; + int read_any = 0; + for (;;) { + int n = ssh_channel_read_buffered(ch->channel, buf, sizeof(buf), is_stderr); + if (ssh_debug_enabled) { + ssh_debug_log("server channel read: rc=%d stderr=%d", n, is_stderr); + } + if (n > 0) { + read_any = 1; + B_bytes out = to$bytesD_len(buf, n); + sshQ_ServerChannel actor = server_channel_actor_ref(ch); + if (is_stderr) { + if (ch->on_stderr) { + $action2 f = ($action2)ch->on_stderr; + f->$class->__asyn__(f, actor, out); + } + } else { + if (ch->on_data) { + $action2 f = ($action2)ch->on_data; + f->$class->__asyn__(f, actor, out); + } + } + continue; + } + if (n == 0 || n == SSH_AGAIN) { + break; + } + if (n == SSH_EOF) { + break; + } + if (n == SSH_ERROR) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server channel read error: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + break; + } + } + return read_any; +} + +static void server_channel_drive(ssh_server_session_ctx *s, ssh_server_channel_ctx *ch) { + if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) + return; + + if (ch->write_head != NULL) { + server_channel_try_write(s, ch); + if (ch->state == SCHAN_STATE_ERROR) + return; + } + + if (ch->send_eof && !ch->eof_sent && ch->write_head == NULL && + !session_has_pending_write(s->session)) { + int rc = ssh_channel_send_eof(ch->channel); + if (rc == SSH_OK) { + ch->eof_sent = 1; + } else if (rc == SSH_AGAIN) { + s->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server send EOF failed: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + return; + } + } + + if (ch->close_requested && !ch->close_sent && ch->write_head == NULL && + (!ch->send_eof || ch->eof_sent) && + !session_has_pending_write(s->session)) { + int rc = ssh_channel_close(ch->channel); + if (rc == SSH_OK) { + ch->close_sent = 1; + ch->state = SCHAN_STATE_CLOSING; + } else if (rc == SSH_AGAIN) { + s->write_ready = 0; + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server channel close failed: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + return; + } + } + + if (ch->callbacks == NULL) { + for (int i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + did |= server_channel_read_stream(s, ch, 0); + did |= server_channel_read_stream(s, ch, 1); + if (!did) + break; + } + } + + if (ch->channel != NULL && ch->remote_close_seen && + ssh_channel_is_closed(ch->channel)) { + server_channel_finalize(ch); + } +} + +static void session_drive_channels(ssh_server_session_ctx *s) { + ssh_server_channel_ctx *prev = NULL; + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + ssh_server_channel_ctx *next = ch->next; + server_channel_drive(s, ch); + if (ch->state == SCHAN_STATE_ERROR) { + server_channel_finalize(ch); + } + if (ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) { + if (prev != NULL) { + prev->next = next; + } else { + s->channels = next; + } + session_retire_channel(s, ch); + } else { + prev = ch; + } + ch = next; + } +} + +static void session_fail(ssh_server_session_ctx *s, const char *msg) { + if (s == NULL || s->state == SESSION_STATE_CLOSED || s->state == SESSION_STATE_ERROR) + return; + s->state = SESSION_STATE_ERROR; + session_close_internal(s, msg, 1); +} + +static int session_check_reply_rc(ssh_server_session_ctx *s, int rc, const char *context) { + if (rc == SSH_OK || rc == SSH_AGAIN) + return 0; + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "%s: %s", + context, s != NULL && s->session != NULL ? ssh_get_error(s->session) : "unknown error"); + session_fail(s, errmsg); + return -1; +} + +static void session_restart_auth_timer(ssh_server_session_ctx *s, double timeout_sec) { + if (s == NULL) + return; + if (timeout_sec <= 0.0) + return; + if (s->auth_timer == NULL) { + s->auth_timer = acton_calloc(1, sizeof(uv_timer_t)); + s->auth_timer->data = s; + uv_timer_init(get_uv_loop(), s->auth_timer); + } else { + uv_timer_stop(s->auth_timer); + } + uv_timer_start(s->auth_timer, session_auth_timeout_cb, + (uint64_t)(timeout_sec * 1000.0), 0); +} + +static void session_start_attach_timer(ssh_server_session_ctx *s) { + if (s == NULL || s->attach_timer != NULL) + return; + s->attach_timer = acton_calloc(1, sizeof(uv_timer_t)); + s->attach_timer->data = s; + uv_timer_init(get_uv_loop(), s->attach_timer); + uv_timer_start(s->attach_timer, session_auth_timeout_cb, + (uint64_t)(SSH_ATTACH_TIMEOUT_SEC * 1000.0), 0); +} + +static void session_start_keyex_timer(ssh_server_session_ctx *s) { + if (s == NULL) + return; + double timeout = s->auth_timeout > SSH_KEYEX_TIMEOUT_SEC ? + s->auth_timeout : SSH_KEYEX_TIMEOUT_SEC; + session_restart_auth_timer(s, timeout); +} + +static void session_start_auth_timer(ssh_server_session_ctx *s) { + session_restart_auth_timer(s, s != NULL ? s->auth_timeout : 0.0); +} + +static void session_start_keepalive(ssh_server_session_ctx *s) { + if (s == NULL || !s->keepalive_enabled || s->keepalive_interval <= 0.0 || s->keepalive_timer != NULL) + return; + s->keepalive_timer = acton_calloc(1, sizeof(uv_timer_t)); + s->keepalive_timer->data = s; + uv_timer_init(get_uv_loop(), s->keepalive_timer); + uv_timer_start(s->keepalive_timer, session_keepalive_cb, (uint64_t)(s->keepalive_interval * 1000.0), (uint64_t)(s->keepalive_interval * 1000.0)); +} + +static void server_fail(ssh_server_ctx *s, const char *msg) { + if (s == NULL || s->state == SERVER_STATE_CLOSED || s->state == SERVER_STATE_ERROR) + return; + s->state = SERVER_STATE_ERROR; + if (!s->listen_ok) + server_notify_listen(s, msg); + server_close_internal(s, msg); +} + +static void session_auth_timeout_cb(uv_timer_t *timer) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)timer->data; + if (s == NULL) + return; + if (!s->attached && s->state == SESSION_STATE_KEYEX) { + session_fail(s, "SSH session attach timeout"); + return; + } + if (s->attached && s->state == SESSION_STATE_KEYEX) { + session_fail(s, "SSH key exchange timeout"); + return; + } + if (s->state == SESSION_STATE_AUTH) { + session_fail(s, "SSH authentication timeout"); + } } -void sshQ___ext_init__() { - // TODO: can we avoid custom malloc in libssh? like let libssh use stock - // malloc and instead we would explicitly call free() from a finalizer() - // All things related to buffers for receiving data and similarly would have - // to be allocated on the GC-heap though since that data is passed outside - // of the SSH actor - libssh_replace_allocator( - acton_gc_malloc, - acton_gc_realloc, - acton_gc_calloc, - noop_free, - acton_gc_strdup, - acton_gc_strndup); - int r = ssh_init(); - printf("SSH extension initialized %d\n", r); -} - -B_str sshQ_version () { - if (LIBSSH_VERSION_MINOR != 11) - return to$str("invalid"); - return to$str("0.11.0"); -} - -// TODO: crap function for test, to be replaced with something -int show_remote_processes(ssh_session session) -{ - ssh_channel channel; - int rc; - char buffer[256]; - int nbytes; - - channel = ssh_channel_new(session); - if (channel == NULL) - return SSH_ERROR; - - rc = ssh_channel_open_session(channel); - if (rc != SSH_OK) - { - ssh_channel_free(channel); - return rc; - } - - rc = ssh_channel_request_exec(channel, "ps aux"); - if (rc != SSH_OK) - { - ssh_channel_close(channel); - ssh_channel_free(channel); - return rc; - } - - nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0); - while (nbytes > 0) - { - if (write(1, buffer, nbytes) != (unsigned int) nbytes) - { - ssh_channel_close(channel); - ssh_channel_free(channel); - return SSH_ERROR; - } - nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0); - } - - if (nbytes < 0) - { - ssh_channel_close(channel); - ssh_channel_free(channel); - return SSH_ERROR; - } - - ssh_channel_send_eof(channel); - ssh_channel_close(channel); - ssh_channel_free(channel); - - return SSH_OK; -} - -$R sshQ_ClientD__initG_local (sshQ_Client self, $Cont c$cont) { - ssh_session session = ssh_new(); - if (session == NULL) { - //log_error("Failed to create SSH session"); - return $R_CONT(c$cont, B_None); - } - printf("session: %p\n", session); - self->_ssh_session = toB_u64((unsigned long)session); - printf("init self->session: %p\n", self->_ssh_session); - - ssh_options_set(session, SSH_OPTIONS_HOST, fromB_str(self->host)); - ssh_options_set(session, SSH_OPTIONS_PORT, &self->port->val); - ssh_options_set(session, SSH_OPTIONS_USER, fromB_str(self->username)); - - ssh_set_blocking(session, 1); - printf("Connecting to \n"); - int rc = ssh_connect(session); +static int session_start_poll(ssh_server_session_ctx *s, char *errmsg, size_t errmsg_len) { + if (s == NULL || s->session == NULL || s->fd < 0) { + snprintf(errmsg, errmsg_len, "Failed to start SSH session poll"); + return -1; + } + if (s->poll != NULL) + return 0; + + s->poll = acton_calloc(1, sizeof(uv_poll_t)); + s->poll->data = s; + int uv_rc = uv_poll_init(get_uv_loop(), s->poll, s->fd); + if (uv_rc != 0) { + uv_strerror_r(uv_rc, errmsg, errmsg_len); + acton_free(s->poll); + s->poll = NULL; + return -1; + } + s->poll_events = UV_READABLE | UV_WRITABLE; + uv_rc = uv_poll_start(s->poll, s->poll_events, session_poll_cb); + if (uv_rc != 0) { + uv_strerror_r(uv_rc, errmsg, errmsg_len); + return -1; + } + return 0; +} + +static void session_keepalive_cb(uv_timer_t *timer) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)timer->data; + if (s == NULL || s->session == NULL) + return; + if (s->state != SESSION_STATE_READY) + return; + int rc = ssh_send_ignore(s->session, "keepalive"); + if (rc == SSH_AGAIN) { + session_update_poll(s); + return; + } + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server keepalive failed: %s", ssh_get_error(s->session)); + session_fail(s, errmsg); + return; + } + session_update_poll(s); +} + +static void session_update_poll(ssh_server_session_ctx *s) { + if (s->poll == NULL || s->session == NULL) + return; + if (s->state == SESSION_STATE_CLOSED) + return; + if (uv_is_closing((uv_handle_t *)s->poll)) + return; + int status = ssh_get_status(s->session); + if (status & SSH_CLOSED_ERROR) { + session_fail(s, "SSH session closed with error"); + return; + } + if (status & SSH_CLOSED) { + session_close_internal(s, "SSH session closed", 1); + return; + } + int flags = ssh_get_poll_flags(s->session); + int pending = flags | status; + int events = UV_READABLE; +#ifdef UV_DISCONNECT + events |= UV_DISCONNECT; +#endif + if (pending & SSH_WRITE_PENDING) + events |= UV_WRITABLE; + if ((events & UV_WRITABLE) == 0 && session_needs_write(s)) + events |= UV_WRITABLE; + if (ssh_debug_enabled && (events != s->poll_events || (pending & SSH_WRITE_PENDING))) { + ssh_debug_log("server update poll: status=0x%x flags=0x%x pending=0x%x events=0x%x state=%d", + status, flags, pending, events, s->state); + } + if (events != s->poll_events) { + int uv_rc = uv_poll_start(s->poll, events, session_poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + session_fail(s, errmsg); + return; + } + s->poll_events = events; + } +} + +static void session_pump_io(ssh_server_session_ctx *s) { + if (s == NULL || s->session == NULL) + return; + int i; + for (i = 0; i < SSH_IO_PUMP_LIMIT; i++) { + int did = 0; + if (s->session == NULL) + return; + int has_data = fd_has_data(s->fd); + if (has_data) { + if (ssh_debug_enabled) { + ssh_debug_log("server pump: readable"); + } + ssh_set_fd_toread(s->session); + if (session_apply_poll_events(s->session, UV_READABLE) != 0) { + char errmsg[256] = {0}; + format_session_error(s->session, "SSH poll callback error", errmsg, sizeof(errmsg)); + session_fail(s, errmsg); + return; + } + session_drive(s); + did = 1; + } + if (s->session == NULL) + return; + if (!did) { + int status = ssh_get_status(s->session); + if (status & SSH_READ_PENDING) { + if (ssh_debug_enabled) { + ssh_debug_log("server pump: buffered read pending=0x%x", status); + } + session_drive(s); + /* Avoid spinning when only buffered data remains. */ + break; + } + } + if (!did) + break; + } + if (ssh_debug_enabled && i >= SSH_IO_PUMP_LIMIT) { + int status = ssh_get_status(s->session); + int flags = ssh_get_poll_flags(s->session); + ssh_debug_log("server pump: hit limit status=0x%x flags=0x%x", status, flags); + } +} + +static int session_needs_write(ssh_server_session_ctx *s) { + if (s == NULL) + return 0; + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + if (ch->pending_req != NULL || ch->write_head != NULL) + return 1; + if (ch->send_eof && !ch->eof_sent) + return 1; + if (ch->close_requested && !ch->close_sent) + return 1; + ch = ch->next; + } + return 0; +} + +static void session_drive(ssh_server_session_ctx *s) { + if (s == NULL) + return; + if (s->state == SESSION_STATE_ERROR || s->state == SESSION_STATE_CLOSED) + return; + if (!s->attached) + return; + if (s->state == SESSION_STATE_CLOSING) { + session_drive_channels(s); + session_finish_close(s); + return; + } + + int spin = 0; + while (1) { + if (s->state == SESSION_STATE_KEYEX) { + int rc = ssh_handle_key_exchange(s->session); + if (rc == SSH_OK) { + s->state = SESSION_STATE_AUTH; + stop_timer(&s->attach_timer, session_timer_close_cb); + session_start_auth_timer(s); + ssh_set_auth_methods(s->session, SSH_AUTH_METHOD_PASSWORD); + continue; + } else if (rc == SSH_AGAIN) { + int status = ssh_get_status(s->session); + if (status & SSH_WRITE_PENDING) + s->write_ready = 0; + if ((status & SSH_READ_PENDING) && spin++ < SSH_IO_PUMP_LIMIT) { + if (ssh_debug_enabled) { + ssh_debug_log("server drive: buffered read pending during key exchange"); + } + continue; + } + session_update_poll(s); + return; + } else { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH key exchange failed: %s", ssh_get_error(s->session)); + session_fail(s, errmsg); + return; + } + } + + if (s->state == SESSION_STATE_AUTH) { + if (s->pending_auth != NULL) { + session_update_poll(s); + return; + } + ssh_message msg = ssh_message_get(s->session); + if (msg == NULL) { + session_update_poll(s); + return; + } + int type = ssh_message_type(msg); + if (type == SSH_REQUEST_SERVICE) { + int rc; + const char *service = ssh_message_service_service(msg); + if (service && strcmp(service, "ssh-userauth") == 0) { + rc = ssh_message_service_reply_success(msg); + } else { + rc = ssh_message_reply_default(msg); + } + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH service reply failed") != 0) + return; + continue; + } + if (type == SSH_REQUEST_AUTH && ssh_message_subtype(msg) == SSH_AUTH_METHOD_PASSWORD) { + if (s->on_auth == NULL) { + int rc; + ssh_message_auth_set_methods(msg, SSH_AUTH_METHOD_PASSWORD); + rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH auth reject failed") != 0) + return; + continue; + } + const char *user = ssh_message_auth_user(msg); +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + const char *pass = ssh_message_auth_password(msg); +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + s->pending_auth = msg; + sshQ_AuthRequest req = sshQ_AuthRequestG_new( + to$str((char *)"password"), + to$str((char *)(user ? user : "")), + pass ? to$str((char *)(pass)) : B_None, + B_None); + $action2 f = ($action2)s->on_auth; + f->$class->__asyn__(f, session_actor_ref(s), req); + session_update_poll(s); + return; + } + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH auth reply failed") != 0) + return; + continue; + } + + if (s->state == SESSION_STATE_READY) { + if (s->pending_channel_open != NULL) { + session_drive_channels(s); + session_update_poll(s); + return; + } + while (1) { + ssh_message msg = ssh_message_get(s->session); + if (msg == NULL) + break; + int type = ssh_message_type(msg); + if (type == SSH_REQUEST_CHANNEL_OPEN) { + if (s->pending_channel_open != NULL) { + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return; + } else if (ssh_message_subtype(msg) != SSH_CHANNEL_SESSION) { + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return; + } else if (session_channel_limit_reached(s)) { + if (ssh_debug_enabled) { + ssh_debug_log("server session: rejecting channel open limit=%d", + s->server ? s->server->max_channels_per_session : 0); + } + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return; + } else if (s->on_channel_open == NULL) { + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return; + } else { + s->pending_channel_open = msg; + $action f = ($action)s->on_channel_open; + f->$class->__asyn__(f, session_actor_ref(s)); + break; + } + } else if (type == SSH_REQUEST_CHANNEL) { + ssh_channel chan = ssh_message_channel_request_channel(msg); + ssh_server_channel_ctx *ch = server_channel_from_ssh(s, chan); + if (ch == NULL || ch->pending_req != NULL) { + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0) + return; + } else if (ssh_message_subtype(msg) == SSH_CHANNEL_REQUEST_EXEC) { + if (s->on_exec == NULL) { + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH exec reject failed") != 0) + return; + } else { + const char *cmd = ssh_message_channel_request_command(msg); + ch->pending_req = msg; + ch->pending_req_type = SCHAN_REQ_EXEC; + $action3 f = ($action3)s->on_exec; + f->$class->__asyn__(f, session_actor_ref(s), server_channel_actor_ref(ch), + to$str((char *)(cmd ? cmd : ""))); + break; + } + } else if (ssh_message_subtype(msg) == SSH_CHANNEL_REQUEST_SUBSYSTEM) { + if (s->on_subsystem == NULL) { + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH subsystem reject failed") != 0) + return; + } else { + const char *name = ssh_message_channel_request_subsystem(msg); + ch->pending_req = msg; + ch->pending_req_type = SCHAN_REQ_SUBSYSTEM; + $action3 f = ($action3)s->on_subsystem; + f->$class->__asyn__(f, session_actor_ref(s), server_channel_actor_ref(ch), + to$str((char *)(name ? name : ""))); + break; + } + } else { + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0) + return; + } + } else if (type == SSH_REQUEST_SERVICE) { + int rc; + const char *service = ssh_message_service_service(msg); + if (service && strcmp(service, "ssh-connection") == 0) { + rc = ssh_message_service_reply_success(msg); + } else { + rc = ssh_message_reply_default(msg); + } + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH connection service reply failed") != 0) + return; + } else { + int rc = ssh_message_reply_default(msg); + ssh_message_free(msg); + if (session_check_reply_rc(s, rc, "SSH request reply failed") != 0) + return; + } + } + session_drive_channels(s); + session_update_poll(s); + return; + } + return; + } +} + +static void session_poll_cb(uv_poll_t *handle, int status, int events) { + ssh_server_session_ctx *s = (ssh_server_session_ctx *)handle->data; + if (s == NULL) + return; + if (ssh_debug_enabled) { + ssh_debug_log("server poll: status=%d events=0x%x state=%d", status, events, s->state); + } + if (status < 0) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH session poll error: %s", uv_strerror(status)); + session_fail(s, errmsg); + return; + } + int libssh_events = 0; + if ((events & UV_READABLE) && fd_has_data(s->fd)) { + ssh_set_fd_toread(s->session); + libssh_events |= UV_READABLE; + } +#ifdef UV_DISCONNECT + if (events & UV_DISCONNECT) { + ssh_set_fd_toread(s->session); + libssh_events |= UV_DISCONNECT; + } +#endif + if ((events & UV_WRITABLE) && fd_can_write(s->fd)) { + s->write_ready = 1; + ssh_set_fd_towrite(s->session); + libssh_events |= UV_WRITABLE; + } + if (session_apply_poll_events(s->session, libssh_events) != 0) { + char errmsg[256] = {0}; + format_session_error(s->session, "SSH poll callback error", errmsg, sizeof(errmsg)); + session_fail(s, errmsg); + return; + } + session_drive(s); + session_pump_io(s); + s->write_ready = 0; +} + +static void server_poll_cb(uv_poll_t *handle, int status, int events) { + ssh_server_ctx *s = (ssh_server_ctx *)handle->data; + if (s == NULL) + return; + if (ssh_debug_enabled) { + ssh_debug_log("server listen poll: status=%d events=0x%x", status, events); + } + if (status < 0) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server poll error: %s", uv_strerror(status)); + server_fail(s, errmsg); + return; + } + if (events & UV_READABLE) + server_accept(s); +} + +static void server_accept(ssh_server_ctx *s) { + if (s == NULL || s->state != SERVER_STATE_LISTENING) + return; + + int accepted = 0; + while (accepted < SSH_SERVER_ACCEPT_LIMIT) { + socket_t fd = accept(s->fd, NULL, NULL); + if (fd == SSH_INVALID_SOCKET) { + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH accept failed: %s", strerror(errno)); + server_fail(s, errmsg); + return; + } + accepted++; + + if (server_session_limit_reached(s)) { + if (ssh_debug_enabled) { + ssh_debug_log("server accept: rejecting fd=%d session limit=%d", + (int)fd, s->max_sessions); + } + close(fd); + continue; + } + + if (fd_set_nonblocking(fd) != 0) { + log_warn("SSH accept: failed to set accepted fd nonblocking"); + close(fd); + continue; + } + + ssh_session session = ssh_new(); + if (session == NULL) { + log_warn("SSH accept: failed to create SSH session"); + close(fd); + continue; + } + + if (ssh_debug_enabled) { + ssh_debug_log("server accept: actor=%p", (void *)server_actor_ref(s)); + } + + int rc = ssh_bind_accept_fd(s->bind, session, fd); + if (rc != SSH_OK) { + if (ssh_debug_enabled) { + ssh_debug_log("server accept: accept_fd failed: %s", ssh_get_error(s->bind)); + } + close(fd); + ssh_free(session); + continue; + } + + ssh_set_blocking(session, 0); + ssh_server_session_ctx *sess = acton_calloc(1, sizeof(ssh_server_session_ctx)); + sess->server = s; + sess->session = session; + sess->state = SESSION_STATE_KEYEX; + sess->pending_id = alloc_pending_session_id(); + sess->fd = ssh_get_fd(session); + sshQ_Server act = server_actor_ref(s); + sess->owner_wt = act ? (int)act->$affinity : 0; + sess->auth_timeout = act ? fromB_float(act->_auth_timeout) : 0.0; + if (sess->fd < 0) { + log_warn("SSH accept: failed to get accepted session fd"); + ssh_disconnect(session); + ssh_free(session); + continue; + } + session_start_attach_timer(sess); + + sess->next = s->sessions; + s->sessions = sess; + + if (ssh_debug_enabled) { + ssh_debug_log("server accept: session fd=%d", sess->fd); + } + + if (act) { + if (ssh_debug_enabled) { + ssh_debug_log("server accept: scheduling session pending act=%p session=%llu", + (void *)act, (unsigned long long)sess->pending_id); + } + act->$class->on_session_pending(act, toB_u64(sess->pending_id)); + if (ssh_debug_enabled) { + ssh_debug_log("server accept: on_session_pending call returned session=%llu", + (unsigned long long)sess->pending_id); + } + } + } +} + +static void server_remove_session(ssh_server_ctx *s, ssh_server_session_ctx *sess) { + if (s == NULL || sess == NULL) + return; + ssh_server_session_ctx *prev = NULL; + ssh_server_session_ctx *cur = s->sessions; + while (cur != NULL) { + if (cur == sess) { + if (prev != NULL) + prev->next = cur->next; + else + s->sessions = cur->next; + break; + } + prev = cur; + cur = cur->next; + } + server_maybe_release(s); +} + +static void server_finalize(ssh_server_ctx *s) { + if (s == NULL || s->close_finalized) + return; + s->close_finalized = 1; + + if (s->bind != NULL) { + ssh_bind_free(s->bind); + s->bind = NULL; + } + if (s->hostkey != NULL) { + ssh_key_free(s->hostkey); + s->hostkey = NULL; + } + + server_notify_close(s, s->close_reason ? s->close_reason : "closed"); + s->state = SERVER_STATE_CLOSED; + sshQ_Server actor = server_actor_ref(s); + if (actor) + actor->_server = toB_u64(0); + STORE_HIDDEN_PTR(s->actor, NULL); + server_maybe_release(s); +} + +static void session_finalize(ssh_server_session_ctx *s) { + if (s == NULL || s->close_finalized) + return; + s->close_finalized = 1; + + if (s->session != NULL) { + ssh_disconnect(s->session); + ssh_free(s->session); + s->session = NULL; + } + session_free_retired_channels(s); + + server_remove_session(s->server, s); + session_notify_close(s, s->close_reason ? s->close_reason : "closed"); + s->state = SESSION_STATE_CLOSED; + s->pending_id = 0; + sshQ_ServerSession actor = session_actor_ref(s); + if (actor) + actor->_session_id = toB_u64(0); + s->actor = NULL; + session_maybe_release(s); +} + +static void session_reject_pending_messages(ssh_server_session_ctx *s) { + if (s->pending_auth) { + ssh_message_reply_default(s->pending_auth); + ssh_message_free(s->pending_auth); + s->pending_auth = NULL; + } + if (s->pending_channel_open) { + ssh_message_reply_default(s->pending_channel_open); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + } +} + +static void session_abort_channels(ssh_server_session_ctx *s) { + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + ssh_server_channel_ctx *next = ch->next; + server_channel_notify_close(ch, "Session closed"); + server_channel_finalize(ch); + session_retire_channel(s, ch); + ch = next; + } + s->channels = NULL; +} + +static void session_request_channel_close(ssh_server_session_ctx *s) { + ssh_server_channel_ctx *ch = s->channels; + while (ch != NULL) { + if (ch->pending_req) { + ssh_message_reply_default(ch->pending_req); + ssh_message_free(ch->pending_req); + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + } + if (ch->state != SCHAN_STATE_CLOSED && ch->state != SCHAN_STATE_ERROR) { + ch->send_eof = 1; + ch->close_requested = 1; + } + ch = ch->next; + } +} + +static void session_finish_close(ssh_server_session_ctx *s) { + if (s == NULL || s->state != SESSION_STATE_CLOSING) + return; + if (s->close_force) { + if (s->poll != NULL) { + close_poll(&s->poll, session_poll_close_cb); + s->poll_events = 0; + return; + } + session_finalize(s); + return; + } + if (s->channels != NULL) { + session_update_poll(s); + return; + } + if (s->session != NULL && session_has_pending_write(s->session)) { + session_update_poll(s); + return; + } + if (s->poll != NULL) { + close_poll(&s->poll, session_poll_close_cb); + s->poll_events = 0; + return; + } + session_finalize(s); +} + +static void server_close_internal(ssh_server_ctx *s, const char *reason) { + if (s == NULL || s->state == SERVER_STATE_CLOSED || s->state == SERVER_STATE_CLOSING) + return; + int force_sessions = (s->state == SERVER_STATE_ERROR); + + if (!s->listen_ok && !s->listen_notified) { + server_notify_listen(s, reason ? reason : "closed"); + } + + s->state = SERVER_STATE_CLOSING; + if (reason != NULL && s->close_reason == NULL) + s->close_reason = acton_strdup(reason); + + if (s->poll != NULL) { + close_poll(&s->poll, server_poll_close_cb); + } + + ssh_server_session_ctx *sess = s->sessions; + while (sess != NULL) { + ssh_server_session_ctx *next = sess->next; + session_close_internal(sess, "Server closed", force_sessions); + sess = next; + } + + if (s->poll != NULL) + return; + + server_finalize(s); +} + +static void session_close_internal(ssh_server_session_ctx *s, const char *reason, int force_close) { + if (s == NULL || s->state == SESSION_STATE_CLOSED) + return; + if (!force_close && s->state != SESSION_STATE_READY) + force_close = 1; + if (reason != NULL && s->close_reason == NULL) + s->close_reason = acton_strdup(reason); + if (s->state == SESSION_STATE_CLOSING) { + if (force_close && !s->close_force) { + s->close_force = 1; + session_reject_pending_messages(s); + session_abort_channels(s); + } + session_finish_close(s); + return; + } + + stop_timer(&s->auth_timer, session_timer_close_cb); + stop_timer(&s->attach_timer, session_timer_close_cb); + stop_timer(&s->keepalive_timer, session_timer_close_cb); + + s->state = SESSION_STATE_CLOSING; + s->close_force = force_close; + session_reject_pending_messages(s); + if (force_close) { + session_abort_channels(s); + session_finish_close(s); + return; + } + + session_request_channel_close(s); + session_drive(s); +} + +static enum ssh_keytypes_e parse_hostkey_type(const char *type_str, int *param_out) { + if (type_str == NULL) + return SSH_KEYTYPE_UNKNOWN; + if (strcmp(type_str, "ed25519") == 0) { + if (param_out) + *param_out = 0; + return SSH_KEYTYPE_ED25519; + } + if (strcmp(type_str, "rsa") == 0) { + if (param_out) + *param_out = 2048; + return SSH_KEYTYPE_RSA; + } + if (strcmp(type_str, "ecdsa") == 0) { + if (param_out) + *param_out = 256; + return SSH_KEYTYPE_ECDSA; + } + return SSH_KEYTYPE_UNKNOWN; +} + +$R sshQ_ServerD__pin_affinityG_local(sshQ_Server self, $Cont c$cont) { + pin_actor_affinity(); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerD__initG_local(sshQ_Server self, $Cont c$cont) { + ssh_configure_libssh_logging(); + if (ssh_debug_enabled) { + log_info("ssh server init: self=%p", (void *)self); + } + ssh_server_ctx *s = acton_calloc(1, sizeof(ssh_server_ctx)); + STORE_HIDDEN_PTR(s->actor, self); + s->on_listen = ($action2)self->_on_listen; + s->on_close = ($action2)self->_on_close; + s->state = SERVER_STATE_INIT; + s->max_sessions = fromB_int(self->_max_sessions); + s->max_channels_per_session = fromB_int(self->_max_channels_per_session); + + self->_server = toB_u64((unsigned long)s); + + s->bind = ssh_bind_new(); + if (s->bind == NULL) { + server_fail(s, "Failed to create SSH bind"); + return $R_CONT(c$cont, B_None); + } + + int rc; + bool process_config = false; + + rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_PROCESS_CONFIG, &process_config); if (rc != SSH_OK) { - //log_error("Error connecting to SSH server: %s", ssh_get_error(session)); - $action2 f = ($action2) self->on_close; - f->$class->__asyn__(f, self, to$str(ssh_get_error(session))); + server_fail(s, "Failed to disable SSH bind config processing"); return $R_CONT(c$cont, B_None); } - rc = ssh_userauth_password(session, NULL, fromB_str(self->password)); - if (rc == SSH_OK) { - printf("Connected\n"); - show_remote_processes(session); + const char *host = (const char *)fromB_str(self->_host); + int port = (int)fromB_u16(self->_port); + rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_BINDADDR, host); + if (rc != SSH_OK) { + server_fail(s, "Failed to set bind address"); + return $R_CONT(c$cont, B_None); + } + rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_BINDPORT, &port); + if (rc != SSH_OK) { + server_fail(s, "Failed to set bind port"); + return $R_CONT(c$cont, B_None); + } + + if (self->_host_key_path != B_None) { + const char *path = (const char *)fromB_str(self->_host_key_path); + rc = ssh_pki_import_privkey_file(path, NULL, NULL, NULL, &s->hostkey); + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "Failed to load host key: %s", ssh_get_error(s->bind)); + server_fail(s, errmsg); + return $R_CONT(c$cont, B_None); + } } else { - printf("Error: %s\n", ssh_get_error(session)); + const char *type_str = (const char *)fromB_str(self->_host_key_type); + int param = fromB_int(self->_host_key_bits); + int default_param = 0; + enum ssh_keytypes_e type = parse_hostkey_type(type_str, &default_param); + if (type == SSH_KEYTYPE_UNKNOWN) { + server_fail(s, "Unsupported host key type"); + return $R_CONT(c$cont, B_None); + } + if (type == SSH_KEYTYPE_ED25519) { + param = 0; + } else if (param <= 0) { + param = default_param; + } + rc = ssh_pki_generate(type, param, &s->hostkey); + if (rc != SSH_OK) { + server_fail(s, "Failed to generate host key"); + return $R_CONT(c$cont, B_None); + } } -// self->_connected = true; -// $action f = ($action) self->on_connect; -// f->$class->__asyn__(f, self); + rc = ssh_bind_options_set(s->bind, SSH_BIND_OPTIONS_IMPORT_KEY, s->hostkey); + if (rc != SSH_OK) { + server_fail(s, "Failed to set host key"); + return $R_CONT(c$cont, B_None); + } + /* ssh_bind takes ownership of IMPORT_KEY and frees it via ssh_bind_free(). */ + s->hostkey = NULL; + + ssh_bind_set_blocking(s->bind, 0); + + rc = ssh_bind_listen(s->bind); + if (rc != SSH_OK) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH listen failed: %s", ssh_get_error(s->bind)); + server_fail(s, errmsg); + return $R_CONT(c$cont, B_None); + } + + s->fd = ssh_bind_get_fd(s->bind); + if (s->fd < 0) { + server_fail(s, "Failed to get bind fd"); + return $R_CONT(c$cont, B_None); + } + if (fd_set_nonblocking(s->fd) != 0) { + server_fail(s, "Failed to set bind fd nonblocking"); + return $R_CONT(c$cont, B_None); + } + + s->poll = acton_calloc(1, sizeof(uv_poll_t)); + s->poll->data = s; + int uv_rc = uv_poll_init(get_uv_loop(), s->poll, s->fd); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + server_fail(s, errmsg); + return $R_CONT(c$cont, B_None); + } + uv_rc = uv_poll_start(s->poll, UV_READABLE, server_poll_cb); + if (uv_rc != 0) { + char errmsg[256] = {0}; + uv_strerror_r(uv_rc, errmsg + strlen(errmsg), sizeof(errmsg) - strlen(errmsg)); + server_fail(s, errmsg); + return $R_CONT(c$cont, B_None); + } + + s->state = SERVER_STATE_LISTENING; + server_notify_listen(s, NULL); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerD_closeG_local(sshQ_Server self, $Cont c$cont) { + ssh_server_ctx *s = server_from_actor(self); + if (s == NULL) + return $R_CONT(c$cont, B_None); + server_close_internal(s, "closed"); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerD__cleanup_nativeG_local(sshQ_Server self, $Cont c$cont) { + ssh_server_ctx *s = server_from_actor(self); + if (s != NULL) + server_close_internal(s, "collected"); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD__pin_affinityG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_ctx *server = server_from_actor(self->server); + ssh_server_session_ctx *s = session_from_pending_token(server, self->session_id); + if (s != NULL && s->owner_wt >= 0) { + set_actor_affinity(s->owner_wt); + } else { + pin_actor_affinity(); + } + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD__attachG_local(sshQ_ServerSession self, $Cont c$cont, B_u64 session_id) { + ssh_server_ctx *server = server_from_actor(self->server); + ssh_server_session_ctx *s = session_from_pending_token(server, session_id); + if (s == NULL) + return $R_CONT(c$cont, B_None); + if (s->session == NULL || s->state == SESSION_STATE_CLOSED || + s->state == SESSION_STATE_CLOSING || s->state == SESSION_STATE_ERROR) { + if (ssh_debug_enabled) { + ssh_debug_log("server session attach skipped closed id=%llu", + (unsigned long long)fromB_u64(session_id)); + } + return $R_CONT(c$cont, B_None); + } + if (ssh_debug_enabled) { + ssh_debug_log("server session attach: id=%llu", (unsigned long long)fromB_u64(session_id)); + } + char errmsg[256] = {0}; + if (session_start_poll(s, errmsg, sizeof(errmsg)) != 0) { + session_close_internal(s, errmsg, 1); + return $R_CONT(c$cont, B_None); + } + s->actor = self; + self->_session_id = toB_u64((unsigned long)s); + s->attached = 1; + s->pending_id = 0; + s->on_auth = ($action2)self->_on_auth; + s->on_channel_open = ($action)self->_on_channel_open; + s->on_exec = (self->_on_exec == B_None) ? NULL : ($action3)self->_on_exec; + s->on_subsystem = (self->_on_subsystem == B_None) ? NULL : ($action3)self->_on_subsystem; + s->on_close = (self->_on_close == B_None) ? NULL : ($action2)self->_on_close; + s->auth_timeout = fromB_float(self->server->_auth_timeout); + s->keepalive_interval = fromB_float(self->server->_keepalive_interval); + s->keepalive_enabled = fromB_bool(self->server->_keepalive_enabled) ? 1 : 0; + if (s->state == SESSION_STATE_KEYEX) { + stop_timer(&s->attach_timer, session_timer_close_cb); + session_start_keyex_timer(s); + } + if (ssh_debug_enabled) { + ssh_debug_log("server session attach: callbacks set"); + } + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD__drive_attachedG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || !s->attached) + return $R_CONT(c$cont, B_None); + if (ssh_debug_enabled) { + ssh_debug_log("server session drive attached"); + } + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_accept_authG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || s->pending_auth == NULL) + return $R_CONT(c$cont, B_None); + + int rc = ssh_message_auth_reply_success(s->pending_auth, 0); + ssh_message_free(s->pending_auth); + s->pending_auth = NULL; + if (session_check_reply_rc(s, rc, "SSH auth accept failed") != 0) + return $R_CONT(c$cont, B_None); + s->state = SESSION_STATE_READY; + stop_timer(&s->auth_timer, session_timer_close_cb); + session_start_keepalive(s); + session_drive(s); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_reject_authG_local(sshQ_ServerSession self, $Cont c$cont, B_str reason) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || s->pending_auth == NULL) + return $R_CONT(c$cont, B_None); + + ssh_message_auth_set_methods(s->pending_auth, SSH_AUTH_METHOD_PASSWORD); + int rc = ssh_message_reply_default(s->pending_auth); + ssh_message_free(s->pending_auth); + s->pending_auth = NULL; + (void)reason; + if (session_check_reply_rc(s, rc, "SSH auth reject failed") != 0) + return $R_CONT(c$cont, B_None); + session_drive(s); + + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_accept_channel_openG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, + $action on_data, + $action on_stderr, + $action on_close) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || s->pending_channel_open == NULL) + return $R_CONT(c$cont, B_None); + + ssh_channel chan = ssh_channel_new(s->session); + if (chan == NULL) { + int rc = ssh_message_reply_default(s->pending_channel_open); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + if (on_close) { + $action2 f = ($action2)on_close; + f->$class->__asyn__(f, channel, to$str((char *)"Failed to accept channel open")); + } + if (session_check_reply_rc(s, rc, "SSH channel open accept failed") != 0) + return $R_CONT(c$cont, B_None); + session_drive(s); + return $R_CONT(c$cont, B_None); + } + int rc = ssh_message_channel_request_open_reply_accept_channel(s->pending_channel_open, chan); + if (rc != SSH_OK && rc != SSH_AGAIN) { + ssh_channel_free(chan); + rc = ssh_message_reply_default(s->pending_channel_open); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + if (on_close) { + $action2 f = ($action2)on_close; + f->$class->__asyn__(f, channel, to$str((char *)"Failed to accept channel open")); + } + if (session_check_reply_rc(s, rc, "SSH channel open accept failed") != 0) + return $R_CONT(c$cont, B_None); + session_drive(s); + return $R_CONT(c$cont, B_None); + } + if (rc == SSH_AGAIN) + s->write_ready = 0; + ssh_channel_set_blocking(chan, 0); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + + ssh_server_channel_ctx *ch = acton_calloc(1, sizeof(ssh_server_channel_ctx)); + ch->channel = chan; + ch->session = s; + ch->actor = channel; + ch->callbacks = NULL; + ch->state = SCHAN_STATE_OPEN; + ch->send_eof = 0; + ch->close_requested = 0; + ch->close_sent = 0; + ch->eof_sent = 0; + ch->stdout_eof = 0; + ch->stderr_eof = 0; + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + ch->on_data = ($action2)on_data; + ch->on_stderr = ($action2)on_stderr; + ch->on_close = ($action2)on_close; + if (server_channel_setup_callbacks(ch) != SSH_OK) { + server_channel_notify_close(ch, "Failed to set SSH channel callbacks"); + if (ch->channel != NULL) + ssh_channel_close(ch->channel); + server_channel_finalize(ch); + acton_free(ch); + session_drive(s); + return $R_CONT(c$cont, B_None); + } + if (ssh_debug_enabled) { + ssh_debug_log("server channel new ch=%p callbacks=%p", (void *)ch, (void *)ch->callbacks); + } + + ch->next = s->channels; + s->channels = ch; + channel->_channel_id = toB_u64((unsigned long)ch); + + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_reject_channelG_local(sshQ_ServerSession self, $Cont c$cont, B_str reason) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL || s->pending_channel_open == NULL) + return $R_CONT(c$cont, B_None); + + int rc = ssh_message_reply_default(s->pending_channel_open); + ssh_message_free(s->pending_channel_open); + s->pending_channel_open = NULL; + (void)reason; + if (session_check_reply_rc(s, rc, "SSH channel open reject failed") != 0) + return $R_CONT(c$cont, B_None); + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_closeG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s == NULL) + return $R_CONT(c$cont, B_None); + session_close_internal(s, "closed", 0); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD__cleanup_nativeG_local(sshQ_ServerSession self, $Cont c$cont) { + ssh_server_session_ctx *s = session_from_actor(self); + if (s != NULL) + session_close_internal(s, "collected", 1); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_accept_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (server_channel_validate(s, ch) != 0 || ch->pending_req == NULL) + return $R_CONT(c$cont, B_None); + + int rc = ssh_message_channel_request_reply_success(ch->pending_req); + ssh_message_free(ch->pending_req); + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + if (session_check_reply_rc(s, rc, "SSH channel request accept failed") != 0) + return $R_CONT(c$cont, B_None); + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_reject_requestG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_str reason) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (server_channel_validate(s, ch) != 0 || ch->pending_req == NULL) + return $R_CONT(c$cont, B_None); + + int rc = ssh_message_reply_default(ch->pending_req); + ssh_message_free(ch->pending_req); + ch->pending_req = NULL; + ch->pending_req_type = SCHAN_REQ_NONE; + (void)reason; + if (session_check_reply_rc(s, rc, "SSH channel request reject failed") != 0) + return $R_CONT(c$cont, B_None); + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_writeG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_bytes data) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (server_channel_validate(s, ch) != 0) + return $R_CONT(c$cont, B_None); + server_channel_queue_write(ch, data, 0); + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_write_stderrG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_bytes data) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (server_channel_validate(s, ch) != 0) + return $R_CONT(c$cont, B_None); + server_channel_queue_write(ch, data, 1); + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_send_eofG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (server_channel_validate(s, ch) != 0) + return $R_CONT(c$cont, B_None); + ch->send_eof = 1; + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_send_exit_statusG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel, B_int status) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (server_channel_validate(s, ch) != 0 || ch->channel == NULL) + return $R_CONT(c$cont, B_None); + int rc = ssh_channel_request_send_exit_status(ch->channel, fromB_int(status)); + if (rc != SSH_OK && rc != SSH_AGAIN) { + char errmsg[256] = {0}; + snprintf(errmsg, sizeof(errmsg), "SSH server send exit status failed: %s", ssh_get_error(s->session)); + server_channel_notify_close(ch, errmsg); + ch->state = SCHAN_STATE_ERROR; + } + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerSessionD_channel_closeG_local(sshQ_ServerSession self, $Cont c$cont, sshQ_ServerChannel channel) { + ssh_server_session_ctx *s = session_from_actor(self); + ssh_server_channel_ctx *ch = server_channel_from_actor(channel); + if (server_channel_validate(s, ch) != 0) + return $R_CONT(c$cont, B_None); + ch->send_eof = 1; + ch->close_requested = 1; + session_drive(s); + return $R_CONT(c$cont, B_None); +} + +$R sshQ_ServerChannelD__cleanup_nativeG_local(sshQ_ServerChannel self, $Cont c$cont) { + ssh_server_channel_ctx *ch = server_channel_from_actor(self); + if (ch == NULL || ch->state == SCHAN_STATE_CLOSED || ch->state == SCHAN_STATE_ERROR) + return $R_CONT(c$cont, B_None); + ssh_server_session_ctx *s = ch->session; + if (s == NULL) + return $R_CONT(c$cont, B_None); + ch->send_eof = 1; + ch->close_requested = 1; + session_drive(s); return $R_CONT(c$cont, B_None); } diff --git a/src/test_ssh_server.act b/src/test_ssh_server.act new file mode 100644 index 0000000..d59c60f --- /dev/null +++ b/src/test_ssh_server.act @@ -0,0 +1,5881 @@ +import random +import testing +import logging + +import acton.rts +import net +import ssh + +LISTEN_RETRY_LIMIT = 16 + + +def _pick_test_port(base: int, span: int, attempt: int, salt: int) -> int: + return base + (random.randint(0, span - 1) + attempt * salt) % span + + +def _is_retryable_listen_error(err: str) -> bool: + return err.find("Address already in use") >= 0 + + +actor ServerClientTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + NETCONF_END = b"]]>]]>" + NETCONF_HELLO = b'\n' + \ + b'\n' + \ + b' \n' + \ + b' urn:ietf:params:netconf:base:1.0\n' + \ + b' \n' + \ + b']]>]]>' + SERVER_HELLO = b'urn:ietf:params:netconf:base:1.0]]>]]>' + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_session: ?ssh.ServerSession = None + var server_buf = b"" + var client_buf = b"" + var client_channel: ?ssh.Channel = None + var shutdown_started = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_channel_closed = False + var server_channel_closed = False + var client_close_requested = False + var server_close_requested = False + var client_channel_close_requested = False + var session_close_requested = False + var shutdown_timer_started = False + + def maybe_finish(): + if done: + return + if shutdown_started and client_closed and server_closed and session_closed and \ + client_channel_closed and server_channel_closed: + log.info("test success", None) + done = True + t.success() + + def request_client_channel_close(): + if client_channel_close_requested: + return + client_channel_close_requested = True + if client_channel is not None: + client_channel.close() + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def request_session_close(): + if session_close_requested or session_closed: + return + session_close_requested = True + if server_session is not None: + log.info("request session close", None) + server_session.close() + else: + log.info("request session close: no session", None) + + def begin_shutdown(reason: str): + if shutdown_started: + return + shutdown_started = True + log.info("shutdown start: " + reason, None) + request_session_close() + request_client_channel_close() + request_client_close() + if not shutdown_timer_started: + shutdown_timer_started = True + after 10.0: on_timeout() + maybe_finish() + + def finish_error(msg: str): + if done: + return + log.info("test error: " + msg, None) + done = True + shutdown_started = True + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + log.info("timeout fired", None) + finish_error("timeout waiting for SSH test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if done: + return + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + log.info("server listening", None) + server = s + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + log.info("server session ready", None) + server_session = sess + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + log.info("auth request for " + req.user, None) + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + log.info("channel open requested", None) + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + server_channel = ch + sess.accept_channel(ch) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + log.info("subsystem request: " + name, None) + if name == "netconf": + ch.accept_request() + else: + ch.reject_request("unsupported subsystem") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + log.info("exec request: " + cmd, None) + ch.reject_request("exec disabled") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + if shutdown_started: + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_buf += data + log.info("server data len " + str(len(data)), None) + if server_buf.find(NETCONF_END) >= 0: + log.info("server got NETCONF hello", None) + ch.write(SERVER_HELLO) + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_closed = True + maybe_finish() + + def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo): + log.info("hostkey state: " + state, None) + client.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + log.info("client connect error: " + err, None) + finish_error("client connect error: " + err) + return + log.info("client connected", None) + client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if shutdown_started: + request_session_close() + maybe_finish() + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + log.info("client channel open error: " + err, None) + finish_error("client channel open error: " + err) + return + log.info("client channel open", None) + ch.request_subsystem("netconf") + ch.write(NETCONF_HELLO) + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + client_buf += data + log.info("client data len " + str(len(data)), None) + if client_buf.find(NETCONF_END) >= 0: + log.info("client got NETCONF hello", None) + begin_shutdown("client got NETCONF hello") + + def ch_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + log.info("client exit " + str(code), None) + return + + def ch_close(ch: ssh.Channel, reason: str): + log.info("client channel closed: " + reason, None) + client_channel_closed = True + if shutdown_started and client is not None: + client.close() + maybe_finish() + + def start_client(): + log.info("starting client", None) + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 42000 + (random.randint(0, 2000) + attempts * 37) % 2000 + log.info("starting server on " + str(port), None) + server_close_requested = False + server_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + + +def _test_ssh_server_subsystem(t: testing.EnvT): + """Exercise server subsystem handling with a local client.""" + ServerClientTester(t) + + +actor ClientCloseFlushTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var payload_len = 1024 * 1024 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_buf = b"" + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def finish_success(): + if done: + return + done = True + log.info("test success", None) + request_client_close() + request_server_close() + t.success() + + def on_timeout(): + if done: + return + finish_error("timeout waiting for client close flush test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + log.info("server listening", None) + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + ch.reject_request("exec disabled") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name == "netconf": + ch.accept_request() + else: + ch.reject_request("unsupported subsystem") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_buf += data + return + if len(server_buf) != payload_len: + finish_error("server received " + str(len(server_buf)) + " bytes, expected " + str(payload_len)) + return + finish_success() + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo): + client.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + ch.request_subsystem("netconf") + payload = b"A" * payload_len + ch.write(payload) + ch.close() + + def ch_out(ch: ssh.Channel, data: ?bytes): + return + + def ch_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + return + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 44000 + (random.randint(0, 2000) + attempts * 59) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 15.0: on_timeout() + + +def _test_ssh_client_close_flush(t: testing.EnvT): + """Client close should flush queued writes before EOF.""" + ClientCloseFlushTester(t) + + +actor ClientSessionCloseFlushTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var payload_len = 1024 * 1024 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_buf = b"" + var client_close_requested = False + var server_close_requested = False + var saw_terminal_event = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def finish_success(): + if done: + return + done = True + log.info("test success", None) + request_client_close() + request_server_close() + t.success() + + def check_server_bytes(reason: str): + if done or not saw_terminal_event: + return + if len(server_buf) != payload_len: + finish_error(reason + ": server received " + str(len(server_buf)) + " bytes, expected " + str(payload_len)) + return + finish_success() + + def on_timeout(): + if done: + return + finish_error("timeout waiting for client session close flush test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + log.info("server listening", None) + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + ch.reject_request("exec disabled") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name == "netconf": + ch.accept_request() + else: + ch.reject_request("unsupported subsystem") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if done: + return + saw_terminal_event = True + check_server_bytes("session closed") + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_buf += data + if not client_close_requested: + request_client_close() + return + if done: + return + saw_terminal_event = True + check_server_bytes("server EOF") + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + if done: + return + saw_terminal_event = True + check_server_bytes("server channel close") + + def on_hostkey(client: ssh.Client, state: str, info: ssh.HostKeyInfo): + client.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + ch.request_subsystem("netconf") + payload = b"B" * payload_len + ch.write(payload) + + def ch_out(ch: ssh.Channel, data: ?bytes): + return + + def ch_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + return + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 2000) + attempts * 101) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_client_session_close_flush(t: testing.EnvT): + """Client.close should not drop queued channel writes.""" + ClientSessionCloseFlushTester(t) + + +actor AuthRejectTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var saw_auth = False + var got_connect_error = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if saw_auth and got_connect_error: + done = True + request_client_close() + request_server_close() + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error("timeout waiting for auth reject test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + saw_auth = True + sess.reject_auth("invalid credentials") + maybe_finish() + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open after auth reject") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request after auth reject") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request after auth reject") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is None: + finish_error("expected auth failure but connect succeeded") + return + got_connect_error = True + maybe_finish() + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="wrong-pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 46000 + (random.randint(0, 2000) + attempts * 73) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 2.0: on_timeout() + + +def _test_ssh_auth_reject(t: testing.EnvT): + """Server reject_auth should fail client connect without hanging.""" + AuthRejectTester(t) + + +actor ServerSessionCloseFlushTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var payload_len = 1024 * 1024 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_session: ?ssh.ServerSession = None + var client_buf = b"" + var client_close_requested = False + var server_close_requested = False + var session_close_requested = False + var saw_terminal_event = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def request_session_close(): + if session_close_requested: + return + session_close_requested = True + if server_session is not None: + server_session.close() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def finish_success(): + if done: + return + done = True + log.info("test success", None) + request_client_close() + request_server_close() + t.success() + + def check_client_bytes(reason: str): + if done or not saw_terminal_event: + return + if len(client_buf) != payload_len: + finish_error(reason + ": client received " + str(len(client_buf)) + " bytes, expected " + str(payload_len)) + return + finish_success() + + def on_timeout(): + if done: + return + finish_error("timeout waiting for server session close flush test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + server_session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + ch.reject_request("exec disabled") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "netconf": + ch.reject_request("unsupported subsystem") + return + ch.accept_request() + payload = b"C" * payload_len + ch.write(payload) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + if done: + return + saw_terminal_event = True + check_client_bytes("client closed") + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + ch.request_subsystem("netconf") + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + client_buf += data + if not session_close_requested: + request_session_close() + return + if done: + return + saw_terminal_event = True + check_client_bytes("client EOF") + + def ch_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + if done: + return + saw_terminal_event = True + check_client_bytes("client channel close") + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 52000 + (random.randint(0, 2000) + attempts * 107) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_server_session_close_flush(t: testing.EnvT): + """ServerSession.close should not drop queued channel writes.""" + ServerSessionCloseFlushTester(t) + + +actor ConcurrentChannelTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + PAYLOAD = b"channel-ok" + + var done = False + var attempts = 0 + var port: int = 0 + var num_channels = 8 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_session: ?ssh.ServerSession = None + var client_close_requested = False + var server_close_requested = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_opened = 0 + var client_closed_channels = 0 + var server_subsystems = 0 + var server_closed_channels = 0 + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_shutdown(): + if done: + return + if client_closed_channels == num_channels and server_closed_channels == num_channels: + request_client_close() + if server_session is not None: + server_session.close() + request_server_close() + + def maybe_finish(): + if done: + return + if client_closed and server_closed and session_closed and \ + client_opened == num_channels and server_subsystems == num_channels and \ + client_closed_channels == num_channels and server_closed_channels == num_channels: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + if server_session is not None: + server_session.close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error("timeout waiting for concurrent channel test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + server_session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "netconf": + finish_error("unexpected subsystem: " + name) + return + server_subsystems += 1 + ch.accept_request() + ch.write(PAYLOAD) + ch.close() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("unexpected client payload on concurrent channel test") + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("unexpected client stderr on concurrent channel test") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + server_closed_channels += 1 + maybe_shutdown() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + start_channel(c, 0) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + maybe_finish() + + def start_channel(c: ssh.Client, idx: int): + if idx >= num_channels: + return + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("channel " + str(idx) + " open error: " + err) + return + client_opened += 1 + ch.request_subsystem("netconf") + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if data != PAYLOAD: + finish_error("channel " + str(idx) + " received unexpected payload") + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel " + str(idx) + " received unexpected stderr") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + client_closed_channels += 1 + maybe_shutdown() + maybe_finish() + + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + start_channel(c, idx + 1) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 54000 + (random.randint(0, 2000) + attempts * 109) % 2000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_concurrent_channels(t: testing.EnvT): + """Multiple concurrent channels should share one session cleanly.""" + ConcurrentChannelTester(t) + + +actor ConcurrentWriteServerChannelHandler(payload_len: int, + ack: bytes, + on_complete: action() -> None, + on_error: action(str) -> None): + var recv_buf = b"" + + action def on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + recv_buf += data + return + if len(recv_buf) != payload_len: + on_error("server received " + str(len(recv_buf)) + " bytes, expected " + str(payload_len)) + return + ch.write(ack) + ch.close() + + action def on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + on_error("unexpected client stderr on concurrent write test") + + action def on_close(ch: ssh.ServerChannel, reason: str): + on_complete() + + +actor ConcurrentChannelWriteTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + ACK = b"ack" + + var done = False + var attempts = 0 + var port: int = 0 + var num_channels = 8 + var payload_len = 256 * 1024 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_session: ?ssh.ServerSession = None + var client_close_requested = False + var server_close_requested = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_opened = 0 + var client_acks = 0 + var client_closed_channels = 0 + var server_subsystems = 0 + var server_closed_channels = 0 + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_shutdown(): + if done: + return + if client_closed_channels == num_channels and server_closed_channels == num_channels: + request_client_close() + if server_session is not None: + server_session.close() + request_server_close() + + def maybe_finish(): + if done: + return + if client_closed and server_closed and session_closed and \ + client_opened == num_channels and server_subsystems == num_channels and \ + client_closed_channels == num_channels and server_closed_channels == num_channels: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + if server_session is not None: + server_session.close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error("timeout waiting for concurrent channel write test: client_opened=" + str(client_opened) + " client_acks=" + str(client_acks) + " server_subsystems=" + str(server_subsystems) + " client_closed_channels=" + str(client_closed_channels) + " server_closed_channels=" + str(server_closed_channels) + " client_closed=" + str(client_closed) + " session_closed=" + str(session_closed) + " server_closed=" + str(server_closed)) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + server_session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_server_channel_close(): + server_closed_channels += 1 + maybe_shutdown() + maybe_finish() + + def on_channel_open(sess: ssh.ServerSession): + handler = ConcurrentWriteServerChannelHandler( + payload_len, + ACK, + on_server_channel_close, + finish_error, + ) + ch = ssh.ServerChannel(sess, handler.on_data, handler.on_stderr, handler.on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "netconf": + finish_error("unexpected subsystem: " + name) + return + server_subsystems += 1 + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + start_channel(c, 0) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + maybe_finish() + + def start_channel(c: ssh.Client, idx: int): + if idx >= num_channels: + return + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("channel " + str(idx) + " open error: " + err) + return + client_opened += 1 + ch.request_subsystem("netconf") + payload = b"W" * payload_len + ch.write(payload) + ch.send_eof() + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if data != ACK: + finish_error("channel " + str(idx) + " received unexpected payload") + return + client_acks += 1 + ch.close() + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel " + str(idx) + " received unexpected stderr") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + client_closed_channels += 1 + maybe_shutdown() + maybe_finish() + + ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + start_channel(c, idx + 1) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = _pick_test_port(35000, 25000, attempts, 113) + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 20.0: on_timeout() + + +def _test_ssh_concurrent_channel_writes(t: testing.EnvT): + """Multiple channels should flush writes over one session.""" + ConcurrentChannelWriteTester(t) + + +actor KeepaliveTrafficTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + PING = b"keepalive-ping" + ACK = b"keepalive-ack" + ROUNDS = 10 + KEEPALIVE_INTERVAL = 0.05 + SEND_DELAY = 0.06 + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var session: ?ssh.ServerSession = None + var client_channel: ?ssh.Channel = None + var client_close_requested = False + var server_close_requested = False + var client_closed = False + var server_closed = False + var session_closed = False + var server_channel_closed = False + var ack_received = False + var server_recv_bytes = 0 + + def total_bytes() -> int: + return len(PING) * ROUNDS + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if ack_received and server_recv_bytes == total_bytes() and \ + server_channel_closed and client_closed and session_closed and \ + server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + if session is not None: + session.close() + request_server_close() + t.error(Exception(msg)) + + def send_ping(remaining: int): + if done or remaining <= 0: + return + if client_channel is None: + finish_error("client channel missing during keepalive write burst") + return + if client_channel is not None: + client_channel.write(PING) + if remaining == 1: + client_channel.send_eof() + return + after SEND_DELAY: send_ping(remaining - 1) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for keepalive traffic test " + "(ack_received=" + str(ack_received) + + ", server_recv_bytes=" + str(server_recv_bytes) + + ", server_channel_closed=" + str(server_channel_closed) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < LISTEN_RETRY_LIMIT and _is_retryable_listen_error(err): + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + session = sess + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "echo": + finish_error("unexpected subsystem request: " + name) + return + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_recv_bytes += len(data) + if server_recv_bytes > total_bytes(): + finish_error("server received too much data during keepalive test") + return + if server_recv_bytes == total_bytes(): + ch.write(ACK) + ch.send_eof() + ch.close() + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected stderr during keepalive test") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_closed = True + if session is not None: + session.close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + maybe_finish() + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + client_channel = ch + ch.request_subsystem("echo") + after SEND_DELAY: send_ping(ROUNDS) + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + if data != ACK: + finish_error("client received unexpected keepalive reply") + return + if ack_received: + finish_error("client received duplicate keepalive reply") + return + ack_received = True + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client received unexpected stderr during keepalive test") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + log.info("client channel closed: " + reason, None) + request_client_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + keepalive_interval=KEEPALIVE_INTERVAL, + ) + + def start_server(): + attempts += 1 + port = _pick_test_port(36000, 24000, attempts, 271) + server_close_requested = False + client_close_requested = False + client_closed = False + server_closed = False + session_closed = False + server_channel_closed = False + ack_received = False + server_recv_bytes = 0 + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + keepalive_interval=KEEPALIVE_INTERVAL, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_keepalive_traffic(t: testing.EnvT): + """Keepalives should not disrupt an active channel.""" + KeepaliveTrafficTester(t) + + +actor CrossHandleOwnershipTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + CLIENT1_PAYLOAD = b"client-one" + CLIENT2_PAYLOAD = b"client-two" + BAD_CLIENT_PAYLOAD = b"wrong-client" + SERVER1_REPLY = b"server-one" + SERVER2_REPLY = b"server-two" + BAD_SERVER_PAYLOAD = b"wrong-server" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client1: ?ssh.Client = None + var client2: ?ssh.Client = None + var client_channel1: ?ssh.Channel = None + var client_channel2: ?ssh.Channel = None + var server_session1: ?ssh.ServerSession = None + var server_session2: ?ssh.ServerSession = None + var server_channel1: ?ssh.ServerChannel = None + var server_channel2: ?ssh.ServerChannel = None + var client1_buf = b"" + var client2_buf = b"" + var server1_buf = b"" + var server2_buf = b"" + var io_started = False + var replies_sent = False + var shutdown_started = False + var client1_opened = False + var client2_opened = False + var client1_ok = False + var client2_ok = False + var server1_ok = False + var server2_ok = False + var client_close_count = 0 + var client_channel_close_count = 0 + var server_channel_close_count = 0 + var session_close_count = 0 + var server_closed = False + var client1_close_requested = False + var client2_close_requested = False + var server_close_requested = False + + def request_client1_close(): + if client1_close_requested: + return + client1_close_requested = True + if client1 is not None: + client1.close() + + def request_client2_close(): + if client2_close_requested: + return + client2_close_requested = True + if client2 is not None: + client2.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_request_client_closes(): + if not shutdown_started: + return + if client_channel_close_count >= 2: + request_client1_close() + request_client2_close() + + def maybe_request_server_shutdown(): + if not shutdown_started: + return + if session_close_count >= 2: + request_server_close() + + def begin_shutdown(): + if shutdown_started: + return + shutdown_started = True + if client_channel1 is not None: + client_channel1.close() + if client_channel2 is not None: + client_channel2.close() + if server_session1 is not None: + server_session1.close() + if server_session2 is not None: + server_session2.close() + maybe_request_client_closes() + maybe_request_server_shutdown() + + def maybe_finish(): + if done: + return + if client1_ok and client2_ok and server1_ok and server2_ok and \ + client_close_count == 2 and client_channel_close_count == 2 and \ + server_channel_close_count == 2 and session_close_count == 2 and \ + server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + begin_shutdown() + t.error(Exception(msg)) + + def maybe_start_io(): + if io_started or done: + return + if not client1_opened or not client2_opened: + return + if client1 is not None: + if client_channel1 is not None: + if client_channel2 is not None: + if server_channel1 is not None: + if server_channel2 is not None: + io_started = True + client1.channel_write(client_channel2, BAD_CLIENT_PAYLOAD) + client_channel1.write(CLIENT1_PAYLOAD) + client_channel2.write(CLIENT2_PAYLOAD) + + def maybe_send_replies(): + if replies_sent or done: + return + if not server1_ok or not server2_ok: + return + if server_session1 is not None: + if server_channel1 is not None: + if server_channel2 is not None: + replies_sent = True + server_session1.channel_write(server_channel2, BAD_SERVER_PAYLOAD) + server_channel1.write(SERVER1_REPLY) + server_channel2.write(SERVER2_REPLY) + return + finish_error("server channels not ready for reply") + + def maybe_shutdown(): + if client1_ok and client2_ok: + begin_shutdown() + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for cross-handle ownership test " + "(client_close_count=" + str(client_close_count) + + ", client_channel_close_count=" + str(client_channel_close_count) + + ", server_channel_close_count=" + str(server_channel_close_count) + + ", session_close_count=" + str(session_close_count) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client1() + after 0.01: start_client2() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + if server_channel1 is None: + server_session1 = sess + ch = ssh.ServerChannel(sess, srv1_on_data, srv1_on_stderr, srv1_on_close) + server_channel1 = ch + sess.accept_channel(ch) + else: + if server_channel2 is None: + server_session2 = sess + ch = ssh.ServerChannel(sess, srv2_on_data, srv2_on_stderr, srv2_on_close) + server_channel2 = ch + sess.accept_channel(ch) + else: + finish_error("unexpected extra server channel") + return + maybe_start_io() + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "echo": + finish_error("unexpected subsystem: " + name) + return + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_close_count += 1 + maybe_request_server_shutdown() + maybe_finish() + + def srv1_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server1_buf += data + if server1_buf.find(BAD_CLIENT_PAYLOAD) >= 0: + finish_error("cross-client payload leaked into server channel 1") + return + if len(server1_buf) > len(CLIENT1_PAYLOAD): + finish_error("server channel 1 received extra data") + return + if len(server1_buf) == len(CLIENT1_PAYLOAD): + if server1_buf == CLIENT1_PAYLOAD: + if server1_ok: + finish_error("client-one payload delivered twice") + return + server1_ok = True + else: + if server1_buf == CLIENT2_PAYLOAD: + if server2_ok: + finish_error("client-two payload delivered twice") + return + server2_ok = True + else: + finish_error("server channel 1 received wrong payload") + return + maybe_send_replies() + + def srv1_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server channel 1 received unexpected stderr") + + def srv1_on_close(ch: ssh.ServerChannel, reason: str): + server_channel_close_count += 1 + maybe_finish() + + def srv2_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server2_buf += data + if server2_buf.find(BAD_CLIENT_PAYLOAD) >= 0: + finish_error("cross-client payload leaked into server channel 2") + return + if len(server2_buf) > len(CLIENT2_PAYLOAD): + finish_error("server channel 2 received extra data") + return + if len(server2_buf) == len(CLIENT2_PAYLOAD): + if server2_buf == CLIENT2_PAYLOAD: + if server2_ok: + finish_error("client-two payload delivered twice") + return + server2_ok = True + else: + if server2_buf == CLIENT1_PAYLOAD: + if server1_ok: + finish_error("client-one payload delivered twice") + return + server1_ok = True + else: + finish_error("server channel 2 received wrong payload") + return + maybe_send_replies() + + def srv2_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server channel 2 received unexpected stderr") + + def srv2_on_close(ch: ssh.ServerChannel, reason: str): + server_channel_close_count += 1 + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client1_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client 1 connect error: " + err) + return + client_channel1 = ssh.Channel(c, ch1_open, ch1_out, ch1_err, ch1_exit, ch1_close) + + def on_client2_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client 2 connect error: " + err) + return + client_channel2 = ssh.Channel(c, ch2_open, ch2_out, ch2_err, ch2_exit, ch2_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_close_count += 1 + maybe_finish() + + def ch1_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel 1 open error: " + err) + return + client1_opened = True + client_channel1 = ch + ch.request_subsystem("echo") + maybe_start_io() + + def ch1_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + client1_buf += data + if client1_buf.find(BAD_SERVER_PAYLOAD) >= 0: + finish_error("cross-session server payload leaked into client channel 1") + return + if len(client1_buf) > len(SERVER1_REPLY): + finish_error("client channel 1 received extra data") + return + if len(client1_buf) == len(SERVER1_REPLY): + if client1_buf == SERVER1_REPLY: + if client1_ok: + finish_error("server-one reply delivered twice") + return + client1_ok = True + else: + if client1_buf == SERVER2_REPLY: + if client2_ok: + finish_error("server-two reply delivered twice") + return + client2_ok = True + else: + finish_error("client channel 1 received wrong payload") + return + maybe_shutdown() + + def ch1_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client channel 1 received unexpected stderr") + + def ch1_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch1_close(ch: ssh.Channel, reason: str): + client_channel_close_count += 1 + maybe_request_client_closes() + maybe_finish() + + def ch2_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel 2 open error: " + err) + return + client2_opened = True + client_channel2 = ch + ch.request_subsystem("echo") + maybe_start_io() + + def ch2_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + client2_buf += data + if client2_buf.find(BAD_SERVER_PAYLOAD) >= 0: + finish_error("cross-session server payload leaked into client channel 2") + return + if len(client2_buf) > len(SERVER2_REPLY): + finish_error("client channel 2 received extra data") + return + if len(client2_buf) == len(SERVER2_REPLY): + if client2_buf == SERVER2_REPLY: + if client2_ok: + finish_error("server-two reply delivered twice") + return + client2_ok = True + else: + if client2_buf == SERVER1_REPLY: + if client1_ok: + finish_error("server-one reply delivered twice") + return + client1_ok = True + else: + finish_error("client channel 2 received wrong payload") + return + maybe_shutdown() + + def ch2_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client channel 2 received unexpected stderr") + + def ch2_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch2_close(ch: ssh.Channel, reason: str): + client_channel_close_count += 1 + maybe_request_client_closes() + maybe_finish() + + def start_client1(): + client1 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client1_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_client2(): + client2 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client2_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 60000 + (random.randint(0, 4000) + attempts * 131) % 4000 + server_close_requested = False + client1_close_requested = False + client2_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_cross_handle_ownership(t: testing.EnvT): + """Wrong client or session should not drive another channel.""" + CrossHandleOwnershipTester(t) + + +actor UnsupportedRequestTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var saw_subsystem = False + var saw_exec = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if saw_subsystem and saw_exec: + done = True + request_client_close() + request_server_close() + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for unsupported request test " + "(saw_subsystem=" + str(saw_subsystem) + + ", saw_exec=" + str(saw_exec) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + saw_exec = True + ch.reject_request("exec disabled") + maybe_finish() + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + saw_subsystem = True + ch.reject_request("unsupported subsystem") + maybe_finish() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + ssh.Channel(c, ch_sub_open, ch_sub_out, ch_sub_err, ch_sub_exit, ch_sub_close) + after 0.05: start_exec_channel(c) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + + def ch_sub_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("subsystem channel open error: " + err) + return + ch.request_subsystem("unsupported-subsystem") + + def ch_sub_out(ch: ssh.Channel, data: ?bytes): + return + + def ch_sub_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_sub_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_sub_close(ch: ssh.Channel, reason: str): + return + + def start_exec_channel(c: ssh.Client): + if done: + return + ssh.Channel(c, ch_exec_open, ch_exec_out, ch_exec_err, ch_exec_exit, ch_exec_close) + + def ch_exec_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("exec channel open error: " + err) + return + ch.request_exec("echo denied") + + def ch_exec_out(ch: ssh.Channel, data: ?bytes): + return + + def ch_exec_err(ch: ssh.Channel, data: ?bytes): + return + + def ch_exec_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_exec_close(ch: ssh.Channel, reason: str): + return + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 64000 + (random.randint(0, 1000) + attempts * 97) % 1000 + server_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_unsupported_requests(t: testing.EnvT): + """Server reject_request should handle unsupported subsystem/exec.""" + UnsupportedRequestTester(t) + + +actor RunCommandTimeoutTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + PARTIAL_OUT = b"partial-out" + PARTIAL_ERR = b"partial-err" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var run_exit_called = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if run_exit_called and client_closed and server_closed and session_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for RunCommand timeout test " + "(run_exit=" + str(run_exit_called) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd != "stall": + finish_error("unexpected exec command: " + cmd) + return + ch.accept_request() + ch.write(PARTIAL_OUT) + ch.write_stderr(PARTIAL_ERR) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + ssh.RunCommand(c, "stall", on_run_exit, timeout=0.1) + + def on_run_exit(ch: ssh.Channel, + code: int, + sig: ?str, + stdout: bytes, + stderr: bytes, + err: ?str): + if run_exit_called: + finish_error("RunCommand on_exit called more than once") + return + run_exit_called = True + if err != "timeout": + finish_error("expected timeout error, got " + str(err)) + return + if len(stdout) > len(PARTIAL_OUT) or PARTIAL_OUT.find(stdout) != 0: + finish_error("unexpected stdout on timeout") + return + if len(stderr) > len(PARTIAL_ERR) or PARTIAL_ERR.find(stderr) != 0: + finish_error("unexpected stderr on timeout") + return + if code != 0: + finish_error("timeout should not invent exit code") + return + if sig is not None: + finish_error("timeout should not invent exit signal") + return + request_client_close() + maybe_finish() + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 52000 + (random.randint(0, 2000) + attempts * 149) % 2000 + server_close_requested = False + client_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_runcommand_timeout(t: testing.EnvT): + """RunCommand timeout should complete exactly once with buffered output.""" + RunCommandTimeoutTester(t) + + +actor RunCommandExitStatusTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + OUT1 = b"alpha-" + OUT2 = b"beta" + ERR1 = b"warn-" + ERR2 = b"detail" + OUT = OUT1 + OUT2 + ERR = ERR1 + ERR2 + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var run_exit_called = False + var client_closed = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if run_exit_called and client_closed and server_closed and session_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for RunCommand exit status test " + "(run_exit=" + str(run_exit_called) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd != "report": + finish_error("unexpected exec command: " + cmd) + return + ch.accept_request() + ch.write(OUT1) + ch.write(OUT2) + ch.write_stderr(ERR1) + ch.write_stderr(ERR2) + ch.send_exit_status(23) + ch.send_eof() + ch.close() + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + ssh.RunCommand(c, "report", on_run_exit) + + def on_run_exit(ch: ssh.Channel, + code: int, + sig: ?str, + stdout: bytes, + stderr: bytes, + err: ?str): + if run_exit_called: + finish_error("RunCommand on_exit called more than once") + return + run_exit_called = True + if err is not None: + finish_error("unexpected RunCommand error: " + str(err)) + return + if stdout != OUT: + finish_error("unexpected stdout payload") + return + if stderr != ERR: + finish_error("unexpected stderr payload") + return + if code != 23: + finish_error("unexpected exit code: " + str(code)) + return + if sig is not None: + finish_error("unexpected exit signal: " + str(sig)) + return + request_client_close() + maybe_finish() + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 58000 + (random.randint(0, 2000) + attempts * 151) % 2000 + server_close_requested = False + client_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_runcommand_exit_status(t: testing.EnvT): + """RunCommand should collect stdout, stderr, and exit status.""" + RunCommandExitStatusTester(t) + + +actor ConcurrentRunCommandTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + NUM_CMDS = 4 + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_channel_closes = 0 + var run_exits = 0 + var client_closed = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + var run0_done = False + var run1_done = False + var run2_done = False + var run3_done = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def out1(cmd: str) -> bytes: + if cmd == "report0": + return b"out0-a:" * 512 + if cmd == "report1": + return b"out1-a:" * 512 + if cmd == "report2": + return b"out2-a:" * 512 + return b"out3-a:" * 512 + + def out2(cmd: str) -> bytes: + if cmd == "report0": + return b"out0-b:" * 512 + if cmd == "report1": + return b"out1-b:" * 512 + if cmd == "report2": + return b"out2-b:" * 512 + return b"out3-b:" * 512 + + def err1(cmd: str) -> bytes: + if cmd == "report0": + return b"err0-a:" * 256 + if cmd == "report1": + return b"err1-a:" * 256 + if cmd == "report2": + return b"err2-a:" * 256 + return b"err3-a:" * 256 + + def err2(cmd: str) -> bytes: + if cmd == "report0": + return b"err0-b:" * 256 + if cmd == "report1": + return b"err1-b:" * 256 + if cmd == "report2": + return b"err2-b:" * 256 + return b"err3-b:" * 256 + + def full_out(cmd: str) -> bytes: + return out1(cmd) + out2(cmd) + + def full_err(cmd: str) -> bytes: + return err1(cmd) + err2(cmd) + + def exit_code(cmd: str) -> int: + if cmd == "report0": + return 30 + if cmd == "report1": + return 31 + if cmd == "report2": + return 32 + return 33 + + def maybe_finish(): + if done: + return + if run_exits == NUM_CMDS and server_channel_closes == NUM_CMDS and \ + client_closed and server_closed and session_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def mark_run_done(cmd: str): + if cmd == "report0": + if run0_done: + finish_error("RunCommand report0 completed more than once") + return + run0_done = True + return + if cmd == "report1": + if run1_done: + finish_error("RunCommand report1 completed more than once") + return + run1_done = True + return + if cmd == "report2": + if run2_done: + finish_error("RunCommand report2 completed more than once") + return + run2_done = True + return + if cmd == "report3": + if run3_done: + finish_error("RunCommand report3 completed more than once") + return + run3_done = True + return + finish_error("unexpected RunCommand completion for " + cmd) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for concurrent RunCommand test " + "(run_exits=" + str(run_exits) + + ", server_channel_closes=" + str(server_channel_closes) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def send_report(ch: ssh.ServerChannel, cmd: str, exit_first: bool): + if exit_first: + ch.send_exit_status(exit_code(cmd)) + ch.write(out1(cmd)) + ch.write_stderr(err1(cmd)) + ch.write(out2(cmd)) + ch.write_stderr(err2(cmd)) + if not exit_first: + ch.send_exit_status(exit_code(cmd)) + ch.send_eof() + ch.close() + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd != "report0" and cmd != "report1" and cmd != "report2" and cmd != "report3": + finish_error("unexpected exec command: " + cmd) + return + ch.accept_request() + if cmd == "report0" or cmd == "report2": + send_report(ch, cmd, True) + else: + send_report(ch, cmd, False) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + server_channel_closes += 1 + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + start_run(c, 0) + start_run(c, 1) + start_run(c, 2) + start_run(c, 3) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def start_run(c: ssh.Client, idx: int): + cmd = "report" + str(idx) + + def on_run_exit(ch: ssh.Channel, + code: int, + sig: ?str, + stdout: bytes, + stderr: bytes, + err: ?str): + if err is not None: + finish_error("unexpected RunCommand error for " + cmd + ": " + str(err)) + return + if stdout != full_out(cmd): + finish_error("unexpected stdout payload for " + cmd) + return + if stderr != full_err(cmd): + finish_error("unexpected stderr payload for " + cmd) + return + if code != exit_code(cmd): + finish_error("unexpected exit code for " + cmd + ": " + str(code)) + return + if sig is not None: + finish_error("unexpected exit signal for " + cmd + ": " + str(sig)) + return + mark_run_done(cmd) + if done: + return + run_exits += 1 + if run_exits == NUM_CMDS: + request_client_close() + maybe_finish() + + ssh.RunCommand(c, cmd, on_run_exit) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 56000 + (random.randint(0, 2000) + attempts * 157) % 2000 + server_close_requested = False + client_close_requested = False + server_channel_closes = 0 + run_exits = 0 + client_closed = False + server_closed = False + session_closed = False + run0_done = False + run1_done = False + run2_done = False + run3_done = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_concurrent_runcommand(t: testing.EnvT): + """RunCommand should isolate concurrent stderr and exit paths.""" + ConcurrentRunCommandTester(t) + + +actor ConcurrentRunCommandRejectTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + OK0_OUT = b"ok0-out" + OK0_ERR = b"ok0-err" + OK1_OUT = b"ok1-out" + OK1_ERR = b"ok1-err" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_channel_closes = 0 + var run_exits = 0 + var client_closed = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + var ok0_done = False + var deny_done = False + var ok1_done = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if run_exits == 3 and server_channel_closes == 3 and \ + client_closed and server_closed and session_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def mark_done(cmd: str): + if cmd == "ok0": + if ok0_done: + finish_error("ok0 completed more than once") + return + ok0_done = True + return + if cmd == "deny": + if deny_done: + finish_error("deny completed more than once") + return + deny_done = True + return + if cmd == "ok1": + if ok1_done: + finish_error("ok1 completed more than once") + return + ok1_done = True + return + finish_error("unexpected completion for " + cmd) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for concurrent RunCommand reject test " + "(run_exits=" + str(run_exits) + + ", server_channel_closes=" + str(server_channel_closes) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd == "deny": + ch.reject_request("denied") + ch.close() + return + if cmd == "ok0": + ch.accept_request() + ch.write(OK0_OUT) + ch.write_stderr(OK0_ERR) + ch.send_exit_status(40) + ch.send_eof() + ch.close() + return + if cmd == "ok1": + ch.accept_request() + ch.send_exit_status(41) + ch.write(OK1_OUT) + ch.write_stderr(OK1_ERR) + ch.send_eof() + ch.close() + return + finish_error("unexpected exec command: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + return + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + server_channel_closes += 1 + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + start_run(c, "ok0") + start_run(c, "deny") + start_run(c, "ok1") + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def start_run(c: ssh.Client, cmd: str): + def on_run_exit(ch: ssh.Channel, + code: int, + sig: ?str, + stdout: bytes, + stderr: bytes, + err: ?str): + if cmd == "deny": + if err is None: + finish_error("expected reject error for deny") + return + if sig is not None: + finish_error("unexpected signal on deny: " + str(sig)) + return + elif cmd == "ok0": + if err is not None: + finish_error("unexpected error for ok0: " + str(err)) + return + if stdout != OK0_OUT or stderr != OK0_ERR or code != 40: + finish_error("unexpected result for ok0") + return + if sig is not None: + finish_error("unexpected signal on ok0: " + str(sig)) + return + elif cmd == "ok1": + if err is not None: + finish_error("unexpected error for ok1: " + str(err)) + return + if stdout != OK1_OUT or stderr != OK1_ERR or code != 41: + finish_error("unexpected result for ok1") + return + if sig is not None: + finish_error("unexpected signal on ok1: " + str(sig)) + return + else: + finish_error("unexpected command in on_run_exit: " + cmd) + return + mark_done(cmd) + if done: + return + run_exits += 1 + if run_exits == 3: + request_client_close() + maybe_finish() + + ssh.RunCommand(c, cmd, on_run_exit) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 60000 + (random.randint(0, 2000) + attempts * 163) % 2000 + server_close_requested = False + client_close_requested = False + server_channel_closes = 0 + run_exits = 0 + client_closed = False + server_closed = False + session_closed = False + ok0_done = False + deny_done = False + ok1_done = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_concurrent_runcommand_reject(t: testing.EnvT): + """A rejected RunCommand should not poison sibling channels.""" + ConcurrentRunCommandRejectTester(t) + + +actor AuthTimeoutTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var client_done = False + var hostkey_seen = False + var server_closed = False + var session_closed = False + var client_close_requested = False + var server_close_requested = False + var timeout_grace = False + + def maybe_finish(): + if done: + return + if session_closed and server_closed: + request_client_close() + done = True + t.success() + return + if timeout_grace and hostkey_seen and client_done and server_closed: + done = True + t.success() + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + request_client_close() + request_server_close() + done = True + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + if not timeout_grace: + timeout_grace = True + request_client_close() + request_server_close() + after 0.4: on_timeout() + return + maybe_finish() + if done: + return + finish_error("timeout waiting for auth timeout test") + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + return + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + finish_error("unexpected auth callback") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if reason == "SSH authentication timeout": + session_closed = True + request_server_close() + maybe_finish() + return + if timeout_grace and reason == "Server closed": + session_closed = True + maybe_finish() + return + if reason != "SSH authentication timeout": + finish_error("unexpected session close reason: " + reason) + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + log.info("hostkey state: " + state, None) + hostkey_seen = True + return + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + if hostkey_seen: + log.info("client terminal connect error: " + err, None) + client_done = True + if session_closed: + request_server_close() + maybe_finish() + return + finish_error("client connect error: " + err) + return + client = c + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_done = True + if session_closed: + request_server_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 56000 + (random.randint(0, 2000) + attempts * 157) % 2000 + server_close_requested = False + client_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + auth_timeout=1.0, + ) + + after 0: start_server() + after 2.4: on_timeout() + + +def _test_ssh_auth_timeout(t: testing.EnvT): + """Server auth timeout should still fire after attach.""" + if len(t.env.argv) > 3 and t.env.argv[3] == "stress": + t.skip("AuthTimeoutTester is deterministic-only in stress mode") + AuthTimeoutTester(t) + + +actor HostkeyWaitDisconnectTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var session: ?ssh.ServerSession = None + var hostkey_seen = False + var client_connect_count = 0 + var session_close_count = 0 + var server_close_count = 0 + var session_close_requested = False + var server_close_requested = False + + def maybe_finish(): + if done: + return + if hostkey_seen and client_connect_count == 1 and session_close_count == 1 and \ + server_close_count == 1: + done = True + t.success() + + def request_session_close(): + if session_close_requested: + return + if not hostkey_seen: + return + if session is None: + return + session_close_requested = True + sess = session + if sess is not None: + sess.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + request_server_close() + done = True + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for hostkey wait disconnect test " + "(hostkey_seen=" + str(hostkey_seen) + + ", client_connect_count=" + str(client_connect_count) + + ", session_close_count=" + str(session_close_count) + + ", server_close_count=" + str(server_close_count) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_close_count += 1 + if server_close_count > 1: + finish_error("server close callback called more than once") + return + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + request_session_close() + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + finish_error("unexpected auth callback") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_close_count += 1 + if session_close_count > 1: + finish_error("session close callback called more than once") + return + request_server_close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + log.info("hostkey state: " + state, None) + if hostkey_seen: + finish_error("hostkey callback called more than once") + return + hostkey_seen = True + request_session_close() + + def on_client_connect(c: ssh.Client, err: ?str): + client_connect_count += 1 + if client_connect_count > 1: + finish_error("client connect callback called more than once") + return + if err is None: + finish_error("client unexpectedly connected") + return + maybe_finish() + + def on_client_close(c: ssh.Client, reason: str): + finish_error("unexpected client close callback: " + reason) + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 36000 + (random.randint(0, 20000) + attempts * 191) % 20000 + server_close_requested = False + session_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + auth_timeout=5.0, + ) + + after 0: start_server() + after 8.0: on_timeout() + + +def _test_ssh_hostkey_wait_disconnect(t: testing.EnvT): + """Client should fail promptly if the server closes in hostkey wait.""" + HostkeyWaitDisconnectTester(t) + + +actor CloseCallbackReentryTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + REPLY = b"close-ok" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var session: ?ssh.ServerSession = None + var client_channel: ?ssh.Channel = None + var saw_reply = False + var client_stdout_eof_count = 0 + var client_stderr_eof_count = 0 + var client_exit_count = 0 + var client_channel_close_count = 0 + var client_close_count = 0 + var server_channel_close_count = 0 + var session_close_count = 0 + var server_close_count = 0 + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if saw_reply and client_stdout_eof_count == 1 and client_stderr_eof_count == 1 and \ + client_exit_count == 1 and client_channel_close_count == 1 and \ + client_close_count == 1 and server_channel_close_count == 1 and \ + session_close_count == 1 and server_close_count == 1: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for close callback reentry test " + "(reply=" + str(saw_reply) + + ", client_stdout_eof_count=" + str(client_stdout_eof_count) + + ", client_stderr_eof_count=" + str(client_stderr_eof_count) + + ", client_exit_count=" + str(client_exit_count) + + ", client_channel_close_count=" + str(client_channel_close_count) + + ", client_close_count=" + str(client_close_count) + + ", server_channel_close_count=" + str(server_channel_close_count) + + ", session_close_count=" + str(session_close_count) + + ", server_close_count=" + str(server_close_count) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_close_count += 1 + if server_close_count > 1: + finish_error("server close callback called more than once") + return + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + session = sess + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + if cmd != "reentry": + finish_error("unexpected exec command: " + cmd) + return + session = sess + ch.accept_request() + ch.write(REPLY) + ch.send_exit_status(7) + ch.send_eof() + ch.close() + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_close_count += 1 + if session_close_count > 1: + finish_error("session close callback called more than once") + return + sess.close() + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected data") + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected stderr") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_close_count += 1 + if server_channel_close_count > 1: + finish_error("server channel close callback called more than once") + return + ch.write(b"after-server-close") + ch.send_eof() + ch.close() + if session is not None: + session.close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_close_count += 1 + if client_close_count > 1: + finish_error("client close callback called more than once") + return + c.close() + maybe_finish() + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + client_channel = ch + ch.request_exec("reentry") + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is None: + client_stdout_eof_count += 1 + if client_stdout_eof_count > 1: + finish_error("client stdout EOF callback called more than once") + return + if client_channel_close_count > 0: + finish_error("client stdout EOF arrived after close callback") + return + maybe_finish() + return + if saw_reply: + finish_error("client received duplicate reply") + return + if data != REPLY: + finish_error("unexpected reply payload: " + str(data)) + return + saw_reply = True + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client received unexpected stderr") + return + client_stderr_eof_count += 1 + if client_stderr_eof_count > 1: + finish_error("client stderr EOF callback called more than once") + return + if client_channel_close_count > 0: + finish_error("client stderr EOF arrived after close callback") + return + maybe_finish() + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + client_exit_count += 1 + if client_exit_count > 1: + finish_error("client exit callback called more than once") + return + if client_channel_close_count > 0: + finish_error("client exit callback arrived after close callback") + return + if code != 7: + finish_error("unexpected exit code: " + str(code)) + return + if sig is not None: + finish_error("unexpected exit signal: " + str(sig)) + return + maybe_finish() + + def ch_close(ch: ssh.Channel, reason: str): + log.info("client channel closed: " + reason, None) + client_channel_close_count += 1 + if client_channel_close_count > 1: + finish_error("client channel close callback called more than once") + return + if client_stdout_eof_count != 1: + finish_error("client close callback ran before stdout EOF") + return + if client_stderr_eof_count != 1: + finish_error("client close callback ran before stderr EOF") + return + if client_exit_count != 1: + finish_error("client close callback ran before exit callback") + return + ch.write(b"after-client-close") + ch.send_eof() + ch.close() + request_client_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 54000 + (random.randint(0, 3000) + attempts * 163) % 3000 + server_close_requested = False + client_close_requested = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_close_callback_reentry(t: testing.EnvT): + """Close callbacks should tolerate reentrant close calls.""" + CloseCallbackReentryTester(t) + + +actor KeyExchangeTimeoutTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var raw_client: ?net.TCPConnection = None + var raw_connected = False + var raw_bytes_in = 0 + var session_ready = False + var session_closed = False + var server_closed = False + var raw_close_requested = False + var server_close_requested = False + var timeout_grace = False + + def request_raw_close(): + if raw_close_requested: + return + raw_close_requested = True + if raw_client is not None: + raw_client.close(on_raw_close) + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if raw_connected and session_ready and session_closed and server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_raw_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + if not timeout_grace: + timeout_grace = True + after 5.0: on_timeout() + return + finish_error( + "timeout waiting for key exchange timeout test " + "(raw_connected=" + str(raw_connected) + + ", session_ready=" + str(session_ready) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + + ", raw_bytes_in=" + str(raw_bytes_in) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_raw_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session_ready = True + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + finish_error("unexpected auth request on raw TCP idle client") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open on raw TCP idle client") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request on raw TCP idle client") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request on raw TCP idle client") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if reason != "SSH key exchange timeout": + finish_error("unexpected session close reason: " + reason) + return + session_closed = True + request_raw_close() + request_server_close() + maybe_finish() + + def on_raw_connect(c: net.TCPConnection): + raw_client = c + raw_connected = True + + def on_raw_receive(c: net.TCPConnection, payload: bytes): + raw_bytes_in += len(payload) + + def on_raw_error(c: net.TCPConnection, errmsg: str): + finish_error("unexpected raw client error: " + errmsg) + + def on_raw_remote_close(c: net.TCPConnection): + return + + def on_raw_close(c: net.TCPConnection): + return + + def start_raw_client(): + raw_close_requested = False + raw_connected = False + raw_bytes_in = 0 + raw_client = net.TCPConnection( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + port, + on_raw_connect, + on_raw_receive, + on_raw_error, + on_raw_remote_close, + ) + + def start_server(): + attempts += 1 + port = _pick_test_port(35000, 25000, attempts, 179) + server_close_requested = False + server_closed = False + session_ready = False + session_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + auth_timeout=0.2, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_key_exchange_timeout(t: testing.EnvT): + """Attached sessions should time out during key exchange.""" + KeyExchangeTimeoutTester(t) + + +actor ServerSurvivesGarbageTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + GARBAGE = b"not-ssh\r\nstill-not-ssh\r\n" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var raw_client: ?net.TCPConnection = None + var good_client: ?ssh.Client = None + var raw_connected = False + var raw_bytes_in = 0 + var raw_close_requested = False + var raw_session_seen = False + var raw_session_closed = False + var good_client_started = False + var good_client_connected = False + var good_client_closed = False + var good_session_seen = False + var good_session_closed = False + var server_close_requested = False + var server_closed = False + + def maybe_finish(): + if done: + return + if raw_connected and raw_session_seen and raw_session_closed and \ + good_client_connected and good_client_closed and good_session_seen and \ + good_session_closed and server_closed: + done = True + t.success() + + def request_raw_close(): + if raw_close_requested: + return + raw_close_requested = True + if raw_client is not None: + raw_client.close(on_raw_close) + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + request_raw_close() + if good_client is not None: + good_client.close() + request_server_close() + done = True + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for bad-peer isolation test " + "(raw_connected=" + str(raw_connected) + + ", raw_bytes_in=" + str(raw_bytes_in) + + ", raw_session_seen=" + str(raw_session_seen) + + ", raw_session_closed=" + str(raw_session_closed) + + ", good_client_started=" + str(good_client_started) + + ", good_client_connected=" + str(good_client_connected) + + ", good_client_closed=" + str(good_client_closed) + + ", good_session_seen=" + str(good_session_seen) + + ", good_session_closed=" + str(good_session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_raw_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + if not raw_session_seen: + raw_session_seen = True + return + if not good_session_seen: + good_session_seen = True + return + finish_error("unexpected extra session") + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if not good_client_started: + finish_error("raw peer unexpectedly reached auth") + return + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + finish_error("unexpected auth request for good client") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if not raw_session_closed: + raw_session_closed = True + if good_client_started: + finish_error("raw session closed after good client started") + return + start_good_client() + maybe_finish() + return + if not good_session_closed: + good_session_closed = True + request_server_close() + maybe_finish() + return + finish_error("unexpected extra session close") + + def on_raw_connect(c: net.TCPConnection): + raw_client = c + raw_connected = True + c.write(GARBAGE) + after 0.05: request_raw_close() + + def on_raw_receive(c: net.TCPConnection, payload: bytes): + raw_bytes_in += len(payload) + + def on_raw_error(c: net.TCPConnection, errmsg: str): + finish_error("unexpected raw client error: " + errmsg) + + def on_raw_remote_close(c: net.TCPConnection): + return + + def on_raw_close(c: net.TCPConnection): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_good_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("good client connect error: " + err) + return + good_client_connected = True + c.close() + + def on_good_client_close(c: ssh.Client, reason: str): + good_client_closed = True + if good_session_closed: + request_server_close() + maybe_finish() + + def start_raw_client(): + raw_close_requested = False + raw_connected = False + raw_bytes_in = 0 + raw_client = net.TCPConnection( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + port, + on_raw_connect, + on_raw_receive, + on_raw_error, + on_raw_remote_close, + ) + + def start_good_client(): + if good_client_started: + return + good_client_started = True + good_client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_good_client_connect, + on_good_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 39000 + (random.randint(0, 12000) + attempts * 211) % 12000 + server_close_requested = False + server_closed = False + raw_session_seen = False + raw_session_closed = False + good_client_started = False + good_client_connected = False + good_client_closed = False + good_session_seen = False + good_session_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_server_survives_garbage(t: testing.EnvT): + """One bad TCP peer should not poison the next SSH session.""" + ServerSurvivesGarbageTester(t) + + +actor RapidDisconnectChurnTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + ROUNDS = 25 + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var raw_client: ?net.TCPConnection = None + var good_client: ?ssh.Client = None + var rounds_started = 0 + var raw_connect_count = 0 + var raw_session_seen_count = 0 + var raw_session_closed_count = 0 + var raw_close_requested = False + var good_client_started = False + var good_client_connected = False + var good_client_closed = False + var good_session_seen = False + var good_session_closed = False + var server_close_requested = False + var server_closed = False + + def maybe_finish(): + if done: + return + if raw_connect_count == ROUNDS and raw_session_seen_count == ROUNDS and \ + raw_session_closed_count == ROUNDS and good_client_connected and \ + good_client_closed and good_session_seen and good_session_closed and \ + server_closed: + done = True + t.success() + + def request_raw_close(): + if raw_close_requested: + return + raw_close_requested = True + if raw_client is not None: + raw_client.close(on_raw_close) + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def finish_error(msg: str): + if done: + return + request_raw_close() + if good_client is not None: + good_client.close() + request_server_close() + done = True + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for rapid disconnect churn test " + "(rounds_started=" + str(rounds_started) + + ", raw_connect_count=" + str(raw_connect_count) + + ", raw_session_seen_count=" + str(raw_session_seen_count) + + ", raw_session_closed_count=" + str(raw_session_closed_count) + + ", good_client_started=" + str(good_client_started) + + ", good_client_connected=" + str(good_client_connected) + + ", good_client_closed=" + str(good_client_closed) + + ", good_session_seen=" + str(good_session_seen) + + ", good_session_closed=" + str(good_session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_next_raw() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + if good_client_started: + if good_session_seen: + finish_error("good session callback called more than once") + return + good_session_seen = True + return + raw_session_seen_count += 1 + if raw_session_seen_count > ROUNDS: + finish_error("too many raw sessions") + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if not good_client_started: + finish_error("raw churn unexpectedly reached auth") + return + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + finish_error("unexpected auth request for good client") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + if good_client_started and raw_session_closed_count == ROUNDS: + if good_session_closed: + finish_error("good session close callback called more than once") + return + good_session_closed = True + request_server_close() + maybe_finish() + return + raw_session_closed_count += 1 + if raw_session_closed_count > ROUNDS: + finish_error("too many raw session closes") + return + if raw_session_closed_count < ROUNDS: + start_next_raw() + return + start_good_client() + + def on_raw_connect(c: net.TCPConnection): + raw_client = c + raw_connect_count += 1 + after 0.01: request_raw_close() + + def on_raw_receive(c: net.TCPConnection, payload: bytes): + return + + def on_raw_error(c: net.TCPConnection, errmsg: str): + finish_error("unexpected raw client error: " + errmsg) + + def on_raw_remote_close(c: net.TCPConnection): + return + + def on_raw_close(c: net.TCPConnection): + return + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_good_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("good client connect error: " + err) + return + good_client_connected = True + c.close() + + def on_good_client_close(c: ssh.Client, reason: str): + good_client_closed = True + if good_session_closed: + request_server_close() + maybe_finish() + + def start_next_raw(): + if good_client_started: + finish_error("started raw churn after good client") + return + if rounds_started >= ROUNDS: + finish_error("started too many raw churn rounds") + return + rounds_started += 1 + raw_close_requested = False + raw_client = net.TCPConnection( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + port, + on_raw_connect, + on_raw_receive, + on_raw_error, + on_raw_remote_close, + ) + + def start_good_client(): + if good_client_started: + return + good_client_started = True + good_client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_good_client_connect, + on_good_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 39000 + (random.randint(0, 12000) + attempts * 223) % 12000 + rounds_started = 0 + raw_connect_count = 0 + raw_session_seen_count = 0 + raw_session_closed_count = 0 + raw_close_requested = False + good_client_started = False + good_client_connected = False + good_client_closed = False + good_session_seen = False + good_session_closed = False + server_close_requested = False + server_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_rapid_disconnect_churn(t: testing.EnvT): + """Rapid raw disconnect churn should not poison the next SSH client.""" + RapidDisconnectChurnTester(t) + + +actor CloseRejectsLateWritesTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + START = b"start" + ACK = b"ack" + LATE = b"late-data" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var session: ?ssh.ServerSession = None + var client_channel: ?ssh.Channel = None + var server_buf = b"" + var ack_received = False + var shutdown_started = False + var client_closed = False + var server_channel_closed = False + var session_closed = False + var server_closed = False + var client_close_requested = False + var server_close_requested = False + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_finish(): + if done: + return + if ack_received and client_closed and server_channel_closed and \ + session_closed and server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def late_write_burst(remaining: int): + if done or remaining <= 0: + return + if client_channel is not None: + client_channel.write(LATE) + after 0: late_write_burst(remaining - 1) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for late-write close test " + "(ack_received=" + str(ack_received) + + ", client_closed=" + str(client_closed) + + ", server_channel_closed=" + str(server_channel_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + session = sess + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "echo": + finish_error("unexpected subsystem request: " + name) + return + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + server_buf += data + if server_buf.find(LATE) >= 0: + finish_error("late write reached server after close started") + return + if not ack_received and server_buf.find(START) >= 0: + ch.write(ACK) + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected stderr") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_closed = True + if session is not None: + session.close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + client_channel = ssh.Channel(c, ch_open, ch_out, ch_err, ch_exit, ch_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def ch_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("client channel open error: " + err) + return + client_channel = ch + ch.request_subsystem("echo") + ch.write(START) + + def ch_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + if data != ACK: + finish_error("unexpected client payload: " + str(data)) + return + if ack_received: + finish_error("client received duplicate ACK") + return + ack_received = True + if not shutdown_started: + shutdown_started = True + request_client_close() + after 0: late_write_burst(32) + + def ch_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("client received unexpected stderr") + + def ch_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch_close(ch: ssh.Channel, reason: str): + log.info("client channel closed: " + reason, None) + request_client_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 3000) + attempts * 181) % 3000 + server_close_requested = False + client_close_requested = False + ack_received = False + shutdown_started = False + client_closed = False + server_channel_closed = False + session_closed = False + server_closed = False + server_buf = b"" + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 5.0: on_timeout() + + +def _test_ssh_close_rejects_late_writes(t: testing.EnvT): + """Channel writes should be ignored once parent close starts.""" + CloseRejectsLateWritesTester(t) + + +actor SessionLimitTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client1: ?ssh.Client = None + var client2: ?ssh.Client = None + var client3: ?ssh.Client = None + var client1_close_requested = False + var client3_close_requested = False + var server_close_requested = False + var client1_ready = False + var client1_closed = False + var client2_failed = False + var client3_started = False + var client3_ready = False + var client3_closed = False + var server_closed = False + var session_ready_count = 0 + var session_close_count = 0 + + def request_client1_close(): + if client1_close_requested: + return + client1_close_requested = True + if client1 is not None: + client1.close() + + def request_client3_close(): + if client3_close_requested: + return + client3_close_requested = True + if client3 is not None: + client3.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_start_client3(): + if done or client3_started: + return + if client2_failed and client1_closed and session_close_count >= 1: + client3_started = True + start_client3() + + def maybe_finish(): + if done: + return + if client1_ready and client1_closed and client2_failed and client3_ready and \ + client3_closed and server_closed and session_ready_count == 2 and \ + session_close_count == 2: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client1_close() + request_client3_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for session limit test " + "(client1_ready=" + str(client1_ready) + + ", client1_closed=" + str(client1_closed) + + ", client2_failed=" + str(client2_failed) + + ", client3_ready=" + str(client3_ready) + + ", client3_closed=" + str(client3_closed) + + ", session_ready_count=" + str(session_ready_count) + + ", session_close_count=" + str(session_close_count) + + ", server_closed=" + str(server_closed) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client1() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session_ready_count += 1 + if session_ready_count > 2: + finish_error("unexpected extra session became ready") + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open during session limit test") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request: " + name) + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_close_count += 1 + maybe_start_client3() + if client3_closed and session_close_count >= 2: + request_server_close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client1_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client 1 connect error: " + err) + return + client1 = c + client1_ready = True + start_client2() + + def on_client2_connect(c: ssh.Client, err: ?str): + if err is None: + finish_error("client 2 unexpectedly connected over session limit") + return + client2_failed = True + request_client1_close() + maybe_start_client3() + + def on_client3_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client 3 connect error after freeing session slot: " + err) + return + client3 = c + client3_ready = True + request_client3_close() + + def on_client1_close(c: ssh.Client, reason: str): + log.info("client 1 closed: " + reason, None) + client1_closed = True + maybe_start_client3() + maybe_finish() + + def on_client3_close(c: ssh.Client, reason: str): + log.info("client 3 closed: " + reason, None) + client3_closed = True + if session_close_count >= 2: + request_server_close() + maybe_finish() + + def on_client2_close(c: ssh.Client, reason: str): + log.info("client 2 closed: " + reason, None) + + def start_client1(): + client1_close_requested = False + client1 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client1_connect, + on_client1_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_client2(): + client2 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client2_connect, + on_client2_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_client3(): + client3_close_requested = False + client3 = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client3_connect, + on_client3_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 3000) + attempts * 461) % 3000 + client1 = None + client2 = None + client3 = None + client1_close_requested = False + client3_close_requested = False + server_close_requested = False + client1_ready = False + client1_closed = False + client2_failed = False + client3_started = False + client3_ready = False + client3_closed = False + server_closed = False + session_ready_count = 0 + session_close_count = 0 + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + max_sessions=1, + ) + + after 0: start_server() + after 15.0: on_timeout() + + +def _test_ssh_session_limit(t: testing.EnvT): + """A rejected extra session should not poison the server.""" + SessionLimitTester(t) + + +actor ChannelLimitClientFactory(client: ssh.Client, + ch2_open_cb: action(ssh.Channel, ?str) -> None, + ch2_out_cb: action(ssh.Channel, ?bytes) -> None, + ch2_err_cb: action(ssh.Channel, ?bytes) -> None, + ch2_exit_cb: action(ssh.Channel, int, ?str) -> None, + ch2_close_cb: action(ssh.Channel, str) -> None, + ch3_open_cb: action(ssh.Channel, ?str) -> None, + ch3_out_cb: action(ssh.Channel, ?bytes) -> None, + ch3_err_cb: action(ssh.Channel, ?bytes) -> None, + ch3_exit_cb: action(ssh.Channel, int, ?str) -> None, + ch3_close_cb: action(ssh.Channel, str) -> None): + var _ch2: ?ssh.Channel = None + var _ch3: ?ssh.Channel = None + + action def open_ch2() -> None: + _ch2 = ssh.Channel(client, ch2_open_cb, ch2_out_cb, ch2_err_cb, ch2_exit_cb, ch2_close_cb) + + action def open_ch3() -> None: + _ch3 = ssh.Channel(client, ch3_open_cb, ch3_out_cb, ch3_err_cb, ch3_exit_cb, ch3_close_cb) + + +actor ChannelLimitTester(t: testing.EnvT): + log = logging.Logger(t.log_handler) + + MSG1 = b"one" + ACK1 = b"ack-one" + MSG3 = b"three" + ACK3 = b"ack-three" + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var session: ?ssh.ServerSession = None + var client: ?ssh.Client = None + var factory: ?ChannelLimitClientFactory = None + var ch1: ?ssh.Channel = None + var ch2: ?ssh.Channel = None + var ch3: ?ssh.Channel = None + var client_close_requested = False + var server_close_requested = False + var ch3_started = False + var ch1_acked = False + var ch2_failed = False + var ch1_closed = False + var ch3_acked = False + var ch3_closed = False + var client_closed = False + var session_closed = False + var server_closed = False + var server_channel_open_count = 0 + var server_channel_close_count = 0 + + def request_client_close(): + if client_close_requested: + return + client_close_requested = True + if client is not None: + client.close() + + def request_server_close(): + if server_close_requested: + return + server_close_requested = True + if server is not None: + server.close() + + def maybe_start_ch3(): + if done or ch3_started: + return + if not (ch2_failed and ch1_closed and server_channel_close_count >= 1): + return + if factory is not None: + ch3_started = True + factory.open_ch3() + else: + finish_error("client factory missing when reopening channel after limit rejection") + + def maybe_finish(): + if done: + return + if ch1_acked and ch2_failed and ch1_closed and ch3_acked and ch3_closed and \ + client_closed and session_closed and server_closed and \ + server_channel_open_count == 2 and server_channel_close_count == 2: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + request_client_close() + request_server_close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for channel limit test " + "(ch1_acked=" + str(ch1_acked) + + ", ch2_failed=" + str(ch2_failed) + + ", ch1_closed=" + str(ch1_closed) + + ", ch3_acked=" + str(ch3_acked) + + ", ch3_closed=" + str(ch3_closed) + + ", client_closed=" + str(client_closed) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + + ", server_channel_open_count=" + str(server_channel_open_count) + + ", server_channel_close_count=" + str(server_channel_close_count) + ")" + ) + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if attempts < 5: + request_server_close() + start_server() + return + finish_error("server listen error: " + err) + return + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session = sess + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + session = sess + server_channel_open_count += 1 + if server_channel_open_count > 2: + finish_error("unexpected extra server channel open") + return + ch = ssh.ServerChannel(sess, srv_on_data, srv_on_stderr, srv_on_close) + sess.accept_channel(ch) + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request: " + cmd) + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + if name != "echo": + finish_error("unexpected subsystem request: " + name) + return + ch.accept_request() + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + request_server_close() + maybe_finish() + + def srv_on_data(ch: ssh.ServerChannel, data: ?bytes): + if data is None: + return + if data == MSG1: + ch.write(ACK1) + return + if data == MSG3: + ch.write(ACK3) + return + finish_error("unexpected server payload: " + str(data)) + + def srv_on_stderr(ch: ssh.ServerChannel, data: ?bytes): + if data is not None: + finish_error("server received unexpected stderr") + + def srv_on_close(ch: ssh.ServerChannel, reason: str): + log.info("server channel closed: " + reason, None) + server_channel_close_count += 1 + maybe_start_ch3() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client = c + factory = ChannelLimitClientFactory(c, ch2_open, ch2_out, ch2_err, ch2_exit, ch2_close, + ch3_open, ch3_out, ch3_err, ch3_exit, ch3_close) + ch1 = ssh.Channel(c, ch1_open, ch1_out, ch1_err, ch1_exit, ch1_close) + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + client_closed = True + if session_closed: + request_server_close() + maybe_finish() + + def ch1_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("channel 1 open error: " + err) + return + ch1 = ch + ch.request_subsystem("echo") + ch.write(MSG1) + + def ch1_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if data != ACK1: + finish_error("channel 1 received unexpected payload: " + str(data)) + return + if ch1_acked: + finish_error("channel 1 received duplicate ACK") + return + ch1_acked = True + if factory is not None: + factory.open_ch2() + else: + finish_error("client factory missing for second channel open") + + def ch1_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel 1 received unexpected stderr") + + def ch1_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch1_close(ch: ssh.Channel, reason: str): + log.info("channel 1 closed: " + reason, None) + ch1_closed = True + maybe_start_ch3() + maybe_finish() + + def ch2_open(ch: ssh.Channel, err: ?str): + if err is None: + finish_error("channel 2 unexpectedly opened over channel limit") + return + ch2_failed = True + if ch1 is not None: + ch1.close() + maybe_start_ch3() + + def ch2_out(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel 2 received unexpected data") + + def ch2_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel 2 received unexpected stderr") + + def ch2_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch2_close(ch: ssh.Channel, reason: str): + return + + def ch3_open(ch: ssh.Channel, err: ?str): + if err is not None: + finish_error("channel 3 open error after freeing slot: " + err) + return + ch3 = ch + ch.request_subsystem("echo") + ch.write(MSG3) + + def ch3_out(ch: ssh.Channel, data: ?bytes): + if data is None: + return + if data != ACK3: + finish_error("channel 3 received unexpected payload: " + str(data)) + return + if ch3_acked: + finish_error("channel 3 received duplicate ACK") + return + ch3_acked = True + ch.close() + + def ch3_err(ch: ssh.Channel, data: ?bytes): + if data is not None: + finish_error("channel 3 received unexpected stderr") + + def ch3_exit(ch: ssh.Channel, code: int, sig: ?str): + return + + def ch3_close(ch: ssh.Channel, reason: str): + log.info("channel 3 closed: " + reason, None) + ch3_closed = True + request_client_close() + maybe_finish() + + def start_client(): + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = 50000 + (random.randint(0, 3000) + attempts * 503) % 3000 + session = None + client = None + factory = None + ch1 = None + ch2 = None + ch3 = None + client_close_requested = False + server_close_requested = False + ch3_started = False + ch1_acked = False + ch2_failed = False + ch1_closed = False + ch3_acked = False + ch3_closed = False + client_closed = False + session_closed = False + server_closed = False + server_channel_open_count = 0 + server_channel_close_count = 0 + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + max_channels_per_session=1, + ) + + after 0: start_server() + after 10.0: on_timeout() + + +def _test_ssh_channel_limit(t: testing.EnvT): + """A rejected extra channel should not poison the session.""" + ChannelLimitTester(t) + + +actor ServerCleanupTester(t: testing.EnvT): + t.require("gc_cleanup") + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var probe: ?ssh.Server = None + var server_listened = False + var server_collected = False + var probe_listened = False + var probe_closed = False + var gc_started = False + + def maybe_finish(): + if done: + return + if server_listened and server_collected and probe_listened and probe_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for GC-cleaned server " + "(server_listened=" + str(server_listened) + + ", server_collected=" + str(server_collected) + + ", probe_listened=" + str(probe_listened) + + ", probe_closed=" + str(probe_closed) + ")" + ) + + def drive_gc(): + if done: + return + blob = b"G" * (4 * 1024 * 1024) + work = 0 + for i in range(32768): + work += i + if len(blob) == work: + finish_error("unreachable GC drive path") + return + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + after 0.02: drive_gc() + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT: + start_server() + return + finish_error("server listen error: " + err) + return + server_listened = True + server = None + if not gc_started: + gc_started = True + after 0: drive_gc() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + if reason == "collected": + server_collected = True + if not probe_listened and probe is None: + start_probe() + maybe_finish() + + def on_probe_listen(s: ssh.Server, err: ?str): + probe = s + if err is not None: + log.info("probe listen error: " + err, None) + probe = None + if _is_retryable_listen_error(err): + after 0.1: start_probe() + return + finish_error("probe listen error: " + err) + return + probe_listened = True + s.close() + + def on_probe_close(s: ssh.Server, reason: str): + log.info("probe closed: " + reason, None) + probe_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + finish_error("unexpected session on GC-cleaned server test") + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + finish_error("unexpected auth request on GC-cleaned server test") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open on GC-cleaned server test") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request on GC-cleaned server test") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request on GC-cleaned server test") + + def on_session_close(sess: ssh.ServerSession, reason: str): + finish_error("unexpected session close on GC-cleaned server test") + + def start_probe(): + if done or probe_listened or probe is not None: + return + probe = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_probe_listen, + on_probe_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + def start_server(): + attempts += 1 + port = _pick_test_port(35000, 25000, attempts, 541) + server = None + probe = None + server_listened = False + server_collected = False + probe_listened = False + probe_closed = False + gc_started = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 60.0: on_timeout() + + +def _test_ssh_server_gc_cleanup(t: testing.EnvT): + """An unreachable server actor should eventually close itself.""" + ServerCleanupTester(t) + + +actor ClientCleanupTester(t: testing.EnvT): + t.require("gc_cleanup") + log = logging.Logger(t.log_handler) + + var done = False + var attempts = 0 + var port: int = 0 + var server: ?ssh.Server = None + var client: ?ssh.Client = None + var server_close_requested = False + var server_listened = False + var client_connected = False + var session_opened = False + var session_closed = False + var server_closed = False + var gc_started = False + + def maybe_finish(): + if done: + return + if server_listened and client_connected and session_opened and session_closed and server_closed: + done = True + t.success() + + def finish_error(msg: str): + if done: + return + done = True + log.info("test error: " + msg, None) + if not server_close_requested and server is not None: + server_close_requested = True + server.close() + t.error(Exception(msg)) + + def on_timeout(): + if done: + return + finish_error( + "timeout waiting for GC-cleaned client " + "(server_listened=" + str(server_listened) + + ", client_connected=" + str(client_connected) + + ", session_opened=" + str(session_opened) + + ", session_closed=" + str(session_closed) + + ", server_closed=" + str(server_closed) + ")" + ) + + def drive_gc(): + if done: + return + blob = b"G" * (4 * 1024 * 1024) + work = 0 + for i in range(32768): + work += i + if len(blob) == work: + finish_error("unreachable GC drive path") + return + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + acton.rts.gc(t.env.syscap) + after 0.02: drive_gc() + + def on_server_listen(s: ssh.Server, err: ?str): + server = s + if err is not None: + log.info("server listen error: " + err, None) + if _is_retryable_listen_error(err) and attempts < LISTEN_RETRY_LIMIT: + if not server_close_requested and server is not None: + server_close_requested = True + server.close() + start_server() + return + finish_error("server listen error: " + err) + return + server_listened = True + start_client() + + def on_server_close(s: ssh.Server, reason: str): + log.info("server closed: " + reason, None) + server_closed = True + maybe_finish() + + def on_session(sess: ssh.ServerSession): + session_opened = True + + def on_auth(sess: ssh.ServerSession, req: ssh.AuthRequest): + if req.method == "password" and req.user == "user" and req.password == "pass": + sess.accept_auth() + else: + sess.reject_auth("invalid credentials") + + def on_channel_open(sess: ssh.ServerSession): + finish_error("unexpected channel open on GC-cleaned client test") + + def on_exec(sess: ssh.ServerSession, ch: ssh.ServerChannel, cmd: str): + finish_error("unexpected exec request on GC-cleaned client test") + + def on_subsystem(sess: ssh.ServerSession, ch: ssh.ServerChannel, name: str): + finish_error("unexpected subsystem request on GC-cleaned client test") + + def on_session_close(sess: ssh.ServerSession, reason: str): + log.info("session closed: " + reason, None) + session_closed = True + if not server_close_requested and server is not None: + server_close_requested = True + server.close() + maybe_finish() + + def on_hostkey(c: ssh.Client, state: str, info: ssh.HostKeyInfo): + c.accept_hostkey() + + def on_client_connect(c: ssh.Client, err: ?str): + if err is not None: + finish_error("client connect error: " + err) + return + client_connected = True + client = None + if not gc_started: + gc_started = True + after 0: drive_gc() + + def on_client_close(c: ssh.Client, reason: str): + log.info("client closed: " + reason, None) + return + + def start_client(): + client = None + client_connected = False + session_opened = False + session_closed = False + gc_started = False + client = ssh.Client( + net.TCPConnectCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + "user", + on_client_connect, + on_client_close, + on_hostkey, + password="pass", + port=u16(port), + known_hosts="/tmp/acton_ssh_test_known_hosts", + ) + + def start_server(): + attempts += 1 + port = _pick_test_port(35000, 25000, attempts, 557) + client = None + server_close_requested = False + server_listened = False + client_connected = False + session_opened = False + session_closed = False + server_closed = False + server = ssh.Server( + net.TCPListenCap(net.TCPCap(net.NetCap(t.env.cap))), + "127.0.0.1", + u16(port), + on_server_listen, + on_server_close, + on_session, + on_auth, + on_channel_open, + on_exec, + on_subsystem, + on_session_close, + ) + + after 0: start_server() + after 20.0: on_timeout() + + +def _test_ssh_client_gc_cleanup(t: testing.EnvT): + """An unreachable client actor should eventually close itself.""" + ClientCleanupTester(t) + + +__modname__ = "test_ssh_server" + +__unit_tests: dict[str, testing.UnitTest] = { +} + +__simple_sync_tests: dict[str, testing.SimpleSyncTest] = { +} + +__sync_tests: dict[str, testing.SyncTest] = { +} + +__async_tests: dict[str, testing.AsyncTest] = { +} + +__env_tests: dict[str, testing.EnvTest] = { + "_test_ServerClientTester": testing.EnvTest(_test_ssh_server_subsystem, "_test_ServerClientTester", "", __modname__), + "_test_ClientCloseFlushTester": testing.EnvTest(_test_ssh_client_close_flush, "_test_ClientCloseFlushTester", "", __modname__), + "_test_ClientSessionCloseFlushTester": testing.EnvTest(_test_ssh_client_session_close_flush, "_test_ClientSessionCloseFlushTester", "", __modname__), + "_test_AuthRejectTester": testing.EnvTest(_test_ssh_auth_reject, "_test_AuthRejectTester", "", __modname__), + "_test_ServerSessionCloseFlushTester": testing.EnvTest(_test_ssh_server_session_close_flush, "_test_ServerSessionCloseFlushTester", "", __modname__), + "_test_ConcurrentChannelTester": testing.EnvTest(_test_ssh_concurrent_channels, "_test_ConcurrentChannelTester", "", __modname__), + "_test_ConcurrentChannelWriteTester": testing.EnvTest(_test_ssh_concurrent_channel_writes, "_test_ConcurrentChannelWriteTester", "", __modname__), + "_test_KeepaliveTrafficTester": testing.EnvTest(_test_ssh_keepalive_traffic, "_test_KeepaliveTrafficTester", "", __modname__), + "_test_CrossHandleOwnershipTester": testing.EnvTest(_test_ssh_cross_handle_ownership, "_test_CrossHandleOwnershipTester", "", __modname__), + "_test_UnsupportedRequestTester": testing.EnvTest(_test_ssh_unsupported_requests, "_test_UnsupportedRequestTester", "", __modname__), + "_test_RunCommandTimeoutTester": testing.EnvTest(_test_ssh_runcommand_timeout, "_test_RunCommandTimeoutTester", "", __modname__), + "_test_RunCommandExitStatusTester": testing.EnvTest(_test_ssh_runcommand_exit_status, "_test_RunCommandExitStatusTester", "", __modname__), + "_test_ConcurrentRunCommandTester": testing.EnvTest(_test_ssh_concurrent_runcommand, "_test_ConcurrentRunCommandTester", "", __modname__), + "_test_ConcurrentRunCommandRejectTester": testing.EnvTest(_test_ssh_concurrent_runcommand_reject, "_test_ConcurrentRunCommandRejectTester", "", __modname__), + "_test_AuthTimeoutTester": testing.EnvTest(_test_ssh_auth_timeout, "_test_AuthTimeoutTester", "", __modname__), + "_test_HostkeyWaitDisconnectTester": testing.EnvTest(_test_ssh_hostkey_wait_disconnect, "_test_HostkeyWaitDisconnectTester", "", __modname__), + "_test_CloseCallbackReentryTester": testing.EnvTest(_test_ssh_close_callback_reentry, "_test_CloseCallbackReentryTester", "", __modname__), + "_test_KeyExchangeTimeoutTester": testing.EnvTest(_test_ssh_key_exchange_timeout, "_test_KeyExchangeTimeoutTester", "", __modname__), + "_test_ServerSurvivesGarbageTester": testing.EnvTest(_test_ssh_server_survives_garbage, "_test_ServerSurvivesGarbageTester", "", __modname__), + "_test_RapidDisconnectChurnTester": testing.EnvTest(_test_ssh_rapid_disconnect_churn, "_test_RapidDisconnectChurnTester", "", __modname__), + "_test_CloseRejectsLateWritesTester": testing.EnvTest(_test_ssh_close_rejects_late_writes, "_test_CloseRejectsLateWritesTester", "", __modname__), + "_test_SessionLimitTester": testing.EnvTest(_test_ssh_session_limit, "_test_SessionLimitTester", "", __modname__), + "_test_ChannelLimitTester": testing.EnvTest(_test_ssh_channel_limit, "_test_ChannelLimitTester", "", __modname__), + "_test_ServerCleanupTester": testing.EnvTest(_test_ssh_server_gc_cleanup, "_test_ServerCleanupTester", "", __modname__), + "_test_ClientCleanupTester": testing.EnvTest(_test_ssh_client_gc_cleanup, "_test_ClientCleanupTester", "", __modname__), +} + + +actor __test_main(env): + testing.test_runner(env, __unit_tests, __simple_sync_tests, __sync_tests, __async_tests, __env_tests)