From df8d31944130f19aab5955b02b61c55209c6cbc4 Mon Sep 17 00:00:00 2001 From: Vittorio Distefano Date: Sat, 21 Mar 2026 14:55:17 +0100 Subject: [PATCH] cache: deduplicate load_or_fetch by key --- src/routing/cache.rs | 8 +- src/routing/fetch.rs | 227 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 193 insertions(+), 42 deletions(-) diff --git a/src/routing/cache.rs b/src/routing/cache.rs index 0ef9c27..4034be3 100644 --- a/src/routing/cache.rs +++ b/src/routing/cache.rs @@ -4,10 +4,11 @@ use std::collections::HashMap; use std::mem::size_of; use std::ops::Deref; use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use std::sync::OnceLock; use serde::{Deserialize, Serialize}; -use tokio::sync::{RwLock, RwLockReadGuard}; +use tokio::sync::{Mutex, RwLock, RwLockReadGuard}; use super::bbox::BoundingBox; use super::network::RoadNetwork; @@ -15,6 +16,7 @@ use super::network::RoadNetwork; pub const CACHE_VERSION: u32 = 5; static NETWORK_CACHE: OnceLock>> = OnceLock::new(); +static IN_FLIGHT_LOADS: OnceLock>>>> = OnceLock::new(); static CACHE_HITS: AtomicU64 = AtomicU64::new(0); static CACHE_MISSES: AtomicU64 = AtomicU64::new(0); @@ -22,6 +24,10 @@ pub(crate) fn cache() -> &'static RwLock> { NETWORK_CACHE.get_or_init(|| RwLock::new(HashMap::new())) } +pub(crate) fn in_flight_loads() -> &'static Mutex>>> { + IN_FLIGHT_LOADS.get_or_init(|| Mutex::new(HashMap::new())) +} + pub(crate) fn record_hit() { CACHE_HITS.fetch_add(1, Ordering::Relaxed); } diff --git a/src/routing/fetch.rs b/src/routing/fetch.rs index 9a2c929..eb2b515 100644 --- a/src/routing/fetch.rs +++ b/src/routing/fetch.rs @@ -1,15 +1,17 @@ //! Overpass API fetching and caching for road networks. use std::collections::HashMap; +use std::future::Future; use std::path::Path; +use std::sync::Arc; -use tokio::sync::mpsc::Sender; +use tokio::sync::{mpsc::Sender, Mutex, OwnedMutexGuard}; use tracing::{debug, info}; use super::bbox::BoundingBox; use super::cache::{ - cache, record_hit, record_miss, CachedEdge, CachedNetwork, CachedNode, NetworkRef, - CACHE_VERSION, + cache, in_flight_loads, record_hit, record_miss, CachedEdge, CachedNetwork, CachedNode, + NetworkRef, CACHE_VERSION, }; use super::config::{ConnectivityPolicy, NetworkConfig}; use super::coord::Coord; @@ -49,48 +51,36 @@ impl RoadNetwork { let _ = tx.send(RoutingProgress::CheckingCache { percent: 5 }).await; } - { - let mut cache_guard = cache().write().await; - if !cache_guard.contains_key(&cache_key) { - tokio::fs::create_dir_all(&config.cache_dir).await?; - let cache_path = config.cache_dir.join(format!("{}.json", cache_key)); - - let network = if tokio::fs::try_exists(&cache_path).await.unwrap_or(false) { - info!("Loading road network from file cache: {:?}", cache_path); - if let Some(tx) = progress { - let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await; - } - match Self::load_from_file(&cache_path, config).await { - Ok(n) => { - if let Some(tx) = progress { - let _ = tx - .send(RoutingProgress::BuildingGraph { percent: 50 }) - .await; - } - n - } - Err(e) => { - info!("File cache invalid ({}), downloading fresh", e); - let n = Self::fetch_from_api(bbox, config, progress).await?; - n.save_to_file(&cache_path).await?; - info!("Saved road network to file cache: {:?}", cache_path); - n + Self::load_or_insert(cache_key, async { + tokio::fs::create_dir_all(&config.cache_dir).await?; + let cache_path = config.cache_dir.join(format!("{}.json", bbox.cache_key())); + + if tokio::fs::try_exists(&cache_path).await.unwrap_or(false) { + info!("Loading road network from file cache: {:?}", cache_path); + if let Some(tx) = progress { + let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await; + } + match Self::load_from_file(&cache_path, config).await { + Ok(network) => { + if let Some(tx) = progress { + let _ = tx + .send(RoutingProgress::BuildingGraph { percent: 50 }) + .await; } + return Ok(network); } - } else { - info!("Downloading road network from Overpass API"); - let n = Self::fetch_from_api(bbox, config, progress).await?; - n.save_to_file(&cache_path).await?; - info!("Saved road network to file cache: {:?}", cache_path); - n - }; - - cache_guard.insert(cache_key.clone(), network); + Err(e) => info!("File cache invalid ({}), downloading fresh", e), + } + } else { + info!("Downloading road network from Overpass API"); } - } - let cache_guard = cache().read().await; - Ok(NetworkRef::new(cache_guard, cache_key)) + let network = Self::fetch_from_api(bbox, config, progress).await?; + network.save_to_file(&cache_path).await?; + info!("Saved road network to file cache: {:?}", cache_path); + Ok(network) + }) + .await } pub async fn fetch( @@ -434,6 +424,47 @@ out body;"#, Ok(()) } + + async fn load_or_insert(cache_key: String, load: F) -> Result + where + F: Future>, + { + if let Some(cached) = Self::get_cached_network(cache_key.clone()).await { + return Ok(cached); + } + + record_miss(); + + let (slot, _slot_guard) = acquire_in_flight_slot(&cache_key).await; + + if let Some(cached) = Self::get_cached_network(cache_key.clone()).await { + cleanup_in_flight_slot(&cache_key, &slot).await; + return Ok(cached); + } + + let network = load.await?; + + { + let mut cache_guard = cache().write().await; + cache_guard.entry(cache_key.clone()).or_insert(network); + } + + cleanup_in_flight_slot(&cache_key, &slot).await; + + Self::get_cached_network(cache_key).await.ok_or_else(|| { + RoutingError::Network("cached network disappeared after insertion".to_string()) + }) + } + + async fn get_cached_network(cache_key: String) -> Option { + let cache_guard = cache().read().await; + if cache_guard.contains_key(&cache_key) { + record_hit(); + Some(NetworkRef::new(cache_guard, cache_key)) + } else { + None + } + } } impl RoadNetwork { @@ -442,3 +473,117 @@ impl RoadNetwork { Self::load_or_fetch(bbox, &NetworkConfig::default(), None).await } } + +async fn acquire_in_flight_slot(cache_key: &str) -> (Arc>, OwnedMutexGuard<()>) { + let slot = { + let mut in_flight = in_flight_loads().lock().await; + in_flight + .entry(cache_key.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() + }; + + let guard = slot.clone().lock_owned().await; + (slot, guard) +} + +async fn cleanup_in_flight_slot(cache_key: &str, slot: &Arc>) { + let mut in_flight = in_flight_loads().lock().await; + let should_remove = in_flight + .get(cache_key) + .map(|current| Arc::ptr_eq(current, slot) && Arc::strong_count(slot) == 2) + .unwrap_or(false); + + if should_remove { + in_flight.remove(cache_key); + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::time::{Duration, Instant}; + + use tokio::time::sleep; + + use super::*; + + fn test_network() -> RoadNetwork { + RoadNetwork::from_test_data(&[(0.0, 0.0), (0.0, 0.01)], &[(0, 1, 60.0, 1_000.0)]) + } + + async fn reset_test_state() { + RoadNetwork::clear_cache().await; + in_flight_loads().lock().await.clear(); + } + + #[tokio::test] + async fn load_or_insert_allows_different_keys_to_progress_concurrently() { + reset_test_state().await; + + let start = Instant::now(); + let first = async { + RoadNetwork::load_or_insert("region-a".to_string(), async { + sleep(Duration::from_millis(100)).await; + Ok(test_network()) + }) + .await + .map(|network| network.node_count()) + }; + let second = async { + RoadNetwork::load_or_insert("region-b".to_string(), async { + sleep(Duration::from_millis(100)).await; + Ok(test_network()) + }) + .await + .map(|network| network.node_count()) + }; + let (left, right) = tokio::join!(first, second); + left.expect("first load should succeed"); + right.expect("second load should succeed"); + + assert!( + start.elapsed() < Duration::from_millis(180), + "different keys should not serialize slow loads" + ); + } + + #[tokio::test] + async fn load_or_insert_deduplicates_same_key_work() { + reset_test_state().await; + + let loads = Arc::new(AtomicUsize::new(0)); + + let first = { + let loads = loads.clone(); + async move { + RoadNetwork::load_or_insert("region-a".to_string(), async move { + loads.fetch_add(1, Ordering::Relaxed); + sleep(Duration::from_millis(50)).await; + Ok(test_network()) + }) + .await + .map(|network| network.node_count()) + } + }; + let second = { + let loads = loads.clone(); + async move { + RoadNetwork::load_or_insert("region-a".to_string(), async move { + loads.fetch_add(1, Ordering::Relaxed); + sleep(Duration::from_millis(50)).await; + Ok(test_network()) + }) + .await + .map(|network| network.node_count()) + } + }; + + let (left, right) = tokio::join!(first, second); + left.expect("first load should succeed"); + right.expect("second load should succeed"); + + assert_eq!(loads.load(Ordering::Relaxed), 1); + } +}