diff --git a/Cargo.lock b/Cargo.lock index 2e9f76f..30d4dad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1443,6 +1443,17 @@ dependencies = [ "serde", ] +[[package]] +name = "async-lock" +version = "3.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f7f2596bd5b78a9fec8088ccd89180d7f9f55b94b0576823bbbdc72ee8311" +dependencies = [ + "event-listener", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "async-recursion" version = "1.1.1" @@ -3077,6 +3088,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "eyre" version = "0.6.12" @@ -5313,6 +5334,26 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "moka" +version = "0.12.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "event-listener", + "futures-util", + "parking_lot", + "portable-atomic", + "smallvec", + "tagptr", + "uuid 1.17.0", +] + [[package]] name = "multer" version = "3.1.0" @@ -8921,6 +8962,7 @@ dependencies = [ "base64 0.22.1", "clap", "futures", + "moka", "rain-math-float", "rain_orderbook_app_settings", "rain_orderbook_bindings", @@ -9282,6 +9324,12 @@ dependencies = [ "libc", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tap" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 267ac21..ec2986d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ sqlx = { version = "0.8.6", features = ["runtime-tokio", "sqlite", "migrate"] } argon2 = "0.5.3" base64 = "0.22.1" clap = { version = "4.5.58", features = ["derive"] } +moka = { version = "0.12", features = ["future"] } toml = "0.8" reqwest = { version = "0.13.2", features = ["json"] } rain_orderbook_js_api = { path = "lib/rain.orderbook/crates/js_api", default-features = false } diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..9a589ec --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,187 @@ +use moka::future::Cache; +use std::future::Future; +use std::sync::Arc; +use std::time::Duration; + +pub(crate) struct AppCache(Cache) +where + K: std::hash::Hash + Eq + Send + Sync + 'static, + V: Clone + Send + Sync + 'static; + +impl AppCache +where + K: std::hash::Hash + Eq + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, +{ + pub(crate) fn new(max_capacity: u64, ttl: Duration) -> Self { + Self( + Cache::builder() + .max_capacity(max_capacity) + .time_to_live(ttl) + .build(), + ) + } + + pub(crate) async fn get(&self, key: &K) -> Option { + self.0.get(key).await + } + + pub(crate) async fn insert(&self, key: K, value: V) { + self.0.insert(key, value).await + } + + pub(crate) async fn get_or_try_insert(&self, key: K, fetch: F) -> Result> + where + F: FnOnce() -> Fut, + Fut: Future>, + E: Send + Sync + 'static, + { + self.0.try_get_with(key, async move { fetch().await }).await + } + + pub(crate) fn invalidate_all(&self) { + self.0.invalidate_all() + } +} + +trait Invalidatable: Send + Sync { + fn invalidate_all(&self); +} + +impl Invalidatable for Cache +where + K: std::hash::Hash + Eq + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, +{ + fn invalidate_all(&self) { + Cache::invalidate_all(self) + } +} + +pub(crate) struct CacheGroup { + caches: Vec>, +} + +impl CacheGroup { + pub(crate) fn new() -> Self { + Self { caches: Vec::new() } + } + + pub(crate) fn register(&mut self, cache: &AppCache) + where + K: std::hash::Hash + Eq + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, + { + self.caches.push(Arc::new(cache.0.clone())); + } + + pub(crate) fn invalidate_all(&self) { + for cache in &self.caches { + cache.invalidate_all(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[rocket::async_test] + async fn test_app_cache_insert_and_get() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + cache.insert("key", 42).await; + assert_eq!(cache.get(&"key").await, Some(42)); + } + + #[rocket::async_test] + async fn test_app_cache_get_returns_none_for_missing_key() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + assert!(cache.get(&"missing").await.is_none()); + } + + #[rocket::async_test] + async fn test_app_cache_invalidate_all_clears_entries() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + cache.insert("a", 1).await; + cache.insert("b", 2).await; + cache.invalidate_all(); + tokio::task::yield_now().await; + assert!(cache.get(&"a").await.is_none()); + assert!(cache.get(&"b").await.is_none()); + } + + #[rocket::async_test] + async fn test_get_or_try_insert_calls_fetch_on_miss() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + let result: Result> = + cache.get_or_try_insert("key", || async { Ok(42) }).await; + assert_eq!(result.unwrap(), 42); + assert_eq!(cache.get(&"key").await, Some(42)); + } + + #[rocket::async_test] + async fn test_get_or_try_insert_returns_cached_on_hit() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + cache.insert("key", 42).await; + let result: Result> = cache + .get_or_try_insert("key", || async { panic!("fetch should not be called") }) + .await; + assert_eq!(result.unwrap(), 42); + } + + #[rocket::async_test] + async fn test_get_or_try_insert_does_not_cache_errors() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + let result: Result> = cache + .get_or_try_insert("key", || async { Err("fail".to_string()) }) + .await; + assert!(result.is_err()); + assert!(cache.get(&"key").await.is_none()); + } + + #[rocket::async_test] + async fn test_get_or_try_insert_coalesces_concurrent_misses() { + let cache: Arc> = + Arc::new(AppCache::new(10, Duration::from_secs(60))); + let fetch_count = Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let mut tasks = tokio::task::JoinSet::new(); + + for _ in 0..10 { + let cache = cache.clone(); + let fetch_count = fetch_count.clone(); + tasks.spawn(async move { + cache + .get_or_try_insert("key".to_string(), || async move { + fetch_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(25)).await; + Ok::<_, String>(42) + }) + .await + }); + } + + while let Some(result) = tasks.join_next().await { + assert_eq!(result.unwrap().unwrap(), 42); + } + + assert_eq!(fetch_count.load(std::sync::atomic::Ordering::SeqCst), 1); + assert_eq!(cache.get(&"key".to_string()).await, Some(42)); + } + + #[rocket::async_test] + async fn test_cache_group_invalidate_all_clears_registered_caches() { + let cache_a: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + let cache_b: AppCache = AppCache::new(10, Duration::from_secs(60)); + cache_a.insert("x", 10).await; + cache_b.insert(1, "hello".into()).await; + + let mut group = CacheGroup::new(); + group.register(&cache_a); + group.register(&cache_b); + group.invalidate_all(); + + tokio::task::yield_now().await; + assert!(cache_a.get(&"x").await.is_none()); + assert!(cache_b.get(&1).await.is_none()); + } +} diff --git a/src/main.rs b/src/main.rs index 4b6af71..a7fa8f8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ extern crate rocket; mod auth; +mod cache; mod catchers; mod cli; mod config;