From a9052dc91d382227a738923af388cca4bf8cf7d1 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 29 Jun 2026 11:33:21 -0400 Subject: [PATCH] fix(pool): preserve Cache readiness with clones --- src/client/pool/cache.rs | 194 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 185 insertions(+), 9 deletions(-) diff --git a/src/client/pool/cache.rs b/src/client/pool/cache.rs index 699fe6f0..22a64fe1 100644 --- a/src/client/pool/cache.rs +++ b/src/client/pool/cache.rs @@ -54,6 +54,7 @@ mod internal { connector: M, shared: Arc>>, events: Ev, + ready: Ready, } /// A builder to configure a `Cache`. @@ -85,6 +86,12 @@ mod internal { // todo: on_idle } + #[derive(Debug)] + enum Ready { + None, + Cached(S), + } + pub enum CacheFuture where M: Service, @@ -147,6 +154,7 @@ mod internal { Cache { connector, events: self.events, + ready: Ready::None, shared: Arc::new(Mutex::new(Shared { services: Vec::new(), waiters: Vec::new(), @@ -166,12 +174,19 @@ mod internal { where F: FnMut(&mut M::Response) -> bool, { + let mut predicate = predicate; + if let Ready::Cached(svc) = &mut self.ready { + if !predicate(svc) { + self.ready = Ready::None; + } + } + self.shared.lock().unwrap().services.retain_mut(predicate); } /// Check whether this cache has no cached services. pub fn is_empty(&self) -> bool { - self.shared.lock().unwrap().services.is_empty() + matches!(self.ready, Ready::None) && self.shared.lock().unwrap().services.is_empty() } } @@ -187,23 +202,38 @@ mod internal { type Future = CacheFuture; fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - if !self.shared.lock().unwrap().services.is_empty() { - Poll::Ready(Ok(())) - } else { - self.connector.poll_ready(cx) + match self.ready { + Ready::Cached(_) => return Poll::Ready(Ok(())), + Ready::None => {} } + + if let Some(svc) = self.shared.lock().unwrap().take() { + self.ready = Ready::Cached(svc); + return Poll::Ready(Ok(())); + } + + self.connector.poll_ready(cx) } fn call(&mut self, target: Dst) -> Self::Future { // 1. If already cached, easy! - let waiter = { - let mut locked = self.shared.lock().unwrap(); - if let Some(found) = locked.take() { + match std::mem::replace(&mut self.ready, Ready::None) { + Ready::Cached(svc) => { return CacheFuture::Cached { - svc: Some(Cached::new(found, Arc::downgrade(&self.shared))), + svc: Some(Cached::new(svc, Arc::downgrade(&self.shared))), }; } + Ready::None => { + if let Some(svc) = self.shared.lock().unwrap().take() { + return CacheFuture::Cached { + svc: Some(Cached::new(svc, Arc::downgrade(&self.shared))), + }; + } + } + } + let waiter = { + let mut locked = self.shared.lock().unwrap(); let (tx, rx) = oneshot::channel(); locked.waiters.push(tx); rx @@ -229,6 +259,20 @@ mod internal { connector: self.connector.clone(), events: self.events.clone(), shared: self.shared.clone(), + ready: Ready::None, + } + } + } + + impl Drop for Cache + where + M: Service, + { + fn drop(&mut self) { + if let Ready::Cached(svc) = std::mem::replace(&mut self.ready, Ready::None) { + if let Ok(mut shared) = self.shared.lock() { + shared.put(svc); + } } } } @@ -440,6 +484,13 @@ mod events { #[cfg(test)] mod tests { + use std::convert::Infallible; + use std::sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }; + use std::task::{self, Poll}; + use futures_util::future; use tower_service::Service; use tower_test::assert_request_eq; @@ -491,4 +542,129 @@ mod tests { let cached = f.await.expect("call"); drop(cached); } + + #[tokio::test] + async fn clone_readiness_reserves_idle_service() { + let connector = StrictConnector::default(); + let poll_ready_count = connector.poll_ready_count.clone(); + let calls = connector.calls.clone(); + let mut cache = super::builder().build(connector); + + std::future::poll_fn(|cx| cache.poll_ready(cx)) + .await + .unwrap(); + let cached = cache.call(1).await.unwrap(); + assert_eq!(*cached.inner(), 0); + drop(cached); + + let mut a = cache.clone(); + let mut b = cache.clone(); + + std::future::poll_fn(|cx| a.poll_ready(cx)).await.unwrap(); + assert_eq!(poll_ready_count.load(Ordering::SeqCst), 1); + assert!(!a.is_empty()); + + std::future::poll_fn(|cx| b.poll_ready(cx)).await.unwrap(); + assert_eq!(poll_ready_count.load(Ordering::SeqCst), 2); + + let a_cached = a.call(10).await.unwrap(); + assert_eq!(*a_cached.inner(), 0); + + let b_cached = b.call(20).await.unwrap(); + assert_eq!(*b_cached.inner(), 1); + + assert_eq!(*calls.lock().unwrap(), vec![1, 20]); + } + + #[tokio::test] + async fn dropped_ready_slot_returns_idle_service() { + let connector = StrictConnector::default(); + let poll_ready_count = connector.poll_ready_count.clone(); + let mut cache = super::builder().build(connector); + + std::future::poll_fn(|cx| cache.poll_ready(cx)) + .await + .unwrap(); + let cached = cache.call(1).await.unwrap(); + drop(cached); + + let mut clone = cache.clone(); + std::future::poll_fn(|cx| clone.poll_ready(cx)) + .await + .unwrap(); + drop(clone); + + std::future::poll_fn(|cx| cache.poll_ready(cx)) + .await + .unwrap(); + assert_eq!(poll_ready_count.load(Ordering::SeqCst), 1); + + let cached = cache.call(2).await.unwrap(); + assert_eq!(*cached.inner(), 0); + } + + #[tokio::test] + async fn retain_checks_ready_slot() { + let connector = StrictConnector::default(); + let poll_ready_count = connector.poll_ready_count.clone(); + let mut cache = super::builder().build(connector); + + std::future::poll_fn(|cx| cache.poll_ready(cx)) + .await + .unwrap(); + let cached = cache.call(1).await.unwrap(); + drop(cached); + + std::future::poll_fn(|cx| cache.poll_ready(cx)) + .await + .unwrap(); + assert!(!cache.is_empty()); + + cache.retain(|svc| *svc != 0); + assert!(cache.is_empty()); + + std::future::poll_fn(|cx| cache.poll_ready(cx)) + .await + .unwrap(); + assert_eq!(poll_ready_count.load(Ordering::SeqCst), 2); + } + + #[derive(Default)] + struct StrictConnector { + poll_ready_count: Arc, + next: Arc, + calls: Arc>>, + ready: bool, + } + + impl Clone for StrictConnector { + fn clone(&self) -> Self { + StrictConnector { + poll_ready_count: self.poll_ready_count.clone(), + next: self.next.clone(), + calls: self.calls.clone(), + ready: false, + } + } + } + + impl Service for StrictConnector { + type Response = usize; + type Error = Infallible; + type Future = std::future::Ready>; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { + self.ready = true; + self.poll_ready_count.fetch_add(1, Ordering::SeqCst); + Poll::Ready(Ok(())) + } + + fn call(&mut self, target: usize) -> Self::Future { + assert!(self.ready, "connector called without poll_ready"); + self.ready = false; + self.calls.lock().unwrap().push(target); + let id = self.next.fetch_add(1, Ordering::SeqCst); + std::future::ready(Ok(id)) + } + } }