diff --git a/Cargo.toml b/Cargo.toml index 9284321e..93d807d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -173,6 +173,7 @@ unused_rounding = "warn" use_self = "warn" useless_let_if_seq = "warn" zero_sized_map_values = "warn" +default_trait_access = "warn" # These are nursery lints which have findings. Allow them for now. Some are not # quite mature enough for use in our codebase and some we don't really want. diff --git a/msg-socket/src/connection/backoff.rs b/msg-socket/src/connection/backoff.rs index fbca13e3..002624dd 100644 --- a/msg-socket/src/connection/backoff.rs +++ b/msg-socket/src/connection/backoff.rs @@ -6,7 +6,7 @@ use std::{ }; use tokio::time::sleep; -use crate::ConnOptions; +use crate::ClientOptions; /// Helper trait alias for backoff streams. /// We define any stream that yields `Duration`s as a backoff @@ -41,8 +41,8 @@ impl ExponentialBackoff { } } -impl From<&ConnOptions> for ExponentialBackoff { - fn from(options: &ConnOptions) -> Self { +impl From<&ClientOptions> for ExponentialBackoff { + fn from(options: &ClientOptions) -> Self { Self::new(options.backoff_duration, options.retry_attempts) } } diff --git a/msg-socket/src/connection/manager.rs b/msg-socket/src/connection/manager.rs new file mode 100644 index 00000000..6d079132 --- /dev/null +++ b/msg-socket/src/connection/manager.rs @@ -0,0 +1,510 @@ +use std::{ + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use arc_swap::ArcSwap; +use bytes::Bytes; +use futures::{Future, FutureExt, SinkExt, StreamExt}; +use msg_common::span::{EnterSpan as _, SpanExt, WithSpan}; +use msg_transport::{Address, MeteredIo, PeerAddress as _, Transport}; +use msg_wire::auth; +use tokio_util::codec::Framed; +use tracing::Instrument; + +use crate::{Authenticator, ClientOptions, ConnectionState, ExponentialBackoff}; + +/// Type alias for a factory function that creates a codec. +type CodecFactory = Box C + Send>; + +/// A connection setup task that connects to a server or handles a connection from a client and +/// returns the underlying IO object. +type ConnSetup = Pin> + Send>>; + +/// A connection from the transport to a server. +/// +/// # Usage of Framed +/// [`Framed`] is used for encoding and decoding messages ("frames"). +/// Usually, [`Framed`] has its own internal buffering mechanism, that's respected +/// when calling `poll_ready` and configured by [`Framed::set_backpressure_boundary`]. +/// +/// However, we don't use `poll_ready` here, and instead we flush every time we write a message to +/// the framed buffer. +pub(crate) type Conn = Framed, C>; + +/// A connection controller that manages the connection to a server with an exponential backoff. +pub(crate) type ConnCtl = ConnectionState, ExponentialBackoff, A>; + +/// Trait for interacting with the connection, regardless of its "side" (client or server). +pub(crate) trait ConnectionController +where + T: Transport, + A: Address, +{ + /// Polls the connection logic, and returns a mutable reference to the connection if it's ready. + #[allow(clippy::type_complexity)] + fn poll( + &mut self, + transport: &mut T, + stats: &Arc>, + span: &tracing::Span, + make_codec: &impl Fn() -> C, + cx: &mut Context<'_>, + ) -> Poll>>; + + /// Resets the connection controller. Will close any active connections. + fn reset(&mut self); + + /// Returns a mutable reference to the active connection, if it exists. + fn active_connection(&mut self) -> Option<&mut Conn>; +} + +/// A connection manager for managing client OR server connections. +/// The type parameter `S` contains the connection state, including its "side" (client / server). +pub(crate) struct ConnectionManager +where + T: Transport, + A: Address, +{ + /// The connection state, including its "side" (client / server). + state: S, + /// The transport used for the connection. + transport: T, + /// Transport stats for metering IO. + transport_stats: Arc>, + /// Factory function for creating a codec. + make_codec: CodecFactory, + + /// Connection manager tracing span. + span: tracing::Span, +} + +impl ConnectionManager +where + T: Transport, + A: Address, +{ + /// Set the connection manager tracing span. + pub(crate) fn with_span(mut self, span: tracing::Span) -> Self { + self.span = span; + self + } +} + +/// A client connection to a remote server. Generic over transport, address type and codec. +pub(crate) struct ClientConnection +where + T: Transport, + A: Address, +{ + /// Options for the connection manager. + options: ClientOptions, + /// The address of the remote. + addr: A, + /// The connection task which handles the connection to the server. + conn_task: Option>>, + /// The transport controller, wrapped in a [`ConnectionState`] for backoff. + /// The [`Framed`] object can send and receive messages from the socket. + conn_ctl: ConnCtl, +} + +impl ConnectionController for ClientConnection +where + T: Transport, + A: Address, +{ + /// Poll connection management logic: connection task, backoff, and retry logic. + /// Loops until the connection is active, then returns a mutable reference to the channel. + /// + /// Note: this is not a `Future` impl because we want to return a reference; doing it in + /// a `Future` would require lifetime headaches or unsafe code. + /// + /// Returns: + /// * `Poll::Ready(Some(&mut channel))` if the connection is active + /// * `Poll::Ready(None)` if we should terminate (max retries exceeded) + /// * `Poll::Pending` if we need to wait for backoff + fn poll( + &mut self, + transport: &mut T, + stats: &Arc>, + span: &tracing::Span, + make_codec: &impl Fn() -> C, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + // Poll the active connection task, if any + if let Some(ref mut conn_task) = self.conn_task { + if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() { + // As soon as the connection task finishes, set it to `None`. + // - If it was successful, set the connection to active + // - If it failed, it will be re-tried until the backoff limit is reached. + self.conn_task = None; + + match result.inner { + Ok(io) => { + tracing::info!("connected"); + + let metered = MeteredIo::new(io, stats.clone()); + let framed = Framed::new(metered, make_codec()); + self.conn_ctl = ConnectionState::Active { channel: framed }; + } + Err(e) => { + tracing::error!(?e, "failed to connect"); + } + } + } + } + + // If the connection is inactive, try to connect to the server or poll the backoff + // timer if we're already trying to connect. + if let ConnectionState::Inactive { backoff, .. } = &mut self.conn_ctl { + let Poll::Ready(item) = backoff.poll_next_unpin(cx) else { + return Poll::Pending; + }; + + let _span = tracing::info_span!(parent: span, "connect").entered(); + + if let Some(duration) = item { + if self.conn_task.is_none() { + tracing::debug!(backoff = ?duration, "trying connection"); + self.try_connect(transport); + } else { + tracing::debug!( + backoff = ?duration, + "not retrying as there is already a connection task" + ); + } + } else { + tracing::error!("exceeded maximum number of retries, terminating connection"); + return Poll::Ready(None); + } + } + + if let ConnectionState::Active { ref mut channel } = self.conn_ctl { + return Poll::Ready(Some(channel)); + } + } + } + + /// Reset the connection state to inactive, so that it will be re-tried. + /// + /// This is done when the connection is closed or an error occurs. + #[inline] + fn reset(&mut self) { + self.conn_ctl = ConnectionState::Inactive { + addr: self.addr.clone(), + backoff: ExponentialBackoff::from(&self.options), + }; + } + + /// Returns a mutable reference to the active connection, if it exists. + fn active_connection(&mut self) -> Option<&mut Conn> { + if let ConnectionState::Active { ref mut channel } = self.conn_ctl { + Some(channel) + } else { + None + } + } +} + +impl ClientConnection +where + T: Transport, + A: Address, +{ + /// Start the connection task to the server, handling authentication if necessary. + /// The result will be polled by the driver and re-tried according to the backoff policy. + fn try_connect(&mut self, transport: &mut T) { + let connect = transport.connect(self.addr.clone()); + let token = self.options.auth_token.clone(); + + let task = async move { + let io = connect.await?; + + let Some(token) = token else { + return Ok(io); + }; + + outbound_handshake::(io, token).await + } + .in_current_span(); + + // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] + self.conn_task = Some(WithSpan::current(Box::pin(task))); + } +} + +/// A local server connection. Manages the connection lifecycle: +/// - Accepting incoming connections. +/// - Handling established connections. +pub(crate) struct ServerConnection +where + T: Transport, + A: Address, +{ + /// The server options. + #[allow(unused)] + options: ServerOptions, + /// The local address. + addr: A, + /// The optional authenticator. + authenticator: Option>, + /// The accept task which handles accepting an incoming connection. + accept_task: Option>>, + /// The inbound connection. + conn: Option>, +} + +impl ConnectionController for ServerConnection +where + T: Transport, + A: Address, +{ + /// Poll the server-side connection controller. This will return: + /// - Poll::Ready(Some(conn)) if a connection is active. + /// - Poll::Pending if no connection is active and the accept task is pending. + /// - Poll::Ready(None) if no connection is active and the accept task is not pending. + fn poll( + &mut self, + transport: &mut T, + stats: &Arc>, + span: &tracing::Span, + make_codec: &impl Fn() -> C, + cx: &mut Context<'_>, + ) -> Poll>> { + let mut transport = Pin::new(transport); + loop { + // 1. If connection is active, return it + if self.conn.is_some() { + return Poll::Ready(self.conn.as_mut()); + } + + let _span = + tracing::info_span!(parent: span, "accept", local_addr = ?self.addr).entered(); + + // 2. If connection is not active, but we have an accept task, poll it. + if let Some(ref mut accept) = self.accept_task { + if let Poll::Ready(result) = accept.poll_unpin(cx).enter() { + match result.inner { + Ok(io) => { + tracing::debug!(peer_addr = ?io.peer_addr(), "Accepted connection"); + let metered = MeteredIo::new(io, stats.clone()); + let framed = Framed::new(metered, make_codec()); + + self.conn = Some(framed); + return Poll::Ready(self.conn.as_mut()); + } + Err(err) => { + tracing::error!("Accept error: {err:?}"); + self.accept_task = None; + } + } + } + } + + // 3. Create a new accept task + if let Poll::Ready(accept_task) = transport.as_mut().poll_accept(cx) { + // NOTE: Compiler needs some help here. + let task: ConnSetup = + // If we have an authenticator, create a task that performs the inbound handshake. + if let Some(ref authenticator) = self.authenticator { + let authenticator = authenticator.clone(); + Box::pin(async move { + let io = accept_task.await?; + + inbound_handshake::(io, &authenticator).await + }) + } else { + // Otherwise just accept the connection as-is. + Box::pin(async move { accept_task.await }) + }; + + self.accept_task = Some(task.with_current_span()); + + // Continue to poll the accept task + continue; + } + + return Poll::Pending; + } + } + + fn reset(&mut self) { + if let Some(ref mut _conn) = self.conn.take() { + // FIXME: This doesn't actually close the underlying connection, it just drops it. + // To actually close it, we'd need to poll the close future. + // let _ = _conn.close().await; + } + } + + fn active_connection(&mut self) -> Option<&mut Conn> { + self.conn.as_mut() + } +} + +// Client-side connection manager implementations. +impl ConnectionManager, C> +where + T: Transport, + A: Address, +{ + pub(crate) fn new( + options: ClientOptions, + transport: T, + addr: A, + conn_ctl: ConnCtl, + transport_stats: Arc>, + make_codec: CodecFactory, + span: tracing::Span, + ) -> Self { + let conn = ClientConnection { options, addr, conn_task: None, conn_ctl }; + + Self { state: conn, transport, transport_stats, make_codec, span } + } +} + +pub struct ServerOptions {} + +impl ConnectionManager, C> +where + T: Transport, + A: Address, +{ + /// Create a new server-side connection manager. + pub(crate) fn new( + options: ServerOptions, + transport: T, + addr: A, + authenticator: Option>, + transport_stats: Arc>, + make_codec: CodecFactory, + span: tracing::Span, + ) -> Self { + debug_assert!(transport.local_addr().is_some(), "Transport must be bound"); + let conn = ServerConnection { options, addr, authenticator, accept_task: None, conn: None }; + + Self { state: conn, transport, transport_stats, make_codec, span } + } + + /// Bind the socket to the given address. + pub(crate) async fn bind(&mut self, addr: A) -> Result<(), T::Error> { + self.transport.bind(addr).await + } +} + +// Generic connection manager implementations. +impl ConnectionManager +where + T: Transport, + A: Address, + Ctr: ConnectionController, +{ + /// Reset the connection state to inactive, so that it will be re-tried. + /// + /// This is done when the connection is closed or an error occurs. + #[inline] + pub(crate) fn reset_connection(&mut self) { + self.state.reset(); + } + + /// Poll the connection controller. + #[allow(clippy::type_complexity)] + pub(crate) fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + self.state.poll( + &mut self.transport, + &self.transport_stats, + &self.span, + &self.make_codec, + cx, + ) + } + + /// Returns a mutable reference to the active connection, if it exists. + pub(crate) fn active_connection(&mut self) -> Option<&mut Conn> { + self.state.active_connection() + } +} + +/// Perform the authentication handshake with the server. +#[tracing::instrument(skip_all, "auth", fields(token = ?token))] +async fn outbound_handshake(mut io: T::Io, token: Bytes) -> Result +where + T: Transport, + A: Address, +{ + let mut conn = Framed::new(&mut io, auth::Codec::new_client()); + + conn.send(auth::Message::Auth(token)).await?; + tracing::debug!("sent auth, waiting ack from server"); + + // Wait for the response + let Some(res) = conn.next().await else { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "connection closed").into()); + }; + + match res { + Ok(auth::Message::Ack) => { + tracing::debug!("received ack"); + Ok(io) + } + Ok(msg) => { + tracing::error!(?msg, "unexpected ack result"); + Err(io::Error::new(io::ErrorKind::PermissionDenied, "rejected").into()) + } + Err(e) => Err(io::Error::new(io::ErrorKind::PermissionDenied, e).into()), + } +} + +/// Perform the authentication handshake with the client +#[tracing::instrument(skip_all, "auth")] +async fn inbound_handshake( + mut io: T::Io, + authenticator: &Arc, +) -> Result +where + T: Transport, + A: Address, +{ + let mut conn = Framed::new(&mut io, auth::Codec::new_server()); + + // Wait for the response + let Some(res) = conn.next().await else { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "connection closed").into()); + }; + match res { + Ok(auth::Message::Auth(token)) => { + tracing::debug!(?token, "auth received"); + // If authentication fails, send a reject message and close the connection + if !authenticator.authenticate(&token) { + conn.send(auth::Message::Reject).await?; + conn.close().await?; + return Err(abort().into()) + } + + // Send ack + conn.send(auth::Message::Ack).await?; + + return Ok(io); + } + Ok(msg) => { + tracing::debug!(?msg, "unexpected message during authentication"); + conn.send(auth::Message::Reject).await?; + conn.close().await?; + Err(abort().into()) + } + Err(e) => { + tracing::error!(?e, "error during authentication"); + Err(abort().into()) + } + } +} + +// Helper function for an abort error +fn abort() -> io::Error { + io::Error::new(io::ErrorKind::Other, "authentication failed, connection aborted") +} diff --git a/msg-socket/src/connection/mod.rs b/msg-socket/src/connection/mod.rs index e16d88bb..d8a8ba5b 100644 --- a/msg-socket/src/connection/mod.rs +++ b/msg-socket/src/connection/mod.rs @@ -3,3 +3,6 @@ pub use state::ConnectionState; pub mod backoff; pub use backoff::{Backoff, ExponentialBackoff}; + +mod manager; +pub use manager::*; diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs deleted file mode 100644 index 1a11197b..00000000 --- a/msg-socket/src/req/conn_manager.rs +++ /dev/null @@ -1,213 +0,0 @@ -use std::{ - io, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use futures::{Future, FutureExt, SinkExt, StreamExt}; -use msg_common::span::{EnterSpan as _, WithSpan}; -use tokio_util::codec::Framed; -use tracing::Instrument; - -use crate::{ConnOptions, ConnectionState, ExponentialBackoff}; - -use msg_transport::{Address, MeteredIo, Transport}; -use msg_wire::{auth, reqrep}; - -/// A connection task that connects to a server and returns the underlying IO object. -type ConnTask = Pin> + Send>>; - -/// A connection from the transport to a server. -/// -/// # Usage of Framed -/// [`Framed`] is used for encoding and decoding messages ("frames"). -/// Usually, [`Framed`] has its own internal buffering mechanism, that's respected -/// when calling `poll_ready` and configured by [`Framed::set_backpressure_boundary`]. -/// -/// However, we don't use `poll_ready` here, and instead we flush every time we write a message to -/// the framed buffer. -pub(crate) type Conn = Framed, reqrep::Codec>; - -/// A connection controller that manages the connection to a server with an exponential backoff. -pub(crate) type ConnCtl = ConnectionState, ExponentialBackoff, A>; - -/// Manages the connection lifecycle: connecting, reconnecting, and maintaining the connection. -pub(crate) struct ConnManager, A: Address> { - /// Options for the connection manager. - options: ConnOptions, - /// The connection task which handles the connection to the server. - conn_task: Option>>, - /// The transport controller, wrapped in a [`ConnectionState`] for backoff. - /// The [`Framed`] object can send and receive messages from the socket. - conn_ctl: ConnCtl, - /// The transport for this socket. - transport: T, - /// The address of the server. - addr: A, - /// Transport stats for metering IO. - transport_stats: Arc>, - - /// A span to use for connection-related logging. - span: tracing::Span, -} - -/// Perform the authentication handshake with the server. -#[tracing::instrument(skip_all, "auth", fields(token = ?token))] -async fn authentication_handshake(mut io: T::Io, token: Bytes) -> Result -where - T: Transport, - A: Address, -{ - let mut conn = Framed::new(&mut io, auth::Codec::new_client()); - - conn.send(auth::Message::Auth(token)).await?; - tracing::debug!("sent auth, waiting ack from server"); - - // Wait for the response - let Some(res) = conn.next().await else { - return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "connection closed").into()); - }; - - match res { - Ok(auth::Message::Ack) => { - tracing::debug!("received ack"); - Ok(io) - } - Ok(msg) => { - tracing::error!(?msg, "unexpected ack result"); - Err(io::Error::new(io::ErrorKind::PermissionDenied, "rejected").into()) - } - Err(e) => Err(io::Error::new(io::ErrorKind::PermissionDenied, e).into()), - } -} - -impl ConnManager -where - T: Transport, - A: Address, -{ - pub(crate) fn new( - options: ConnOptions, - transport: T, - addr: A, - conn_ctl: ConnCtl, - transport_stats: Arc>, - span: tracing::Span, - ) -> Self { - Self { options, conn_task: None, conn_ctl, transport, addr, transport_stats, span } - } - - /// Start the connection task to the server, handling authentication if necessary. - /// The result will be polled by the driver and re-tried according to the backoff policy. - fn try_connect(&mut self) { - let connect = self.transport.connect(self.addr.clone()); - let token = self.options.auth_token.clone(); - - let task = async move { - let io = connect.await?; - - let Some(token) = token else { - return Ok(io); - }; - - authentication_handshake::(io, token).await - } - .in_current_span(); - - // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] - self.conn_task = Some(WithSpan::current(Box::pin(task))); - } - - /// Reset the connection state to inactive, so that it will be re-tried. - /// - /// This is done when the connection is closed or an error occurs. - #[inline] - pub(crate) fn reset_connection(&mut self) { - self.conn_ctl = ConnectionState::Inactive { - addr: self.addr.clone(), - backoff: ExponentialBackoff::from(&self.options), - }; - } - - /// Returns a mutable reference to the connection channel if it is active. - #[inline] - pub(crate) fn active_connection(&mut self) -> Option<&mut Conn> { - if let ConnectionState::Active { ref mut channel } = self.conn_ctl { - Some(channel) - } else { - None - } - } - - /// Poll connection management logic: connection task, backoff, and retry logic. - /// Loops until the connection is active, then returns a mutable reference to the channel. - /// - /// Note: this is not a `Future` impl because we want to return a reference; doing it in - /// a `Future` would require lifetime headaches or unsafe code. - /// - /// Returns: - /// * `Poll::Ready(Some(&mut channel))` if the connection is active - /// * `Poll::Ready(None)` if we should terminate (max retries exceeded) - /// * `Poll::Pending` if we need to wait for backoff - #[allow(clippy::type_complexity)] - pub(crate) fn poll( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { - loop { - // Poll the active connection task, if any - if let Some(ref mut conn_task) = self.conn_task { - if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() { - // As soon as the connection task finishes, set it to `None`. - // - If it was successful, set the connection to active - // - If it failed, it will be re-tried until the backoff limit is reached. - self.conn_task = None; - - match result.inner { - Ok(io) => { - tracing::info!("connected"); - - let metered = MeteredIo::new(io, self.transport_stats.clone()); - let framed = Framed::new(metered, reqrep::Codec::new()); - self.conn_ctl = ConnectionState::Active { channel: framed }; - } - Err(e) => { - tracing::error!(?e, "failed to connect"); - } - } - } - } - - // If the connection is inactive, try to connect to the server or poll the backoff - // timer if we're already trying to connect. - if let ConnectionState::Inactive { backoff, .. } = &mut self.conn_ctl { - let Poll::Ready(item) = backoff.poll_next_unpin(cx) else { - return Poll::Pending; - }; - - let _span = tracing::info_span!(parent: &self.span, "connect").entered(); - - if let Some(duration) = item { - if self.conn_task.is_none() { - tracing::debug!(backoff = ?duration, "trying connection"); - self.try_connect(); - } else { - tracing::debug!( - backoff = ?duration, - "not retrying as there is already a connection task" - ); - } - } else { - tracing::error!("exceeded maximum number of retries, terminating connection"); - return Poll::Ready(None); - } - } - - if let ConnectionState::Active { ref mut channel } = self.conn_ctl { - return Poll::Ready(Some(channel)); - } - } - } -} diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index a87bd7a6..b3aac64e 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -8,6 +8,12 @@ use std::{ use bytes::Bytes; use futures::{Future, SinkExt, StreamExt}; +use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan}; +use msg_transport::{Address, Transport}; +use msg_wire::{ + compression::{Compressor, try_decompress_payload}, + reqrep, +}; use rustc_hash::FxHashMap; use tokio::{ sync::{mpsc, oneshot}, @@ -17,19 +23,17 @@ use tokio::{ use super::{ReqError, ReqOptions}; use crate::{ SendCommand, - req::{SocketState, conn_manager::ConnManager}, -}; - -use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan}; -use msg_transport::{Address, Transport}; -use msg_wire::{ - compression::{Compressor, try_decompress_payload}, - reqrep, + connection::{ConnectionController, ConnectionManager}, + req::SocketState, }; /// The request socket driver. Endless future that drives /// the socket forward. -pub(crate) struct ReqDriver, A: Address> { +pub(crate) struct ReqDriver +where + T: Transport, + A: Address, +{ /// Options shared with the socket. pub(crate) options: Arc, /// State shared with the socket. @@ -39,7 +43,7 @@ pub(crate) struct ReqDriver, A: Address> { /// Commands from the socket. pub(crate) from_socket: mpsc::Receiver, /// Connection manager that handles connection lifecycle. - pub(crate) conn_manager: ConnManager, + pub(crate) conn_manager: ConnectionManager, /// The timer for the write buffer linger. pub(crate) linger_timer: Option, /// The outgoing message queue. @@ -66,7 +70,7 @@ pub(crate) struct PendingRequest { sender: oneshot::Sender>, } -impl ReqDriver +impl ReqDriver where T: Transport, A: Address, @@ -165,10 +169,11 @@ where } } -impl Future for ReqDriver +impl Future for ReqDriver where T: Transport, A: Address, + S: ConnectionController + Unpin, { type Output = (); diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index c70b9eaa..53e35d45 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -5,23 +5,21 @@ use std::{ use arc_swap::ArcSwap; use bytes::Bytes; -use thiserror::Error; -use tokio::sync::oneshot; - use msg_common::{constants::KiB, span::WithSpan}; use msg_wire::{ compression::{CompressionType, Compressor}, reqrep, }; +use thiserror::Error; +use tokio::sync::oneshot; -mod conn_manager; mod driver; mod socket; mod stats; pub use socket::*; +use stats::ReqStats; use crate::{Profile, stats::SocketStats}; -use stats::ReqStats; /// The default buffer size for the socket. const DEFAULT_BUFFER_SIZE: usize = 1024; @@ -45,6 +43,8 @@ pub enum ReqError { NoValidEndpoints, #[error("Failed to connect to the target endpoint: {0:?}")] Connect(Box), + #[error("Failed to bind to the socket address")] + Bind(Box), } /// A command to send a request message and wait for a response. @@ -68,7 +68,7 @@ impl SendCommand { /// Options for the connection manager. #[derive(Debug, Clone)] -pub struct ConnOptions { +pub struct ClientOptions { /// Optional authentication token. pub auth_token: Option, /// The backoff duration for the underlying transport on reconnections. @@ -77,7 +77,7 @@ pub struct ConnOptions { pub retry_attempts: Option, } -impl Default for ConnOptions { +impl Default for ClientOptions { fn default() -> Self { Self { auth_token: None, @@ -95,8 +95,8 @@ impl Default for ConnOptions { /// The request socket options. #[derive(Debug, Clone)] pub struct ReqOptions { - /// Options for the connection manager. - pub conn: ConnOptions, + /// Client options for the connection manager. + pub client: ClientOptions, /// Timeout duration for requests. pub timeout: Duration, /// Wether to block on initial connection to the target. @@ -151,7 +151,7 @@ impl ReqOptions { impl ReqOptions { /// Sets the authentication token for the socket. pub fn with_auth_token(mut self, auth_token: Bytes) -> Self { - self.conn.auth_token = Some(auth_token); + self.client.auth_token = Some(auth_token); self } @@ -169,7 +169,7 @@ impl ReqOptions { /// Sets the backoff duration for the socket. pub fn with_backoff_duration(mut self, backoff_duration: Duration) -> Self { - self.conn.backoff_duration = backoff_duration; + self.client.backoff_duration = backoff_duration; self } @@ -177,7 +177,7 @@ impl ReqOptions { /// /// If `None`, all connections will be retried indefinitely. pub fn with_retry_attempts(mut self, retry_attempts: usize) -> Self { - self.conn.retry_attempts = Some(retry_attempts); + self.client.retry_attempts = Some(retry_attempts); self } @@ -211,7 +211,7 @@ impl ReqOptions { impl Default for ReqOptions { fn default() -> Self { Self { - conn: ConnOptions::default(), + client: ClientOptions::default(), timeout: Duration::from_secs(5), blocking_connect: false, min_compress_size: 8192, diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 98f54462..118b636a 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -7,6 +7,12 @@ use std::{ use arc_swap::Guard; use bytes::Bytes; +use msg_common::span::WithSpan; +use msg_transport::{Address, MeteredIo, Transport}; +use msg_wire::{ + compression::Compressor, + reqrep::{self, Codec}, +}; use rustc_hash::FxHashMap; use tokio::{ net::{ToSocketAddrs, lookup_host}, @@ -14,19 +20,13 @@ use tokio::{ }; use tokio_util::codec::Framed; -use msg_common::span::WithSpan; -use msg_transport::{Address, MeteredIo, Transport}; -use msg_wire::{compression::Compressor, reqrep}; - use super::{DEFAULT_BUFFER_SIZE, ReqError, ReqOptions}; use crate::{ ConnectionState, DRIVER_ID, ExponentialBackoff, ReqMessage, SendCommand, - req::{ - SocketState, - conn_manager::{ConnCtl, ConnManager}, - driver::ReqDriver, - stats::ReqStats, + connection::{ + ClientConnection, ConnectionController, ConnectionManager, ServerConnection, ServerOptions, }, + req::{SocketState, driver::ReqDriver, stats::ReqStats}, stats::SocketStats, }; @@ -69,10 +69,45 @@ where // by the backend task as soon as the driver is spawned. let conn_state = ConnectionState::Inactive { addr, - backoff: ExponentialBackoff::from(&self.options.conn), + backoff: ExponentialBackoff::from(&self.options.client), }; - self.spawn_driver(addr, transport, conn_state) + // Initialize client-side connection manager + let conn_manager = + ConnectionManager::, Codec>::new( + self.options.client.clone(), + transport, + addr, + conn_state, + Arc::clone(&self.state.transport_stats), + Box::new(|| Codec::new()), + tracing::Span::none(), + ); + + self.spawn(addr, conn_manager) + } + + /// Bind the socket to the given address. + pub async fn bind(&mut self, addr: SocketAddr) -> Result<(), ReqError> { + let transport = self.transport.take().expect("Transport has been moved"); + + // Initialize server-side connection manager + let mut conn_manager = + ConnectionManager::, Codec>::new( + // TODO: Server options from config + ServerOptions {}, + transport, + addr, + Arc::clone(&self.state.transport_stats), + Box::new(|| Codec::new()), + tracing::Span::none(), + ); + + // Bind the connection manager + conn_manager.bind(addr).await.map_err(|e| ReqError::Bind(e.into()))?; + + self.spawn(addr, conn_manager); + Ok(()) } } @@ -158,17 +193,32 @@ where // by the backend task as soon as the driver is spawned. ConnectionState::Inactive { addr: endpoint.clone(), - backoff: ExponentialBackoff::from(&self.options.conn), + backoff: ExponentialBackoff::from(&self.options.client), } }; - self.spawn_driver(endpoint, transport, conn_state); + // Initialize client-side connection manager + let conn_manager = ConnectionManager::, Codec>::new( + self.options.client.clone(), + transport, + endpoint.clone(), + conn_state, + Arc::clone(&self.state.transport_stats), + Box::new(|| Codec::new()), + tracing::Span::none(), + ); + + // Spawn the driver task + self.spawn(endpoint, conn_manager); Ok(()) } - /// Internal method to initialize and spawn the driver. - fn spawn_driver(&mut self, endpoint: A, transport: T, conn_ctl: ConnCtl) { + fn spawn + Send + Unpin + 'static>( + &mut self, + addr: A, + mut conn_manager: ConnectionManager, + ) { // Initialize communication channels let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); @@ -179,7 +229,11 @@ where let pending_requests = FxHashMap::default(); let id = DRIVER_ID.fetch_add(1, Ordering::Relaxed); - let span = tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), addr = ?endpoint); + let span = + tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), ?addr); + + // Set driver span + conn_manager = conn_manager.with_span(span.clone()); let linger_timer = self.options.write_buffer_linger.map(|duration| { let mut timer = tokio::time::interval(duration); @@ -187,18 +241,8 @@ where timer }); - // Create connection manager - let conn_manager = ConnManager::new( - self.options.conn.clone(), - transport, - endpoint, - conn_ctl, - Arc::clone(&self.state.transport_stats), - span.clone(), - ); - // Create the socket backend - let driver: ReqDriver = ReqDriver { + let driver: ReqDriver = ReqDriver { options: Arc::clone(&self.options), socket_state: self.state.clone(), id_counter: 0, diff --git a/rustfmt.toml b/rustfmt.toml index 68c3c930..e767bbd1 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,5 +1,9 @@ +# Import ordering +group_imports = "StdExternalCrate" reorder_imports = true imports_granularity = "Crate" + +# Other use_small_heuristics = "Max" comment_width = 100 wrap_comments = true