From 0591b5d16b2e22e2aeac3fc8ed6a213110342b27 Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Thu, 15 Jan 2026 15:45:50 +0100 Subject: [PATCH 1/9] feat(socket): initial scaffold for client/server ConnManager --- Cargo.toml | 1 + msg-socket/src/connection/backoff.rs | 6 +- msg-socket/src/req/conn_manager.rs | 224 ++++++++++++++++++++++++++- msg-socket/src/req/driver.rs | 15 +- msg-socket/src/req/mod.rs | 8 +- 5 files changed, 243 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9284321e..93d807d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -173,6 +173,7 @@ unused_rounding = "warn" use_self = "warn" useless_let_if_seq = "warn" zero_sized_map_values = "warn" +default_trait_access = "warn" # These are nursery lints which have findings. Allow them for now. Some are not # quite mature enough for use in our codebase and some we don't really want. diff --git a/msg-socket/src/connection/backoff.rs b/msg-socket/src/connection/backoff.rs index fbca13e3..002624dd 100644 --- a/msg-socket/src/connection/backoff.rs +++ b/msg-socket/src/connection/backoff.rs @@ -6,7 +6,7 @@ use std::{ }; use tokio::time::sleep; -use crate::ConnOptions; +use crate::ClientOptions; /// Helper trait alias for backoff streams. /// We define any stream that yields `Duration`s as a backoff @@ -41,8 +41,8 @@ impl ExponentialBackoff { } } -impl From<&ConnOptions> for ExponentialBackoff { - fn from(options: &ConnOptions) -> Self { +impl From<&ClientOptions> for ExponentialBackoff { + fn from(options: &ClientOptions) -> Self { Self::new(options.backoff_duration, options.retry_attempts) } } diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs index 1a11197b..70d960f1 100644 --- a/msg-socket/src/req/conn_manager.rs +++ b/msg-socket/src/req/conn_manager.rs @@ -5,13 +5,14 @@ use std::{ task::{Context, Poll}, }; +use arc_swap::ArcSwap; use bytes::Bytes; use futures::{Future, FutureExt, SinkExt, StreamExt}; use msg_common::span::{EnterSpan as _, WithSpan}; use tokio_util::codec::Framed; use tracing::Instrument; -use crate::{ConnOptions, ConnectionState, ExponentialBackoff}; +use crate::{ClientOptions, ConnectionState, ExponentialBackoff}; use msg_transport::{Address, MeteredIo, Transport}; use msg_wire::{auth, reqrep}; @@ -33,10 +34,227 @@ pub(crate) type Conn = Framed, reqrep::Codec>; /// A connection controller that manages the connection to a server with an exponential backoff. pub(crate) type ConnCtl = ConnectionState, ExponentialBackoff, A>; +/// A connection manager for managing client OR server connections. +/// The type parameter `S` contains the connection state, including its "side" (client / server). +pub(crate) struct ConnectionManager +where + T: Transport, + A: Address, +{ + /// The connection state, including its "side" (client / server). + state: S, + /// The transport used for the connection. + transport: T, + /// Transport stats for metering IO. + transport_stats: Arc>, + + /// Connection manager tracing span. + span: tracing::Span, +} + +/// A client connection to a remote server. +pub(crate) struct ClientConnection +where + T: Transport, + A: Address, +{ + /// Options for the connection manager. + options: ClientOptions, + /// The address of the remote. + addr: A, + /// The connection task which handles the connection to the server. + conn_task: Option>>, + /// The transport controller, wrapped in a [`ConnectionState`] for backoff. + /// The [`Framed`] object can send and receive messages from the socket. + conn_ctl: ConnCtl, +} + +impl ClientConnection +where + T: Transport, + A: Address, +{ + /// Reset the connection state to inactive, so that it will be re-tried. + /// + /// This is done when the connection is closed or an error occurs. + #[inline] + pub(crate) fn reset_connection(&mut self) { + self.conn_ctl = ConnectionState::Inactive { + addr: self.addr.clone(), + backoff: ExponentialBackoff::from(&self.options), + }; + } + + /// Returns a mutable reference to the connection channel if it is active. + #[inline] + pub fn active_connection(&mut self) -> Option<&mut Conn> { + if let ConnectionState::Active { ref mut channel } = self.conn_ctl { + Some(channel) + } else { + None + } + } +} + +/// A local server connection. Manages the connection lifecycle: +/// - Accepting incoming connections. +/// - Handling established connections. +pub(crate) struct ServerConnection +where + T: Transport, + A: Address, +{ + /// The local address. + addr: A, + /// The accept task which handles accepting an incoming connection. + accept_task: Option>, + /// The inbound connection. + conn: Conn, +} + +impl ConnectionManager> +where + T: Transport, + A: Address, +{ + pub(crate) fn new( + options: ClientOptions, + transport: T, + addr: A, + conn_ctl: ConnCtl, + transport_stats: Arc>, + span: tracing::Span, + ) -> Self { + let conn = ClientConnection { options, addr, conn_task: None, conn_ctl }; + + Self { state: conn, transport, transport_stats, span } + } + + /// Start the connection task to the server, handling authentication if necessary. + /// The result will be polled by the driver and re-tried according to the backoff policy. + fn try_connect(&mut self) { + let connect = self.transport.connect(self.state.addr.clone()); + let token = self.state.options.auth_token.clone(); + + let task = async move { + let io = connect.await?; + + let Some(token) = token else { + return Ok(io); + }; + + authentication_handshake::(io, token).await + } + .in_current_span(); + + // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] + self.state.conn_task = Some(WithSpan::current(Box::pin(task))); + } + + /// Reset the connection state to inactive, so that it will be re-tried. + /// + /// This is done when the connection is closed or an error occurs. + #[inline] + pub(crate) fn reset_connection(&mut self) { + self.state.reset_connection(); + } + + /// Poll connection management logic: connection task, backoff, and retry logic. + /// Loops until the connection is active, then returns a mutable reference to the channel. + /// + /// Note: this is not a `Future` impl because we want to return a reference; doing it in + /// a `Future` would require lifetime headaches or unsafe code. + /// + /// Returns: + /// * `Poll::Ready(Some(&mut channel))` if the connection is active + /// * `Poll::Ready(None)` if we should terminate (max retries exceeded) + /// * `Poll::Pending` if we need to wait for backoff + #[allow(clippy::type_complexity)] + pub(crate) fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + // Poll the active connection task, if any + if let Some(ref mut conn_task) = self.state.conn_task { + if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() { + // As soon as the connection task finishes, set it to `None`. + // - If it was successful, set the connection to active + // - If it failed, it will be re-tried until the backoff limit is reached. + self.state.conn_task = None; + + match result.inner { + Ok(io) => { + tracing::info!("connected"); + + let metered = MeteredIo::new(io, self.transport_stats.clone()); + let framed = Framed::new(metered, reqrep::Codec::new()); + self.state.conn_ctl = ConnectionState::Active { channel: framed }; + } + Err(e) => { + tracing::error!(?e, "failed to connect"); + } + } + } + } + + // If the connection is inactive, try to connect to the server or poll the backoff + // timer if we're already trying to connect. + if let ConnectionState::Inactive { backoff, .. } = &mut self.state.conn_ctl { + let Poll::Ready(item) = backoff.poll_next_unpin(cx) else { + return Poll::Pending; + }; + + let _span = tracing::info_span!(parent: &self.span, "connect").entered(); + + if let Some(duration) = item { + if self.state.conn_task.is_none() { + tracing::debug!(backoff = ?duration, "trying connection"); + self.try_connect(); + } else { + tracing::debug!( + backoff = ?duration, + "not retrying as there is already a connection task" + ); + } + } else { + tracing::error!("exceeded maximum number of retries, terminating connection"); + return Poll::Ready(None); + } + } + + if let ConnectionState::Active { ref mut channel } = self.state.conn_ctl { + return Poll::Ready(Some(channel)); + } + } + } +} + +pub struct ServerOptions {} + +impl ConnectionManager> +where + T: Transport, + A: Address, +{ + pub(crate) fn new( + options: ServerOptions, + transport: T, + addr: A, + conn: Conn, + transport_stats: Arc>, + span: tracing::Span, + ) -> Self { + let conn = ServerConnection { addr, accept_task: None, conn }; + + Self { state: conn, transport, transport_stats, span } + } +} + /// Manages the connection lifecycle: connecting, reconnecting, and maintaining the connection. pub(crate) struct ConnManager, A: Address> { /// Options for the connection manager. - options: ConnOptions, + options: ClientOptions, /// The connection task which handles the connection to the server. conn_task: Option>>, /// The transport controller, wrapped in a [`ConnectionState`] for backoff. @@ -89,7 +307,7 @@ where A: Address, { pub(crate) fn new( - options: ConnOptions, + options: ClientOptions, transport: T, addr: A, conn_ctl: ConnCtl, diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index a87bd7a6..4da8b482 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -27,9 +27,22 @@ use msg_wire::{ reqrep, }; +/// Type state for a client connection. +struct ClientConnection +where + T: Transport, + A: Address, +{ + conn_manager: ConnManager, +} + /// The request socket driver. Endless future that drives /// the socket forward. -pub(crate) struct ReqDriver, A: Address> { +pub(crate) struct ReqDriver +where + T: Transport, + A: Address, +{ /// Options shared with the socket. pub(crate) options: Arc, /// State shared with the socket. diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index c70b9eaa..efeb76dd 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -68,7 +68,7 @@ impl SendCommand { /// Options for the connection manager. #[derive(Debug, Clone)] -pub struct ConnOptions { +pub struct ClientOptions { /// Optional authentication token. pub auth_token: Option, /// The backoff duration for the underlying transport on reconnections. @@ -77,7 +77,7 @@ pub struct ConnOptions { pub retry_attempts: Option, } -impl Default for ConnOptions { +impl Default for ClientOptions { fn default() -> Self { Self { auth_token: None, @@ -96,7 +96,7 @@ impl Default for ConnOptions { #[derive(Debug, Clone)] pub struct ReqOptions { /// Options for the connection manager. - pub conn: ConnOptions, + pub conn: ClientOptions, /// Timeout duration for requests. pub timeout: Duration, /// Wether to block on initial connection to the target. @@ -211,7 +211,7 @@ impl ReqOptions { impl Default for ReqOptions { fn default() -> Self { Self { - conn: ConnOptions::default(), + conn: ClientOptions::default(), timeout: Duration::from_secs(5), blocking_connect: false, min_compress_size: 8192, From 5551b0c46a8579107dc3f37a4b463f082a6c3fb4 Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Thu, 15 Jan 2026 16:32:55 +0100 Subject: [PATCH 2/9] feat(socket): wip integration --- msg-socket/src/req/conn_manager.rs | 12 ++++++ msg-socket/src/req/mod.rs | 12 +++--- msg-socket/src/req/socket.rs | 65 +++++++++++++++++++++++++++--- 3 files changed, 78 insertions(+), 11 deletions(-) diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs index 70d960f1..2551ec64 100644 --- a/msg-socket/src/req/conn_manager.rs +++ b/msg-socket/src/req/conn_manager.rs @@ -52,6 +52,18 @@ where span: tracing::Span, } +impl ConnectionManager +where + T: Transport, + A: Address, +{ + /// Set the connection manager tracing span. + pub(crate) fn with_span(mut self, span: tracing::Span) -> Self { + self.span = span; + self + } +} + /// A client connection to a remote server. pub(crate) struct ClientConnection where diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index efeb76dd..263558f7 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -95,8 +95,8 @@ impl Default for ClientOptions { /// The request socket options. #[derive(Debug, Clone)] pub struct ReqOptions { - /// Options for the connection manager. - pub conn: ClientOptions, + /// Client options for the connection manager. + pub client: ClientOptions, /// Timeout duration for requests. pub timeout: Duration, /// Wether to block on initial connection to the target. @@ -151,7 +151,7 @@ impl ReqOptions { impl ReqOptions { /// Sets the authentication token for the socket. pub fn with_auth_token(mut self, auth_token: Bytes) -> Self { - self.conn.auth_token = Some(auth_token); + self.client.auth_token = Some(auth_token); self } @@ -169,7 +169,7 @@ impl ReqOptions { /// Sets the backoff duration for the socket. pub fn with_backoff_duration(mut self, backoff_duration: Duration) -> Self { - self.conn.backoff_duration = backoff_duration; + self.client.backoff_duration = backoff_duration; self } @@ -177,7 +177,7 @@ impl ReqOptions { /// /// If `None`, all connections will be retried indefinitely. pub fn with_retry_attempts(mut self, retry_attempts: usize) -> Self { - self.conn.retry_attempts = Some(retry_attempts); + self.client.retry_attempts = Some(retry_attempts); self } @@ -211,7 +211,7 @@ impl ReqOptions { impl Default for ReqOptions { fn default() -> Self { Self { - conn: ClientOptions::default(), + client: ClientOptions::default(), timeout: Duration::from_secs(5), blocking_connect: false, min_compress_size: 8192, diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 98f54462..f461dd4c 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -23,7 +23,7 @@ use crate::{ ConnectionState, DRIVER_ID, ExponentialBackoff, ReqMessage, SendCommand, req::{ SocketState, - conn_manager::{ConnCtl, ConnManager}, + conn_manager::{ClientConnection, ConnCtl, ConnManager, ConnectionManager}, driver::ReqDriver, stats::ReqStats, }, @@ -69,7 +69,7 @@ where // by the backend task as soon as the driver is spawned. let conn_state = ConnectionState::Inactive { addr, - backoff: ExponentialBackoff::from(&self.options.conn), + backoff: ExponentialBackoff::from(&self.options.client), }; self.spawn_driver(addr, transport, conn_state) @@ -158,15 +158,70 @@ where // by the backend task as soon as the driver is spawned. ConnectionState::Inactive { addr: endpoint.clone(), - backoff: ExponentialBackoff::from(&self.options.conn), + backoff: ExponentialBackoff::from(&self.options.client), } }; - self.spawn_driver(endpoint, transport, conn_state); + // Initialize client-side connection manager + let conn_manager = ConnectionManager::>::new( + self.options.client.clone(), + transport, + endpoint, + conn_state, + Arc::clone(&self.state.transport_stats), + tracing::Span::none(), + ); + + // Spawn the driver task + self.spawn(conn_manager); Ok(()) } + fn spawn(&mut self, mut conn_manager: ConnectionManager) { + // Initialize communication channels + let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); + + let timeout_check_interval = tokio::time::interval(self.options.timeout / 10); + + // TODO: we should limit the amount of active outgoing requests, and that should be the + // capacity. If we do this, we'll never have to re-allocate. + let pending_requests = FxHashMap::default(); + + let id = DRIVER_ID.fetch_add(1, Ordering::Relaxed); + let span = tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), addr = ?endpoint); + + // Set driver span + conn_manager = conn_manager.with_span(span); + + let linger_timer = self.options.write_buffer_linger.map(|duration| { + let mut timer = tokio::time::interval(duration); + timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + timer + }); + + // Create the socket backend + let driver: ReqDriver = ReqDriver { + options: Arc::clone(&self.options), + socket_state: self.state.clone(), + id_counter: 0, + from_socket, + conn_manager, + linger_timer, + pending_requests, + timeout_check_interval, + egress_queue: Default::default(), + compressor: self.compressor.clone(), + id, + span, + }; + + // Spawn the backend task + tokio::spawn(driver); + + self.to_driver = Some(to_driver); + } + /// Internal method to initialize and spawn the driver. fn spawn_driver(&mut self, endpoint: A, transport: T, conn_ctl: ConnCtl) { // Initialize communication channels @@ -189,7 +244,7 @@ where // Create connection manager let conn_manager = ConnManager::new( - self.options.conn.clone(), + self.options.client.clone(), transport, endpoint, conn_ctl, From f85bf3d302044d1ed81357e5a694e64f1e2b530d Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Fri, 16 Jan 2026 12:10:36 +0100 Subject: [PATCH 3/9] feat(socket): progress on new connection manager --- msg-socket/src/req/conn_manager.rs | 356 +++++++++++------------------ msg-socket/src/req/driver.rs | 23 +- msg-socket/src/req/socket.rs | 83 ++----- 3 files changed, 171 insertions(+), 291 deletions(-) diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs index 2551ec64..eeaacecd 100644 --- a/msg-socket/src/req/conn_manager.rs +++ b/msg-socket/src/req/conn_manager.rs @@ -34,6 +34,28 @@ pub(crate) type Conn = Framed, reqrep::Codec>; /// A connection controller that manages the connection to a server with an exponential backoff. pub(crate) type ConnCtl = ConnectionState, ExponentialBackoff, A>; +/// Trait for interacting with the connection, regardless of its "side" (client or server). +pub(crate) trait ConnectionController +where + T: Transport, + A: Address, +{ + /// Polls the connection logic, and returns a mutable reference to the connection if it's ready. + fn poll( + &mut self, + transport: &mut T, + stats: &Arc>, + span: &tracing::Span, + cx: &mut Context<'_>, + ) -> Poll>>; + + /// Resets the connection controller. + fn reset(&mut self); + + /// Returns a mutable reference to the active connection, if it exists. + fn active_connection(&mut self) -> Option<&mut Conn>; +} + /// A connection manager for managing client OR server connections. /// The type parameter `S` contains the connection state, including its "side" (client / server). pub(crate) struct ConnectionManager @@ -81,25 +103,96 @@ where conn_ctl: ConnCtl, } -impl ClientConnection +impl ConnectionController for ClientConnection where T: Transport, A: Address, { + /// Poll connection management logic: connection task, backoff, and retry logic. + /// Loops until the connection is active, then returns a mutable reference to the channel. + /// + /// Note: this is not a `Future` impl because we want to return a reference; doing it in + /// a `Future` would require lifetime headaches or unsafe code. + /// + /// Returns: + /// * `Poll::Ready(Some(&mut channel))` if the connection is active + /// * `Poll::Ready(None)` if we should terminate (max retries exceeded) + /// * `Poll::Pending` if we need to wait for backoff + fn poll( + &mut self, + transport: &mut T, + stats: &Arc>, + span: &tracing::Span, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + // Poll the active connection task, if any + if let Some(ref mut conn_task) = self.conn_task { + if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() { + // As soon as the connection task finishes, set it to `None`. + // - If it was successful, set the connection to active + // - If it failed, it will be re-tried until the backoff limit is reached. + self.conn_task = None; + + match result.inner { + Ok(io) => { + tracing::info!("connected"); + + let metered = MeteredIo::new(io, stats.clone()); + let framed = Framed::new(metered, reqrep::Codec::new()); + self.conn_ctl = ConnectionState::Active { channel: framed }; + } + Err(e) => { + tracing::error!(?e, "failed to connect"); + } + } + } + } + + // If the connection is inactive, try to connect to the server or poll the backoff + // timer if we're already trying to connect. + if let ConnectionState::Inactive { backoff, .. } = &mut self.conn_ctl { + let Poll::Ready(item) = backoff.poll_next_unpin(cx) else { + return Poll::Pending; + }; + + let _span = tracing::info_span!(parent: span, "connect").entered(); + + if let Some(duration) = item { + if self.conn_task.is_none() { + tracing::debug!(backoff = ?duration, "trying connection"); + self.try_connect(transport); + } else { + tracing::debug!( + backoff = ?duration, + "not retrying as there is already a connection task" + ); + } + } else { + tracing::error!("exceeded maximum number of retries, terminating connection"); + return Poll::Ready(None); + } + } + + if let ConnectionState::Active { ref mut channel } = self.conn_ctl { + return Poll::Ready(Some(channel)); + } + } + } + /// Reset the connection state to inactive, so that it will be re-tried. /// /// This is done when the connection is closed or an error occurs. #[inline] - pub(crate) fn reset_connection(&mut self) { + fn reset(&mut self) { self.conn_ctl = ConnectionState::Inactive { addr: self.addr.clone(), backoff: ExponentialBackoff::from(&self.options), }; } - /// Returns a mutable reference to the connection channel if it is active. - #[inline] - pub fn active_connection(&mut self) -> Option<&mut Conn> { + /// Returns a mutable reference to the active connection, if it exists. + fn active_connection(&mut self) -> Option<&mut Conn> { if let ConnectionState::Active { ref mut channel } = self.conn_ctl { Some(channel) } else { @@ -108,6 +201,33 @@ where } } +impl ClientConnection +where + T: Transport, + A: Address, +{ + /// Start the connection task to the server, handling authentication if necessary. + /// The result will be polled by the driver and re-tried according to the backoff policy. + fn try_connect(&mut self, transport: &mut T) { + let connect = transport.connect(self.addr.clone()); + let token = self.options.auth_token.clone(); + + let task = async move { + let io = connect.await?; + + let Some(token) = token else { + return Ok(io); + }; + + authentication_handshake::(io, token).await + } + .in_current_span(); + + // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] + self.conn_task = Some(WithSpan::current(Box::pin(task))); + } +} + /// A local server connection. Manages the connection lifecycle: /// - Accepting incoming connections. /// - Handling established connections. @@ -162,83 +282,34 @@ where // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] self.state.conn_task = Some(WithSpan::current(Box::pin(task))); } +} +impl ConnectionManager +where + T: Transport, + A: Address, + C: ConnectionController, +{ /// Reset the connection state to inactive, so that it will be re-tried. /// /// This is done when the connection is closed or an error occurs. #[inline] pub(crate) fn reset_connection(&mut self) { - self.state.reset_connection(); + self.state.reset(); } - /// Poll connection management logic: connection task, backoff, and retry logic. - /// Loops until the connection is active, then returns a mutable reference to the channel. - /// - /// Note: this is not a `Future` impl because we want to return a reference; doing it in - /// a `Future` would require lifetime headaches or unsafe code. - /// - /// Returns: - /// * `Poll::Ready(Some(&mut channel))` if the connection is active - /// * `Poll::Ready(None)` if we should terminate (max retries exceeded) - /// * `Poll::Pending` if we need to wait for backoff + /// Poll the connection controller. #[allow(clippy::type_complexity)] pub(crate) fn poll( &mut self, cx: &mut Context<'_>, ) -> Poll>> { - loop { - // Poll the active connection task, if any - if let Some(ref mut conn_task) = self.state.conn_task { - if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() { - // As soon as the connection task finishes, set it to `None`. - // - If it was successful, set the connection to active - // - If it failed, it will be re-tried until the backoff limit is reached. - self.state.conn_task = None; - - match result.inner { - Ok(io) => { - tracing::info!("connected"); - - let metered = MeteredIo::new(io, self.transport_stats.clone()); - let framed = Framed::new(metered, reqrep::Codec::new()); - self.state.conn_ctl = ConnectionState::Active { channel: framed }; - } - Err(e) => { - tracing::error!(?e, "failed to connect"); - } - } - } - } - - // If the connection is inactive, try to connect to the server or poll the backoff - // timer if we're already trying to connect. - if let ConnectionState::Inactive { backoff, .. } = &mut self.state.conn_ctl { - let Poll::Ready(item) = backoff.poll_next_unpin(cx) else { - return Poll::Pending; - }; - - let _span = tracing::info_span!(parent: &self.span, "connect").entered(); - - if let Some(duration) = item { - if self.state.conn_task.is_none() { - tracing::debug!(backoff = ?duration, "trying connection"); - self.try_connect(); - } else { - tracing::debug!( - backoff = ?duration, - "not retrying as there is already a connection task" - ); - } - } else { - tracing::error!("exceeded maximum number of retries, terminating connection"); - return Poll::Ready(None); - } - } + self.state.poll(&mut self.transport, &self.transport_stats, &self.span, cx) + } - if let ConnectionState::Active { ref mut channel } = self.state.conn_ctl { - return Poll::Ready(Some(channel)); - } - } + /// Returns a mutable reference to the active connection, if it exists. + pub(crate) fn active_connection(&mut self) -> Option<&mut Conn> { + self.state.active_connection() } } @@ -263,26 +334,6 @@ where } } -/// Manages the connection lifecycle: connecting, reconnecting, and maintaining the connection. -pub(crate) struct ConnManager, A: Address> { - /// Options for the connection manager. - options: ClientOptions, - /// The connection task which handles the connection to the server. - conn_task: Option>>, - /// The transport controller, wrapped in a [`ConnectionState`] for backoff. - /// The [`Framed`] object can send and receive messages from the socket. - conn_ctl: ConnCtl, - /// The transport for this socket. - transport: T, - /// The address of the server. - addr: A, - /// Transport stats for metering IO. - transport_stats: Arc>, - - /// A span to use for connection-related logging. - span: tracing::Span, -} - /// Perform the authentication handshake with the server. #[tracing::instrument(skip_all, "auth", fields(token = ?token))] async fn authentication_handshake(mut io: T::Io, token: Bytes) -> Result @@ -312,132 +363,3 @@ where Err(e) => Err(io::Error::new(io::ErrorKind::PermissionDenied, e).into()), } } - -impl ConnManager -where - T: Transport, - A: Address, -{ - pub(crate) fn new( - options: ClientOptions, - transport: T, - addr: A, - conn_ctl: ConnCtl, - transport_stats: Arc>, - span: tracing::Span, - ) -> Self { - Self { options, conn_task: None, conn_ctl, transport, addr, transport_stats, span } - } - - /// Start the connection task to the server, handling authentication if necessary. - /// The result will be polled by the driver and re-tried according to the backoff policy. - fn try_connect(&mut self) { - let connect = self.transport.connect(self.addr.clone()); - let token = self.options.auth_token.clone(); - - let task = async move { - let io = connect.await?; - - let Some(token) = token else { - return Ok(io); - }; - - authentication_handshake::(io, token).await - } - .in_current_span(); - - // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] - self.conn_task = Some(WithSpan::current(Box::pin(task))); - } - - /// Reset the connection state to inactive, so that it will be re-tried. - /// - /// This is done when the connection is closed or an error occurs. - #[inline] - pub(crate) fn reset_connection(&mut self) { - self.conn_ctl = ConnectionState::Inactive { - addr: self.addr.clone(), - backoff: ExponentialBackoff::from(&self.options), - }; - } - - /// Returns a mutable reference to the connection channel if it is active. - #[inline] - pub(crate) fn active_connection(&mut self) -> Option<&mut Conn> { - if let ConnectionState::Active { ref mut channel } = self.conn_ctl { - Some(channel) - } else { - None - } - } - - /// Poll connection management logic: connection task, backoff, and retry logic. - /// Loops until the connection is active, then returns a mutable reference to the channel. - /// - /// Note: this is not a `Future` impl because we want to return a reference; doing it in - /// a `Future` would require lifetime headaches or unsafe code. - /// - /// Returns: - /// * `Poll::Ready(Some(&mut channel))` if the connection is active - /// * `Poll::Ready(None)` if we should terminate (max retries exceeded) - /// * `Poll::Pending` if we need to wait for backoff - #[allow(clippy::type_complexity)] - pub(crate) fn poll( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { - loop { - // Poll the active connection task, if any - if let Some(ref mut conn_task) = self.conn_task { - if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() { - // As soon as the connection task finishes, set it to `None`. - // - If it was successful, set the connection to active - // - If it failed, it will be re-tried until the backoff limit is reached. - self.conn_task = None; - - match result.inner { - Ok(io) => { - tracing::info!("connected"); - - let metered = MeteredIo::new(io, self.transport_stats.clone()); - let framed = Framed::new(metered, reqrep::Codec::new()); - self.conn_ctl = ConnectionState::Active { channel: framed }; - } - Err(e) => { - tracing::error!(?e, "failed to connect"); - } - } - } - } - - // If the connection is inactive, try to connect to the server or poll the backoff - // timer if we're already trying to connect. - if let ConnectionState::Inactive { backoff, .. } = &mut self.conn_ctl { - let Poll::Ready(item) = backoff.poll_next_unpin(cx) else { - return Poll::Pending; - }; - - let _span = tracing::info_span!(parent: &self.span, "connect").entered(); - - if let Some(duration) = item { - if self.conn_task.is_none() { - tracing::debug!(backoff = ?duration, "trying connection"); - self.try_connect(); - } else { - tracing::debug!( - backoff = ?duration, - "not retrying as there is already a connection task" - ); - } - } else { - tracing::error!("exceeded maximum number of retries, terminating connection"); - return Poll::Ready(None); - } - } - - if let ConnectionState::Active { ref mut channel } = self.conn_ctl { - return Poll::Ready(Some(channel)); - } - } - } -} diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index 4da8b482..74dd4878 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -17,7 +17,10 @@ use tokio::{ use super::{ReqError, ReqOptions}; use crate::{ SendCommand, - req::{SocketState, conn_manager::ConnManager}, + req::{ + SocketState, + conn_manager::{ConnectionController, ConnectionManager}, + }, }; use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan}; @@ -27,18 +30,9 @@ use msg_wire::{ reqrep, }; -/// Type state for a client connection. -struct ClientConnection -where - T: Transport, - A: Address, -{ - conn_manager: ConnManager, -} - /// The request socket driver. Endless future that drives /// the socket forward. -pub(crate) struct ReqDriver +pub(crate) struct ReqDriver where T: Transport, A: Address, @@ -52,7 +46,7 @@ where /// Commands from the socket. pub(crate) from_socket: mpsc::Receiver, /// Connection manager that handles connection lifecycle. - pub(crate) conn_manager: ConnManager, + pub(crate) conn_manager: ConnectionManager, /// The timer for the write buffer linger. pub(crate) linger_timer: Option, /// The outgoing message queue. @@ -79,7 +73,7 @@ pub(crate) struct PendingRequest { sender: oneshot::Sender>, } -impl ReqDriver +impl ReqDriver where T: Transport, A: Address, @@ -178,10 +172,11 @@ where } } -impl Future for ReqDriver +impl Future for ReqDriver where T: Transport, A: Address, + S: ConnectionController + Unpin, { type Output = (); diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index f461dd4c..270964f9 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -23,7 +23,7 @@ use crate::{ ConnectionState, DRIVER_ID, ExponentialBackoff, ReqMessage, SendCommand, req::{ SocketState, - conn_manager::{ClientConnection, ConnCtl, ConnManager, ConnectionManager}, + conn_manager::{ClientConnection, ConnectionController, ConnectionManager}, driver::ReqDriver, stats::ReqStats, }, @@ -72,7 +72,17 @@ where backoff: ExponentialBackoff::from(&self.options.client), }; - self.spawn_driver(addr, transport, conn_state) + // Initialize client-side connection manager + let conn_manager = ConnectionManager::>::new( + self.options.client.clone(), + transport, + addr.clone(), + conn_state, + Arc::clone(&self.state.transport_stats), + tracing::Span::none(), + ); + + self.spawn(addr, conn_manager) } } @@ -166,19 +176,23 @@ where let conn_manager = ConnectionManager::>::new( self.options.client.clone(), transport, - endpoint, + endpoint.clone(), conn_state, Arc::clone(&self.state.transport_stats), tracing::Span::none(), ); // Spawn the driver task - self.spawn(conn_manager); + self.spawn(endpoint, conn_manager); Ok(()) } - fn spawn(&mut self, mut conn_manager: ConnectionManager) { + fn spawn + Send + Unpin + 'static>( + &mut self, + addr: A, + mut conn_manager: ConnectionManager, + ) { // Initialize communication channels let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); @@ -189,10 +203,11 @@ where let pending_requests = FxHashMap::default(); let id = DRIVER_ID.fetch_add(1, Ordering::Relaxed); - let span = tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), addr = ?endpoint); + let span = + tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), ?addr); // Set driver span - conn_manager = conn_manager.with_span(span); + conn_manager = conn_manager.with_span(span.clone()); let linger_timer = self.options.write_buffer_linger.map(|duration| { let mut timer = tokio::time::interval(duration); @@ -201,59 +216,7 @@ where }); // Create the socket backend - let driver: ReqDriver = ReqDriver { - options: Arc::clone(&self.options), - socket_state: self.state.clone(), - id_counter: 0, - from_socket, - conn_manager, - linger_timer, - pending_requests, - timeout_check_interval, - egress_queue: Default::default(), - compressor: self.compressor.clone(), - id, - span, - }; - - // Spawn the backend task - tokio::spawn(driver); - - self.to_driver = Some(to_driver); - } - - /// Internal method to initialize and spawn the driver. - fn spawn_driver(&mut self, endpoint: A, transport: T, conn_ctl: ConnCtl) { - // Initialize communication channels - let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); - - let timeout_check_interval = tokio::time::interval(self.options.timeout / 10); - - // TODO: we should limit the amount of active outgoing requests, and that should be the - // capacity. If we do this, we'll never have to re-allocate. - let pending_requests = FxHashMap::default(); - - let id = DRIVER_ID.fetch_add(1, Ordering::Relaxed); - let span = tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), addr = ?endpoint); - - let linger_timer = self.options.write_buffer_linger.map(|duration| { - let mut timer = tokio::time::interval(duration); - timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - timer - }); - - // Create connection manager - let conn_manager = ConnManager::new( - self.options.client.clone(), - transport, - endpoint, - conn_ctl, - Arc::clone(&self.state.transport_stats), - span.clone(), - ); - - // Create the socket backend - let driver: ReqDriver = ReqDriver { + let driver: ReqDriver = ReqDriver { options: Arc::clone(&self.options), socket_state: self.state.clone(), id_counter: 0, From 67651d3426d0b8f8f54734020d611eea93ebe01b Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Fri, 16 Jan 2026 12:39:37 +0100 Subject: [PATCH 4/9] feat(socket): add bind to ReqSocket --- msg-socket/src/req/conn_manager.rs | 80 +++++++++++++++++++++--------- msg-socket/src/req/socket.rs | 21 +++++++- 2 files changed, 77 insertions(+), 24 deletions(-) diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs index eeaacecd..22d614b1 100644 --- a/msg-socket/src/req/conn_manager.rs +++ b/msg-socket/src/req/conn_manager.rs @@ -49,7 +49,7 @@ where cx: &mut Context<'_>, ) -> Poll>>; - /// Resets the connection controller. + /// Resets the connection controller. Will close any active connections. fn reset(&mut self); /// Returns a mutable reference to the active connection, if it exists. @@ -241,9 +241,43 @@ where /// The accept task which handles accepting an incoming connection. accept_task: Option>, /// The inbound connection. - conn: Conn, + conn: Option>, } +impl ConnectionController for ServerConnection +where + T: Transport, + A: Address, +{ + fn poll( + &mut self, + transport: &mut T, + stats: &Arc>, + span: &tracing::Span, + cx: &mut Context<'_>, + ) -> Poll>> { + todo!() + // If connection is active, return it + // + // If no connection BUT accept task is active, poll it + // + // If no connection AND no accept task, create a new accept task, disable listener? + } + + fn reset(&mut self) { + if let Some(ref mut conn) = self.conn.take() { + // FIXME: This doesn't actually close the underlying connection, it just drops it. + // To actually close it, we'd need to poll the close future. + let _ = conn.close(); + } + } + + fn active_connection(&mut self) -> Option<&mut Conn> { + self.conn.as_mut() + } +} + +// Client-side connection manager implementations. impl ConnectionManager> where T: Transport, @@ -284,6 +318,27 @@ where } } +pub struct ServerOptions {} + +impl ConnectionManager> +where + T: Transport, + A: Address, +{ + pub(crate) fn new( + options: ServerOptions, + transport: T, + addr: A, + transport_stats: Arc>, + span: tracing::Span, + ) -> Self { + let conn = ServerConnection { addr, accept_task: None, conn: None }; + + Self { state: conn, transport, transport_stats, span } + } +} + +// Generic connection manager implementations. impl ConnectionManager where T: Transport, @@ -313,27 +368,6 @@ where } } -pub struct ServerOptions {} - -impl ConnectionManager> -where - T: Transport, - A: Address, -{ - pub(crate) fn new( - options: ServerOptions, - transport: T, - addr: A, - conn: Conn, - transport_stats: Arc>, - span: tracing::Span, - ) -> Self { - let conn = ServerConnection { addr, accept_task: None, conn }; - - Self { state: conn, transport, transport_stats, span } - } -} - /// Perform the authentication handshake with the server. #[tracing::instrument(skip_all, "auth", fields(token = ?token))] async fn authentication_handshake(mut io: T::Io, token: Bytes) -> Result diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 270964f9..eeeb18d4 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -23,7 +23,10 @@ use crate::{ ConnectionState, DRIVER_ID, ExponentialBackoff, ReqMessage, SendCommand, req::{ SocketState, - conn_manager::{ClientConnection, ConnectionController, ConnectionManager}, + conn_manager::{ + ClientConnection, ConnectionController, ConnectionManager, ServerConnection, + ServerOptions, + }, driver::ReqDriver, stats::ReqStats, }, @@ -84,6 +87,22 @@ where self.spawn(addr, conn_manager) } + + pub async fn bind(&mut self, addr: SocketAddr) -> Result<(), ReqError> { + let transport = self.transport.take().expect("Transport has been moved"); + + // Initialize server-side connection manager + let conn_manager = ConnectionManager::>::new( + ServerOptions {}, + transport, + addr.clone(), + Arc::clone(&self.state.transport_stats), + tracing::Span::none(), + ); + + self.spawn(addr, conn_manager); + Ok(()) + } } impl ReqSocket From 6c7d2c27419550390bc8103053a1bb4c8c6f1366 Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Fri, 16 Jan 2026 14:56:02 +0100 Subject: [PATCH 5/9] feat(socket): server side ReqSocket done --- msg-socket/src/req/conn_manager.rs | 87 +++++++++++++++++++----------- msg-socket/src/req/mod.rs | 2 + msg-socket/src/req/socket.rs | 22 +++++--- 3 files changed, 73 insertions(+), 38 deletions(-) diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs index 22d614b1..5000554c 100644 --- a/msg-socket/src/req/conn_manager.rs +++ b/msg-socket/src/req/conn_manager.rs @@ -8,13 +8,13 @@ use std::{ use arc_swap::ArcSwap; use bytes::Bytes; use futures::{Future, FutureExt, SinkExt, StreamExt}; -use msg_common::span::{EnterSpan as _, WithSpan}; +use msg_common::span::{EnterSpan as _, SpanExt, WithSpan}; use tokio_util::codec::Framed; use tracing::Instrument; use crate::{ClientOptions, ConnectionState, ExponentialBackoff}; -use msg_transport::{Address, MeteredIo, Transport}; +use msg_transport::{Address, MeteredIo, PeerAddress as _, Transport}; use msg_wire::{auth, reqrep}; /// A connection task that connects to a server and returns the underlying IO object. @@ -236,6 +236,9 @@ where T: Transport, A: Address, { + /// The server options. + #[allow(unused)] + options: ServerOptions, /// The local address. addr: A, /// The accept task which handles accepting an incoming connection. @@ -249,6 +252,10 @@ where T: Transport, A: Address, { + /// Poll the server-side connection controller. This will return: + /// - Poll::Ready(Some(conn)) if a connection is active. + /// - Poll::Pending if no connection is active and the accept task is pending. + /// - Poll::Ready(None) if no connection is active and the accept task is not pending. fn poll( &mut self, transport: &mut T, @@ -256,12 +263,46 @@ where span: &tracing::Span, cx: &mut Context<'_>, ) -> Poll>> { - todo!() - // If connection is active, return it - // - // If no connection BUT accept task is active, poll it - // - // If no connection AND no accept task, create a new accept task, disable listener? + let mut transport = Pin::new(transport); + loop { + // 1. If connection is active, return it + if self.conn.is_some() { + return Poll::Ready(self.conn.as_mut()); + } + + let _span = + tracing::info_span!(parent: span, "accept", local_addr = ?self.addr).entered(); + + // 2. If connection is not active, but we have an accept task, poll it. + if let Some(ref mut accept) = self.accept_task { + if let Poll::Ready(result) = accept.poll_unpin(cx).enter() { + match result.inner { + Ok(io) => { + tracing::debug!(peer_addr = ?io.peer_addr(), "Accepted connection"); + let metered = MeteredIo::new(io, stats.clone()); + let framed = Framed::new(metered, reqrep::Codec::new()); + + self.conn = Some(framed); + return Poll::Ready(self.conn.as_mut()); + } + Err(err) => { + tracing::error!("Accept error: {err:?}"); + self.accept_task = None; + } + } + } + } + + // 3. Create a new accept task + if let Poll::Ready(accept_task) = transport.as_mut().poll_accept(cx) { + self.accept_task = Some(accept_task.with_current_span()); + + // Continue to poll the accept task + continue; + } + + return Poll::Pending; + } } fn reset(&mut self) { @@ -295,27 +336,6 @@ where Self { state: conn, transport, transport_stats, span } } - - /// Start the connection task to the server, handling authentication if necessary. - /// The result will be polled by the driver and re-tried according to the backoff policy. - fn try_connect(&mut self) { - let connect = self.transport.connect(self.state.addr.clone()); - let token = self.state.options.auth_token.clone(); - - let task = async move { - let io = connect.await?; - - let Some(token) = token else { - return Ok(io); - }; - - authentication_handshake::(io, token).await - } - .in_current_span(); - - // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] - self.state.conn_task = Some(WithSpan::current(Box::pin(task))); - } } pub struct ServerOptions {} @@ -325,6 +345,7 @@ where T: Transport, A: Address, { + /// Create a new server-side connection manager. pub(crate) fn new( options: ServerOptions, transport: T, @@ -332,10 +353,16 @@ where transport_stats: Arc>, span: tracing::Span, ) -> Self { - let conn = ServerConnection { addr, accept_task: None, conn: None }; + debug_assert!(transport.local_addr().is_some(), "Transport must be bound"); + let conn = ServerConnection { options, addr, accept_task: None, conn: None }; Self { state: conn, transport, transport_stats, span } } + + /// Bind the socket to the given address. + pub(crate) async fn bind(&mut self, addr: A) -> Result<(), T::Error> { + self.transport.bind(addr).await + } } // Generic connection manager implementations. diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index 263558f7..d13b4c3b 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -45,6 +45,8 @@ pub enum ReqError { NoValidEndpoints, #[error("Failed to connect to the target endpoint: {0:?}")] Connect(Box), + #[error("Failed to bind to the socket address")] + Bind(Box), } /// A command to send a request message and wait for a response. diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index eeeb18d4..8660d384 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -79,7 +79,7 @@ where let conn_manager = ConnectionManager::>::new( self.options.client.clone(), transport, - addr.clone(), + addr, conn_state, Arc::clone(&self.state.transport_stats), tracing::Span::none(), @@ -88,17 +88,23 @@ where self.spawn(addr, conn_manager) } + /// Bind the socket to the given address. pub async fn bind(&mut self, addr: SocketAddr) -> Result<(), ReqError> { let transport = self.transport.take().expect("Transport has been moved"); // Initialize server-side connection manager - let conn_manager = ConnectionManager::>::new( - ServerOptions {}, - transport, - addr.clone(), - Arc::clone(&self.state.transport_stats), - tracing::Span::none(), - ); + let mut conn_manager = + ConnectionManager::>::new( + // TODO: Server options from config + ServerOptions {}, + transport, + addr, + Arc::clone(&self.state.transport_stats), + tracing::Span::none(), + ); + + // Bind the connection manager + conn_manager.bind(addr).await.map_err(|e| ReqError::Bind(e.into()))?; self.spawn(addr, conn_manager); Ok(()) From 40b70811f959fbdb36f4f3251370f30a9afaf295 Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Fri, 16 Jan 2026 14:59:07 +0100 Subject: [PATCH 6/9] chore(socket): clippy --- msg-socket/src/req/conn_manager.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs index 5000554c..5960ba85 100644 --- a/msg-socket/src/req/conn_manager.rs +++ b/msg-socket/src/req/conn_manager.rs @@ -41,6 +41,7 @@ where A: Address, { /// Polls the connection logic, and returns a mutable reference to the connection if it's ready. + #[allow(clippy::type_complexity)] fn poll( &mut self, transport: &mut T, @@ -306,10 +307,10 @@ where } fn reset(&mut self) { - if let Some(ref mut conn) = self.conn.take() { + if let Some(ref mut _conn) = self.conn.take() { // FIXME: This doesn't actually close the underlying connection, it just drops it. // To actually close it, we'd need to poll the close future. - let _ = conn.close(); + // let _ = _conn.close().await; } } From 487966a1a8d56e72f312220d3684dfade4d3fcaf Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Fri, 16 Jan 2026 16:21:32 +0100 Subject: [PATCH 7/9] feat(socket): genericize connection manager --- msg-socket/src/req/conn_manager.rs | 81 ++++++++++++++++++------------ msg-socket/src/req/driver.rs | 17 +++---- msg-socket/src/req/socket.rs | 38 ++++++++------ rustfmt.toml | 4 ++ 4 files changed, 82 insertions(+), 58 deletions(-) diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs index 5960ba85..6fcffbe1 100644 --- a/msg-socket/src/req/conn_manager.rs +++ b/msg-socket/src/req/conn_manager.rs @@ -9,13 +9,15 @@ use arc_swap::ArcSwap; use bytes::Bytes; use futures::{Future, FutureExt, SinkExt, StreamExt}; use msg_common::span::{EnterSpan as _, SpanExt, WithSpan}; +use msg_transport::{Address, MeteredIo, PeerAddress as _, Transport}; +use msg_wire::auth; use tokio_util::codec::Framed; use tracing::Instrument; use crate::{ClientOptions, ConnectionState, ExponentialBackoff}; -use msg_transport::{Address, MeteredIo, PeerAddress as _, Transport}; -use msg_wire::{auth, reqrep}; +/// Type alias for a factory function that creates a codec. +type CodecFactory = Box C + Send>; /// A connection task that connects to a server and returns the underlying IO object. type ConnTask = Pin> + Send>>; @@ -29,13 +31,13 @@ type ConnTask = Pin> + Send>>; /// /// However, we don't use `poll_ready` here, and instead we flush every time we write a message to /// the framed buffer. -pub(crate) type Conn = Framed, reqrep::Codec>; +pub(crate) type Conn = Framed, C>; /// A connection controller that manages the connection to a server with an exponential backoff. -pub(crate) type ConnCtl = ConnectionState, ExponentialBackoff, A>; +pub(crate) type ConnCtl = ConnectionState, ExponentialBackoff, A>; /// Trait for interacting with the connection, regardless of its "side" (client or server). -pub(crate) trait ConnectionController +pub(crate) trait ConnectionController where T: Transport, A: Address, @@ -47,19 +49,20 @@ where transport: &mut T, stats: &Arc>, span: &tracing::Span, + make_codec: &impl Fn() -> C, cx: &mut Context<'_>, - ) -> Poll>>; + ) -> Poll>>; /// Resets the connection controller. Will close any active connections. fn reset(&mut self); /// Returns a mutable reference to the active connection, if it exists. - fn active_connection(&mut self) -> Option<&mut Conn>; + fn active_connection(&mut self) -> Option<&mut Conn>; } /// A connection manager for managing client OR server connections. /// The type parameter `S` contains the connection state, including its "side" (client / server). -pub(crate) struct ConnectionManager +pub(crate) struct ConnectionManager where T: Transport, A: Address, @@ -70,12 +73,14 @@ where transport: T, /// Transport stats for metering IO. transport_stats: Arc>, + /// Factory function for creating a codec. + make_codec: CodecFactory, /// Connection manager tracing span. span: tracing::Span, } -impl ConnectionManager +impl ConnectionManager where T: Transport, A: Address, @@ -87,8 +92,8 @@ where } } -/// A client connection to a remote server. -pub(crate) struct ClientConnection +/// A client connection to a remote server. Generic over transport, address type and codec. +pub(crate) struct ClientConnection where T: Transport, A: Address, @@ -101,10 +106,10 @@ where conn_task: Option>>, /// The transport controller, wrapped in a [`ConnectionState`] for backoff. /// The [`Framed`] object can send and receive messages from the socket. - conn_ctl: ConnCtl, + conn_ctl: ConnCtl, } -impl ConnectionController for ClientConnection +impl ConnectionController for ClientConnection where T: Transport, A: Address, @@ -124,8 +129,9 @@ where transport: &mut T, stats: &Arc>, span: &tracing::Span, + make_codec: &impl Fn() -> C, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { loop { // Poll the active connection task, if any if let Some(ref mut conn_task) = self.conn_task { @@ -140,7 +146,7 @@ where tracing::info!("connected"); let metered = MeteredIo::new(io, stats.clone()); - let framed = Framed::new(metered, reqrep::Codec::new()); + let framed = Framed::new(metered, make_codec()); self.conn_ctl = ConnectionState::Active { channel: framed }; } Err(e) => { @@ -193,7 +199,7 @@ where } /// Returns a mutable reference to the active connection, if it exists. - fn active_connection(&mut self) -> Option<&mut Conn> { + fn active_connection(&mut self) -> Option<&mut Conn> { if let ConnectionState::Active { ref mut channel } = self.conn_ctl { Some(channel) } else { @@ -202,7 +208,7 @@ where } } -impl ClientConnection +impl ClientConnection where T: Transport, A: Address, @@ -232,7 +238,7 @@ where /// A local server connection. Manages the connection lifecycle: /// - Accepting incoming connections. /// - Handling established connections. -pub(crate) struct ServerConnection +pub(crate) struct ServerConnection where T: Transport, A: Address, @@ -245,10 +251,10 @@ where /// The accept task which handles accepting an incoming connection. accept_task: Option>, /// The inbound connection. - conn: Option>, + conn: Option>, } -impl ConnectionController for ServerConnection +impl ConnectionController for ServerConnection where T: Transport, A: Address, @@ -262,8 +268,9 @@ where transport: &mut T, stats: &Arc>, span: &tracing::Span, + make_codec: &impl Fn() -> C, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let mut transport = Pin::new(transport); loop { // 1. If connection is active, return it @@ -281,7 +288,7 @@ where Ok(io) => { tracing::debug!(peer_addr = ?io.peer_addr(), "Accepted connection"); let metered = MeteredIo::new(io, stats.clone()); - let framed = Framed::new(metered, reqrep::Codec::new()); + let framed = Framed::new(metered, make_codec()); self.conn = Some(framed); return Poll::Ready(self.conn.as_mut()); @@ -314,13 +321,13 @@ where } } - fn active_connection(&mut self) -> Option<&mut Conn> { + fn active_connection(&mut self) -> Option<&mut Conn> { self.conn.as_mut() } } // Client-side connection manager implementations. -impl ConnectionManager> +impl ConnectionManager, C> where T: Transport, A: Address, @@ -329,19 +336,20 @@ where options: ClientOptions, transport: T, addr: A, - conn_ctl: ConnCtl, + conn_ctl: ConnCtl, transport_stats: Arc>, + make_codec: CodecFactory, span: tracing::Span, ) -> Self { let conn = ClientConnection { options, addr, conn_task: None, conn_ctl }; - Self { state: conn, transport, transport_stats, span } + Self { state: conn, transport, transport_stats, make_codec, span } } } pub struct ServerOptions {} -impl ConnectionManager> +impl ConnectionManager, C> where T: Transport, A: Address, @@ -352,12 +360,13 @@ where transport: T, addr: A, transport_stats: Arc>, + make_codec: CodecFactory, span: tracing::Span, ) -> Self { debug_assert!(transport.local_addr().is_some(), "Transport must be bound"); let conn = ServerConnection { options, addr, accept_task: None, conn: None }; - Self { state: conn, transport, transport_stats, span } + Self { state: conn, transport, transport_stats, make_codec, span } } /// Bind the socket to the given address. @@ -367,11 +376,11 @@ where } // Generic connection manager implementations. -impl ConnectionManager +impl ConnectionManager where T: Transport, A: Address, - C: ConnectionController, + Ctr: ConnectionController, { /// Reset the connection state to inactive, so that it will be re-tried. /// @@ -386,12 +395,18 @@ where pub(crate) fn poll( &mut self, cx: &mut Context<'_>, - ) -> Poll>> { - self.state.poll(&mut self.transport, &self.transport_stats, &self.span, cx) + ) -> Poll>> { + self.state.poll( + &mut self.transport, + &self.transport_stats, + &self.span, + &self.make_codec, + cx, + ) } /// Returns a mutable reference to the active connection, if it exists. - pub(crate) fn active_connection(&mut self) -> Option<&mut Conn> { + pub(crate) fn active_connection(&mut self) -> Option<&mut Conn> { self.state.active_connection() } } diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index 74dd4878..cac9f5d5 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -8,6 +8,12 @@ use std::{ use bytes::Bytes; use futures::{Future, SinkExt, StreamExt}; +use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan}; +use msg_transport::{Address, Transport}; +use msg_wire::{ + compression::{Compressor, try_decompress_payload}, + reqrep, +}; use rustc_hash::FxHashMap; use tokio::{ sync::{mpsc, oneshot}, @@ -23,13 +29,6 @@ use crate::{ }, }; -use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan}; -use msg_transport::{Address, Transport}; -use msg_wire::{ - compression::{Compressor, try_decompress_payload}, - reqrep, -}; - /// The request socket driver. Endless future that drives /// the socket forward. pub(crate) struct ReqDriver @@ -46,7 +45,7 @@ where /// Commands from the socket. pub(crate) from_socket: mpsc::Receiver, /// Connection manager that handles connection lifecycle. - pub(crate) conn_manager: ConnectionManager, + pub(crate) conn_manager: ConnectionManager, /// The timer for the write buffer linger. pub(crate) linger_timer: Option, /// The outgoing message queue. @@ -176,7 +175,7 @@ impl Future for ReqDriver where T: Transport, A: Address, - S: ConnectionController + Unpin, + S: ConnectionController + Unpin, { type Output = (); diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 8660d384..9375192a 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -7,6 +7,12 @@ use std::{ use arc_swap::Guard; use bytes::Bytes; +use msg_common::span::WithSpan; +use msg_transport::{Address, MeteredIo, Transport}; +use msg_wire::{ + compression::Compressor, + reqrep::{self, Codec}, +}; use rustc_hash::FxHashMap; use tokio::{ net::{ToSocketAddrs, lookup_host}, @@ -14,10 +20,6 @@ use tokio::{ }; use tokio_util::codec::Framed; -use msg_common::span::WithSpan; -use msg_transport::{Address, MeteredIo, Transport}; -use msg_wire::{compression::Compressor, reqrep}; - use super::{DEFAULT_BUFFER_SIZE, ReqError, ReqOptions}; use crate::{ ConnectionState, DRIVER_ID, ExponentialBackoff, ReqMessage, SendCommand, @@ -76,14 +78,16 @@ where }; // Initialize client-side connection manager - let conn_manager = ConnectionManager::>::new( - self.options.client.clone(), - transport, - addr, - conn_state, - Arc::clone(&self.state.transport_stats), - tracing::Span::none(), - ); + let conn_manager = + ConnectionManager::, Codec>::new( + self.options.client.clone(), + transport, + addr, + conn_state, + Arc::clone(&self.state.transport_stats), + Box::new(|| Codec::new()), + tracing::Span::none(), + ); self.spawn(addr, conn_manager) } @@ -94,12 +98,13 @@ where // Initialize server-side connection manager let mut conn_manager = - ConnectionManager::>::new( + ConnectionManager::, Codec>::new( // TODO: Server options from config ServerOptions {}, transport, addr, Arc::clone(&self.state.transport_stats), + Box::new(|| Codec::new()), tracing::Span::none(), ); @@ -198,12 +203,13 @@ where }; // Initialize client-side connection manager - let conn_manager = ConnectionManager::>::new( + let conn_manager = ConnectionManager::, Codec>::new( self.options.client.clone(), transport, endpoint.clone(), conn_state, Arc::clone(&self.state.transport_stats), + Box::new(|| Codec::new()), tracing::Span::none(), ); @@ -213,10 +219,10 @@ where Ok(()) } - fn spawn + Send + Unpin + 'static>( + fn spawn + Send + Unpin + 'static>( &mut self, addr: A, - mut conn_manager: ConnectionManager, + mut conn_manager: ConnectionManager, ) { // Initialize communication channels let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); diff --git a/rustfmt.toml b/rustfmt.toml index 68c3c930..e767bbd1 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,5 +1,9 @@ +# Import ordering +group_imports = "StdExternalCrate" reorder_imports = true imports_granularity = "Crate" + +# Other use_small_heuristics = "Max" comment_width = 100 wrap_comments = true From 26e6dcd59b287fdf100d8d16d47b81bd6472fc1a Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Fri, 16 Jan 2026 16:24:55 +0100 Subject: [PATCH 8/9] refactor(socket): move connection manager to connection module --- .../{req/conn_manager.rs => connection/manager.rs} | 0 msg-socket/src/connection/mod.rs | 3 +++ msg-socket/src/req/driver.rs | 6 ++---- msg-socket/src/req/mod.rs | 8 +++----- msg-socket/src/req/socket.rs | 11 +++-------- 5 files changed, 11 insertions(+), 17 deletions(-) rename msg-socket/src/{req/conn_manager.rs => connection/manager.rs} (100%) diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/connection/manager.rs similarity index 100% rename from msg-socket/src/req/conn_manager.rs rename to msg-socket/src/connection/manager.rs diff --git a/msg-socket/src/connection/mod.rs b/msg-socket/src/connection/mod.rs index e16d88bb..d8a8ba5b 100644 --- a/msg-socket/src/connection/mod.rs +++ b/msg-socket/src/connection/mod.rs @@ -3,3 +3,6 @@ pub use state::ConnectionState; pub mod backoff; pub use backoff::{Backoff, ExponentialBackoff}; + +mod manager; +pub use manager::*; diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index cac9f5d5..b3aac64e 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -23,10 +23,8 @@ use tokio::{ use super::{ReqError, ReqOptions}; use crate::{ SendCommand, - req::{ - SocketState, - conn_manager::{ConnectionController, ConnectionManager}, - }, + connection::{ConnectionController, ConnectionManager}, + req::SocketState, }; /// The request socket driver. Endless future that drives diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index d13b4c3b..53e35d45 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -5,23 +5,21 @@ use std::{ use arc_swap::ArcSwap; use bytes::Bytes; -use thiserror::Error; -use tokio::sync::oneshot; - use msg_common::{constants::KiB, span::WithSpan}; use msg_wire::{ compression::{CompressionType, Compressor}, reqrep, }; +use thiserror::Error; +use tokio::sync::oneshot; -mod conn_manager; mod driver; mod socket; mod stats; pub use socket::*; +use stats::ReqStats; use crate::{Profile, stats::SocketStats}; -use stats::ReqStats; /// The default buffer size for the socket. const DEFAULT_BUFFER_SIZE: usize = 1024; diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 9375192a..118b636a 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -23,15 +23,10 @@ use tokio_util::codec::Framed; use super::{DEFAULT_BUFFER_SIZE, ReqError, ReqOptions}; use crate::{ ConnectionState, DRIVER_ID, ExponentialBackoff, ReqMessage, SendCommand, - req::{ - SocketState, - conn_manager::{ - ClientConnection, ConnectionController, ConnectionManager, ServerConnection, - ServerOptions, - }, - driver::ReqDriver, - stats::ReqStats, + connection::{ + ClientConnection, ConnectionController, ConnectionManager, ServerConnection, ServerOptions, }, + req::{SocketState, driver::ReqDriver, stats::ReqStats}, stats::SocketStats, }; From 27e07a6ad817623957b0c02a2dc51ccc9decb1af Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Fri, 16 Jan 2026 16:58:03 +0100 Subject: [PATCH 9/9] feat(socket): inbound authentication on ConnectionManager --- msg-socket/src/connection/manager.rs | 86 +++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 9 deletions(-) diff --git a/msg-socket/src/connection/manager.rs b/msg-socket/src/connection/manager.rs index 6fcffbe1..6d079132 100644 --- a/msg-socket/src/connection/manager.rs +++ b/msg-socket/src/connection/manager.rs @@ -14,13 +14,14 @@ use msg_wire::auth; use tokio_util::codec::Framed; use tracing::Instrument; -use crate::{ClientOptions, ConnectionState, ExponentialBackoff}; +use crate::{Authenticator, ClientOptions, ConnectionState, ExponentialBackoff}; /// Type alias for a factory function that creates a codec. type CodecFactory = Box C + Send>; -/// A connection task that connects to a server and returns the underlying IO object. -type ConnTask = Pin> + Send>>; +/// A connection setup task that connects to a server or handles a connection from a client and +/// returns the underlying IO object. +type ConnSetup = Pin> + Send>>; /// A connection from the transport to a server. /// @@ -103,7 +104,7 @@ where /// The address of the remote. addr: A, /// The connection task which handles the connection to the server. - conn_task: Option>>, + conn_task: Option>>, /// The transport controller, wrapped in a [`ConnectionState`] for backoff. /// The [`Framed`] object can send and receive messages from the socket. conn_ctl: ConnCtl, @@ -226,7 +227,7 @@ where return Ok(io); }; - authentication_handshake::(io, token).await + outbound_handshake::(io, token).await } .in_current_span(); @@ -248,8 +249,10 @@ where options: ServerOptions, /// The local address. addr: A, + /// The optional authenticator. + authenticator: Option>, /// The accept task which handles accepting an incoming connection. - accept_task: Option>, + accept_task: Option>>, /// The inbound connection. conn: Option>, } @@ -303,7 +306,22 @@ where // 3. Create a new accept task if let Poll::Ready(accept_task) = transport.as_mut().poll_accept(cx) { - self.accept_task = Some(accept_task.with_current_span()); + // NOTE: Compiler needs some help here. + let task: ConnSetup = + // If we have an authenticator, create a task that performs the inbound handshake. + if let Some(ref authenticator) = self.authenticator { + let authenticator = authenticator.clone(); + Box::pin(async move { + let io = accept_task.await?; + + inbound_handshake::(io, &authenticator).await + }) + } else { + // Otherwise just accept the connection as-is. + Box::pin(async move { accept_task.await }) + }; + + self.accept_task = Some(task.with_current_span()); // Continue to poll the accept task continue; @@ -359,12 +377,13 @@ where options: ServerOptions, transport: T, addr: A, + authenticator: Option>, transport_stats: Arc>, make_codec: CodecFactory, span: tracing::Span, ) -> Self { debug_assert!(transport.local_addr().is_some(), "Transport must be bound"); - let conn = ServerConnection { options, addr, accept_task: None, conn: None }; + let conn = ServerConnection { options, addr, authenticator, accept_task: None, conn: None }; Self { state: conn, transport, transport_stats, make_codec, span } } @@ -413,7 +432,7 @@ where /// Perform the authentication handshake with the server. #[tracing::instrument(skip_all, "auth", fields(token = ?token))] -async fn authentication_handshake(mut io: T::Io, token: Bytes) -> Result +async fn outbound_handshake(mut io: T::Io, token: Bytes) -> Result where T: Transport, A: Address, @@ -440,3 +459,52 @@ where Err(e) => Err(io::Error::new(io::ErrorKind::PermissionDenied, e).into()), } } + +/// Perform the authentication handshake with the client +#[tracing::instrument(skip_all, "auth")] +async fn inbound_handshake( + mut io: T::Io, + authenticator: &Arc, +) -> Result +where + T: Transport, + A: Address, +{ + let mut conn = Framed::new(&mut io, auth::Codec::new_server()); + + // Wait for the response + let Some(res) = conn.next().await else { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "connection closed").into()); + }; + match res { + Ok(auth::Message::Auth(token)) => { + tracing::debug!(?token, "auth received"); + // If authentication fails, send a reject message and close the connection + if !authenticator.authenticate(&token) { + conn.send(auth::Message::Reject).await?; + conn.close().await?; + return Err(abort().into()) + } + + // Send ack + conn.send(auth::Message::Ack).await?; + + return Ok(io); + } + Ok(msg) => { + tracing::debug!(?msg, "unexpected message during authentication"); + conn.send(auth::Message::Reject).await?; + conn.close().await?; + Err(abort().into()) + } + Err(e) => { + tracing::error!(?e, "error during authentication"); + Err(abort().into()) + } + } +} + +// Helper function for an abort error +fn abort() -> io::Error { + io::Error::new(io::ErrorKind::Other, "authentication failed, connection aborted") +}