Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 185 additions & 9 deletions src/client/pool/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ mod internal {
connector: M,
shared: Arc<Mutex<Shared<M::Response>>>,
events: Ev,
ready: Ready<M::Response>,
}

/// A builder to configure a `Cache`.
Expand Down Expand Up @@ -85,6 +86,12 @@ mod internal {
// todo: on_idle
}

#[derive(Debug)]
enum Ready<S> {
None,
Cached(S),
}

pub enum CacheFuture<M, Dst, Ev>
where
M: Service<Dst>,
Expand Down Expand Up @@ -147,6 +154,7 @@ mod internal {
Cache {
connector,
events: self.events,
ready: Ready::None,
shared: Arc::new(Mutex::new(Shared {
services: Vec::new(),
waiters: Vec::new(),
Expand All @@ -166,12 +174,19 @@ mod internal {
where
F: FnMut(&mut M::Response) -> bool,
{
let mut predicate = predicate;
if let Ready::Cached(svc) = &mut self.ready {
if !predicate(svc) {
self.ready = Ready::None;
}
}

self.shared.lock().unwrap().services.retain_mut(predicate);
}

/// Check whether this cache has no cached services.
pub fn is_empty(&self) -> bool {
self.shared.lock().unwrap().services.is_empty()
matches!(self.ready, Ready::None) && self.shared.lock().unwrap().services.is_empty()
}
}

Expand All @@ -187,23 +202,38 @@ mod internal {
type Future = CacheFuture<M, Dst, Ev>;

fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
if !self.shared.lock().unwrap().services.is_empty() {
Poll::Ready(Ok(()))
} else {
self.connector.poll_ready(cx)
match self.ready {
Ready::Cached(_) => return Poll::Ready(Ok(())),
Ready::None => {}
}

if let Some(svc) = self.shared.lock().unwrap().take() {
self.ready = Ready::Cached(svc);
return Poll::Ready(Ok(()));
}

self.connector.poll_ready(cx)
}

fn call(&mut self, target: Dst) -> Self::Future {
// 1. If already cached, easy!
let waiter = {
let mut locked = self.shared.lock().unwrap();
if let Some(found) = locked.take() {
match std::mem::replace(&mut self.ready, Ready::None) {
Ready::Cached(svc) => {
return CacheFuture::Cached {
svc: Some(Cached::new(found, Arc::downgrade(&self.shared))),
svc: Some(Cached::new(svc, Arc::downgrade(&self.shared))),
};
}
Ready::None => {
if let Some(svc) = self.shared.lock().unwrap().take() {
return CacheFuture::Cached {
svc: Some(Cached::new(svc, Arc::downgrade(&self.shared))),
};
}
}
}

let waiter = {
let mut locked = self.shared.lock().unwrap();
let (tx, rx) = oneshot::channel();
locked.waiters.push(tx);
rx
Expand All @@ -229,6 +259,20 @@ mod internal {
connector: self.connector.clone(),
events: self.events.clone(),
shared: self.shared.clone(),
ready: Ready::None,
}
}
}

impl<M, Dst, Ev> Drop for Cache<M, Dst, Ev>
where
M: Service<Dst>,
{
fn drop(&mut self) {
if let Ready::Cached(svc) = std::mem::replace(&mut self.ready, Ready::None) {
if let Ok(mut shared) = self.shared.lock() {
shared.put(svc);
}
}
}
}
Expand Down Expand Up @@ -440,6 +484,13 @@ mod events {

#[cfg(test)]
mod tests {
use std::convert::Infallible;
use std::sync::{
Arc, Mutex,
atomic::{AtomicUsize, Ordering},
};
use std::task::{self, Poll};

use futures_util::future;
use tower_service::Service;
use tower_test::assert_request_eq;
Expand Down Expand Up @@ -491,4 +542,129 @@ mod tests {
let cached = f.await.expect("call");
drop(cached);
}

#[tokio::test]
async fn clone_readiness_reserves_idle_service() {
let connector = StrictConnector::default();
let poll_ready_count = connector.poll_ready_count.clone();
let calls = connector.calls.clone();
let mut cache = super::builder().build(connector);

std::future::poll_fn(|cx| cache.poll_ready(cx))
.await
.unwrap();
let cached = cache.call(1).await.unwrap();
assert_eq!(*cached.inner(), 0);
drop(cached);

let mut a = cache.clone();
let mut b = cache.clone();

std::future::poll_fn(|cx| a.poll_ready(cx)).await.unwrap();
assert_eq!(poll_ready_count.load(Ordering::SeqCst), 1);
assert!(!a.is_empty());

std::future::poll_fn(|cx| b.poll_ready(cx)).await.unwrap();
assert_eq!(poll_ready_count.load(Ordering::SeqCst), 2);

let a_cached = a.call(10).await.unwrap();
assert_eq!(*a_cached.inner(), 0);

let b_cached = b.call(20).await.unwrap();
assert_eq!(*b_cached.inner(), 1);

assert_eq!(*calls.lock().unwrap(), vec![1, 20]);
}

#[tokio::test]
async fn dropped_ready_slot_returns_idle_service() {
let connector = StrictConnector::default();
let poll_ready_count = connector.poll_ready_count.clone();
let mut cache = super::builder().build(connector);

std::future::poll_fn(|cx| cache.poll_ready(cx))
.await
.unwrap();
let cached = cache.call(1).await.unwrap();
drop(cached);

let mut clone = cache.clone();
std::future::poll_fn(|cx| clone.poll_ready(cx))
.await
.unwrap();
drop(clone);

std::future::poll_fn(|cx| cache.poll_ready(cx))
.await
.unwrap();
assert_eq!(poll_ready_count.load(Ordering::SeqCst), 1);

let cached = cache.call(2).await.unwrap();
assert_eq!(*cached.inner(), 0);
}

#[tokio::test]
async fn retain_checks_ready_slot() {
let connector = StrictConnector::default();
let poll_ready_count = connector.poll_ready_count.clone();
let mut cache = super::builder().build(connector);

std::future::poll_fn(|cx| cache.poll_ready(cx))
.await
.unwrap();
let cached = cache.call(1).await.unwrap();
drop(cached);

std::future::poll_fn(|cx| cache.poll_ready(cx))
.await
.unwrap();
assert!(!cache.is_empty());

cache.retain(|svc| *svc != 0);
assert!(cache.is_empty());

std::future::poll_fn(|cx| cache.poll_ready(cx))
.await
.unwrap();
assert_eq!(poll_ready_count.load(Ordering::SeqCst), 2);
}

#[derive(Default)]
struct StrictConnector {
poll_ready_count: Arc<AtomicUsize>,
next: Arc<AtomicUsize>,
calls: Arc<Mutex<Vec<usize>>>,
ready: bool,
}

impl Clone for StrictConnector {
fn clone(&self) -> Self {
StrictConnector {
poll_ready_count: self.poll_ready_count.clone(),
next: self.next.clone(),
calls: self.calls.clone(),
ready: false,
}
}
}

impl Service<usize> for StrictConnector {
type Response = usize;
type Error = Infallible;
type Future = std::future::Ready<Result<usize, Infallible>>;

fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.ready = true;
self.poll_ready_count.fetch_add(1, Ordering::SeqCst);
Poll::Ready(Ok(()))
}

fn call(&mut self, target: usize) -> Self::Future {
assert!(self.ready, "connector called without poll_ready");
self.ready = false;
self.calls.lock().unwrap().push(target);
let id = self.next.fetch_add(1, Ordering::SeqCst);
std::future::ready(Ok(id))
}
}
}
Loading