From c1c57affb7e6a5878f702b889fc417997bbb3e33 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 29 Jun 2026 10:41:08 -0400 Subject: [PATCH] fix(pool): Singleton shared calls pass baton if one cancels --- src/client/pool/singleton.rs | 447 ++++++++++++++++++++++++++--------- 1 file changed, 335 insertions(+), 112 deletions(-) diff --git a/src/client/pool/singleton.rs b/src/client/pool/singleton.rs index fc6d172b..cc288579 100644 --- a/src/client/pool/singleton.rs +++ b/src/client/pool/singleton.rs @@ -23,10 +23,9 @@ use std::sync::{Arc, Mutex}; use std::task::{self, Poll}; -use tokio::sync::oneshot; use tower_service::Service; -use self::internal::{DitchGuard, SingletonError, SingletonFuture, State}; +use self::internal::{SingletonError, SingletonFuture, State}; type BoxError = Box; @@ -44,7 +43,7 @@ where M: Service, { mk_svc: M, - state: Arc>>, + state: Arc>>, } impl Singleton @@ -94,7 +93,7 @@ where M::Response: Clone, M::Error: Into, { - type Response = internal::Singled; + type Response = internal::Singled; type Error = SingletonError; type Future = SingletonFuture; @@ -113,18 +112,21 @@ where match *locked { State::Empty => { let fut = self.mk_svc.call(dst); - *locked = State::Making(Vec::new()); - SingletonFuture::Driving { - future: fut, - singleton: DitchGuard(Arc::downgrade(&self.state)), + let mut batch = internal::Batch::new(fut); + let id = batch.register_driver(); + *locked = State::Making(batch); + SingletonFuture::Participating { + id, + state: self.state.clone(), + rx: None, } } - State::Making(ref mut waiters) => { - let (tx, rx) = oneshot::channel(); - waiters.push(tx); - SingletonFuture::Waiting { - rx, - state: Arc::downgrade(&self.state), + State::Making(ref mut batch) => { + let (id, rx) = batch.register_waiter(); + SingletonFuture::Participating { + id, + state: self.state.clone(), + rx: Some(rx), } } State::Made(ref svc) => SingletonFuture::Made { @@ -148,47 +150,77 @@ where } // Holds some "pub" items that otherwise shouldn't be public. +/// Baton-passing implementation. +/// +/// While a singleton service is being made, one participating future is +/// responsible for driving that work. If that future is canceled, the work +/// should not be canceled for every other caller waiting on the same service. +/// Baton-passing lets another participant take over, so cancellation remains +/// local to the future that was dropped. mod internal { + use std::fmt; use std::future::Future; use std::pin::Pin; use std::sync::{Arc, Mutex, Weak}; - use std::task::{self, Poll, ready}; + use std::task::{self, Poll, Waker}; - use pin_project_lite::pin_project; use tokio::sync::oneshot; use tower_service::Service; use super::BoxError; - pin_project! { - #[project = SingletonFutureProj] - pub enum SingletonFuture { - Driving { - #[pin] - future: F, - singleton: DitchGuard, - }, - Waiting { - rx: oneshot::Receiver>, - state: Weak>>, - }, - Made { - svc: Option, - state: Weak>>, - }, - } + pub enum SingletonFuture { + Participating { + id: WaiterId, + state: Arc>>, + rx: Option>>, + }, + Made { + svc: Option, + state: Weak>>, + }, } + impl Unpin for SingletonFuture {} + // XXX: pub because of the enum SingletonFuture - #[derive(Debug)] - pub enum State { + pub enum State { Empty, - Making(Vec>>), + Making(Batch), Made(S), } + impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + State::Empty => f.write_str("Empty"), + State::Making(..) => f.write_str("Making"), + State::Made(svc) => f.debug_tuple("Made").field(svc).finish(), + } + } + } + // XXX: pub because of the enum SingletonFuture - pub struct DitchGuard(pub(super) Weak>>); + pub struct Batch { + future: Option>>, + next_id: WaiterId, + driver: Option, + waiters: Vec>, + } + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub struct WaiterId(usize); + + struct Driver { + id: WaiterId, + waker: Option, + } + + struct Waiter { + id: WaiterId, + waker: Option, + tx: oneshot::Sender>, + } /// A cached service returned from a [`Singleton`]. /// @@ -204,9 +236,9 @@ mod internal { /// code. The type is exposed in the documentation to show which methods /// can be publicly called. #[derive(Debug)] - pub struct Singled { + pub struct Singled { inner: S, - state: Weak>>, + state: Weak>>, } impl Future for SingletonFuture @@ -215,83 +247,77 @@ mod internal { E: Into, S: Clone, { - type Output = Result, SingletonError>; - - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - match self.project() { - SingletonFutureProj::Driving { future, singleton } => { - match ready!(future.poll(cx)) { - Ok(svc) => { - if let Some(state) = singleton.0.upgrade() { - let mut locked = state.lock().unwrap(); - match std::mem::replace(&mut *locked, State::Made(svc.clone())) { - State::Making(waiters) => { - for tx in waiters { - let _ = tx.send(Ok(svc.clone())); - } - } - State::Empty | State::Made(_) => { - // shouldn't happen! - unreachable!() - } - } + type Output = Result, SingletonError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match &mut *self { + SingletonFuture::Participating { id, state, rx } => { + if let Some(receiver) = rx.as_mut() { + match Pin::new(receiver).poll(cx) { + Poll::Ready(Ok(Ok(svc))) => { + return Poll::Ready(Ok(Singled::new(svc, Arc::downgrade(state)))); + } + Poll::Ready(Ok(Err(err))) => { + return Poll::Ready(Err(SingletonError(err))); } - // take out of the DitchGuard so it doesn't treat as "ditched" - let state = std::mem::replace(&mut singleton.0, Weak::new()); - Poll::Ready(Ok(Singled::new(svc, state))) + Poll::Ready(Err(_canceled)) => { + *rx = None; + } + Poll::Pending => {} } - Err(e) => { - let e = box_error_into_shared(e.into()); - if let Some(state) = singleton.0.upgrade() { - let mut locked = state.lock().unwrap(); - singleton.0 = Weak::new(); - match std::mem::replace(&mut *locked, State::Empty) { - State::Making(waiters) => { - for tx in waiters { - let _ = tx.send(Err(e.clone())); - } - } - State::Empty | State::Made(_) => { - // shouldn't happen! - unreachable!() - } - } + } + + let state_weak = Arc::downgrade(state); + let mut locked = state.lock().unwrap(); + + match &mut *locked { + State::Making(batch) => match batch.poll(*id, cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(svc)) => { + batch.send_result(Ok(svc.clone())); + *locked = State::Made(svc.clone()); + Poll::Ready(Ok(Singled::new(svc, state_weak))) } - Poll::Ready(Err(SingletonError(e))) + Poll::Ready(Err(err)) => { + batch.send_result(Err(err.clone())); + *locked = State::Empty; + Poll::Ready(Err(SingletonError(err))) + } + }, + State::Made(svc) => Poll::Ready(Ok(Singled::new(svc.clone(), state_weak))), + State::Empty => { + unreachable!("singleton participant polled after making was canceled") } } } - SingletonFutureProj::Waiting { rx, state } => match ready!(Pin::new(rx).poll(cx)) { - Ok(Ok(svc)) => Poll::Ready(Ok(Singled::new(svc, state.clone()))), - Ok(Err(e)) => Poll::Ready(Err(SingletonError(e))), - Err(_canceled) => { - Poll::Ready(Err(SingletonError(box_error_into_shared(Canceled.into())))) - } - }, - SingletonFutureProj::Made { svc, state } => { + SingletonFuture::Made { svc, state } => { Poll::Ready(Ok(Singled::new(svc.take().unwrap(), state.clone()))) } } } } - impl Drop for DitchGuard { + impl Drop for SingletonFuture { fn drop(&mut self) { - if let Some(state) = self.0.upgrade() { + if let SingletonFuture::Participating { id, state, .. } = self { if let Ok(mut locked) = state.lock() { - *locked = State::Empty; + if let State::Making(batch) = &mut *locked { + if batch.remove(*id) { + *locked = State::Empty; + } + } } } } } - impl Singled { - fn new(inner: S, state: Weak>>) -> Self { + impl Singled { + fn new(inner: S, state: Weak>>) -> Self { Singled { inner, state } } } - impl Service for Singled + impl Service for Singled where S: Service, { @@ -317,9 +343,139 @@ mod internal { } } + impl Batch { + pub(super) fn new(future: F) -> Self { + Batch { + future: Some(Box::pin(future)), + next_id: WaiterId(0), + driver: None, + waiters: Vec::new(), + } + } + + pub(super) fn register_driver(&mut self) -> WaiterId { + let id = self.next_id; + self.next_id.0 += 1; + self.driver = Some(Driver { id, waker: None }); + id + } + + pub(super) fn register_waiter( + &mut self, + ) -> (WaiterId, oneshot::Receiver>) { + let id = self.next_id; + self.next_id.0 += 1; + let (tx, rx) = oneshot::channel(); + self.waiters.push(Waiter { + id, + waker: None, + tx, + }); + (id, rx) + } + + fn remove(&mut self, id: WaiterId) -> bool { + if let Some(pos) = self.waiters.iter().position(|waiter| waiter.id == id) { + self.waiters.swap_remove(pos); + return false; + } + + if self.driver.as_ref().is_some_and(|driver| driver.id == id) { + if let Some(waiter) = self.waiters.pop() { + let waker = waiter.waker; + self.driver = Some(Driver { + id: waiter.id, + waker, + }); + self.wake_driver(); + return false; + } + + self.driver = None; + self.future = None; + return true; + } + + false + } + + fn poll( + &mut self, + id: WaiterId, + cx: &mut task::Context<'_>, + ) -> Poll> + where + F: Future>, + E: Into, + S: Clone, + { + if !self.driver.as_ref().is_some_and(|driver| driver.id == id) { + self.store_waker(id, cx.waker()); + return Poll::Pending; + } + + let future = self.future.as_mut().expect("batch future missing"); + match future.as_mut().poll(cx) { + Poll::Pending => { + self.store_driver_waker(cx.waker()); + Poll::Pending + } + Poll::Ready(Ok(svc)) => { + self.future = None; + Poll::Ready(Ok(svc)) + } + Poll::Ready(Err(err)) => { + let err = box_error_into_shared(err.into()); + self.future = None; + Poll::Ready(Err(err)) + } + } + } + + fn send_result(&mut self, result: Result) + where + S: Clone, + { + for waiter in std::mem::take(&mut self.waiters) { + let _ = waiter.tx.send(result.clone()); + } + } + + fn store_waker(&mut self, id: WaiterId, waker: &Waker) { + if let Some(waiter) = self.waiters.iter_mut().find(|waiter| waiter.id == id) { + if waiter + .waker + .as_ref() + .is_none_or(|current| !current.will_wake(waker)) + { + waiter.waker = Some(waker.clone()); + } + } + } + + fn store_driver_waker(&mut self, waker: &Waker) { + if let Some(driver) = &mut self.driver { + if driver + .waker + .as_ref() + .is_none_or(|current| !current.will_wake(waker)) + { + driver.waker = Some(waker.clone()); + } + } + } + + fn wake_driver(&mut self) { + if let Some(driver) = &mut self.driver { + if let Some(waker) = driver.waker.take() { + waker.wake(); + } + } + } + } + // An opaque error type. By not exposing the type, nor being specifically - // Box, we can _change_ the type once we no longer need the Canceled - // error type. This will be possible with the refactor to baton passing. + // Box, we can change the inner representation later. #[derive(Debug)] pub struct SingletonError(pub(super) SharedError); @@ -346,17 +502,6 @@ mod internal { fn box_error_into_shared(error: BoxError) -> SharedError { error.into() } - - #[derive(Debug)] - struct Canceled; - - impl std::fmt::Display for Canceled { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("singleton connection canceled") - } - } - - impl std::error::Error for Canceled {} } #[cfg(test)] @@ -521,9 +666,8 @@ mod tests { assert!(std::ptr::addr_eq(src1, src3)); } - // TODO: this should be able to be improved with a cooperative baton refactor #[tokio::test] - async fn cancel_driver_cancels_all() { + async fn cancel_driver_hands_off_to_waiter() { let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>(); let mut singleton = Singleton::new(mock_svc); @@ -543,9 +687,88 @@ mod tests { let ((), send_response) = handle.next_request().await.unwrap(); send_response.send_response("svc"); - assert_eq!( - fut2.await.unwrap_err().0.to_string(), - "singleton connection canceled" - ); + fut2.await.unwrap(); + } + + #[tokio::test] + async fn cancel_driver_promotes_parked_waiter() { + let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>(); + let mut singleton = Singleton::new(mock_svc); + + std::future::poll_fn(|cx| singleton.poll_ready(cx)) + .await + .unwrap(); + let mut fut1 = singleton.call(()); + let fut2 = singleton.call(()); + + // Start the make future so dropping fut1 below exercises driver + // handoff during an in-flight connection attempt. + std::future::poll_fn(|cx| { + assert!(Pin::new(&mut fut1).poll(cx).is_pending()); + Poll::Ready(()) + }) + .await; + + // Poll the waiter once so it parks and stores a waker in the batch. + // This covers promotion of an already-parked waiter, not just a + // registered-but-never-polled waiter. + let mut waiter = tokio_test::task::spawn(fut2); + assert!(waiter.poll().is_pending()); + assert!(!waiter.is_woken()); + + // When the driver is dropped, the promoted parked waiter must be + // woken so it can take over driving the shared make future. + drop(fut1); + assert!(waiter.is_woken()); + + // Poll after promotion, before the maker responds, so the waiter + // actually takes over as driver and stores its own waker. + assert!(waiter.poll().is_pending()); + + let ((), send_response) = handle.next_request().await.unwrap(); + send_response.send_response("svc"); + + assert!(waiter.is_woken()); + match waiter.poll() { + Poll::Ready(Ok(_)) => {} + other => panic!("expected promoted waiter to complete, got {other:?}"), + } + } + + #[tokio::test] + async fn cancel_all_waiters_clears_singleton() { + let (mock_svc, _handle) = tower_test::mock::pair::<(), &'static str>(); + let mut singleton = Singleton::new(mock_svc); + + std::future::poll_fn(|cx| singleton.poll_ready(cx)) + .await + .unwrap(); + let fut1 = singleton.call(()); + let fut2 = singleton.call(()); + + drop(fut1); + drop(fut2); + + assert!(singleton.is_empty()); + } + + #[tokio::test] + async fn cancel_non_driver_waiter_does_not_block_others() { + let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>(); + let mut singleton = Singleton::new(mock_svc); + + std::future::poll_fn(|cx| singleton.poll_ready(cx)) + .await + .unwrap(); + let fut1 = singleton.call(()); + let fut2 = singleton.call(()); + let fut3 = singleton.call(()); + drop(fut2); + + let ((), send_response) = handle.next_request().await.unwrap(); + send_response.send_response("svc"); + + fut1.await.unwrap(); + fut3.await.unwrap(); } }