Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/routing/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,30 @@ use std::collections::HashMap;
use std::mem::size_of;
use std::ops::Deref;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::OnceLock;

use serde::{Deserialize, Serialize};
use tokio::sync::{RwLock, RwLockReadGuard};
use tokio::sync::{Mutex, RwLock, RwLockReadGuard};

use super::bbox::BoundingBox;
use super::network::RoadNetwork;

pub const CACHE_VERSION: u32 = 5;

static NETWORK_CACHE: OnceLock<RwLock<HashMap<String, RoadNetwork>>> = OnceLock::new();
static IN_FLIGHT_LOADS: OnceLock<Mutex<HashMap<String, Arc<Mutex<()>>>>> = OnceLock::new();
static CACHE_HITS: AtomicU64 = AtomicU64::new(0);
static CACHE_MISSES: AtomicU64 = AtomicU64::new(0);

pub(crate) fn cache() -> &'static RwLock<HashMap<String, RoadNetwork>> {
NETWORK_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
}

pub(crate) fn in_flight_loads() -> &'static Mutex<HashMap<String, Arc<Mutex<()>>>> {
IN_FLIGHT_LOADS.get_or_init(|| Mutex::new(HashMap::new()))
}

pub(crate) fn record_hit() {
CACHE_HITS.fetch_add(1, Ordering::Relaxed);
}
Expand Down
227 changes: 186 additions & 41 deletions src/routing/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
//! Overpass API fetching and caching for road networks.

use std::collections::HashMap;
use std::future::Future;
use std::path::Path;
use std::sync::Arc;

use tokio::sync::mpsc::Sender;
use tokio::sync::{mpsc::Sender, Mutex, OwnedMutexGuard};
use tracing::{debug, info};

use super::bbox::BoundingBox;
use super::cache::{
cache, record_hit, record_miss, CachedEdge, CachedNetwork, CachedNode, NetworkRef,
CACHE_VERSION,
cache, in_flight_loads, record_hit, record_miss, CachedEdge, CachedNetwork, CachedNode,
NetworkRef, CACHE_VERSION,
};
use super::config::{ConnectivityPolicy, NetworkConfig};
use super::coord::Coord;
Expand Down Expand Up @@ -49,48 +51,36 @@ impl RoadNetwork {
let _ = tx.send(RoutingProgress::CheckingCache { percent: 5 }).await;
}

{
let mut cache_guard = cache().write().await;
if !cache_guard.contains_key(&cache_key) {
tokio::fs::create_dir_all(&config.cache_dir).await?;
let cache_path = config.cache_dir.join(format!("{}.json", cache_key));

let network = if tokio::fs::try_exists(&cache_path).await.unwrap_or(false) {
info!("Loading road network from file cache: {:?}", cache_path);
if let Some(tx) = progress {
let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await;
}
match Self::load_from_file(&cache_path, config).await {
Ok(n) => {
if let Some(tx) = progress {
let _ = tx
.send(RoutingProgress::BuildingGraph { percent: 50 })
.await;
}
n
}
Err(e) => {
info!("File cache invalid ({}), downloading fresh", e);
let n = Self::fetch_from_api(bbox, config, progress).await?;
n.save_to_file(&cache_path).await?;
info!("Saved road network to file cache: {:?}", cache_path);
n
Self::load_or_insert(cache_key, async {
tokio::fs::create_dir_all(&config.cache_dir).await?;
let cache_path = config.cache_dir.join(format!("{}.json", bbox.cache_key()));

if tokio::fs::try_exists(&cache_path).await.unwrap_or(false) {
info!("Loading road network from file cache: {:?}", cache_path);
if let Some(tx) = progress {
let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await;
}
match Self::load_from_file(&cache_path, config).await {
Ok(network) => {
if let Some(tx) = progress {
let _ = tx
.send(RoutingProgress::BuildingGraph { percent: 50 })
.await;
}
return Ok(network);
}
} else {
info!("Downloading road network from Overpass API");
let n = Self::fetch_from_api(bbox, config, progress).await?;
n.save_to_file(&cache_path).await?;
info!("Saved road network to file cache: {:?}", cache_path);
n
};

cache_guard.insert(cache_key.clone(), network);
Err(e) => info!("File cache invalid ({}), downloading fresh", e),
}
} else {
info!("Downloading road network from Overpass API");
}
}

let cache_guard = cache().read().await;
Ok(NetworkRef::new(cache_guard, cache_key))
let network = Self::fetch_from_api(bbox, config, progress).await?;
network.save_to_file(&cache_path).await?;
info!("Saved road network to file cache: {:?}", cache_path);
Ok(network)
})
.await
}

pub async fn fetch(
Expand Down Expand Up @@ -434,6 +424,47 @@ out body;"#,

Ok(())
}

async fn load_or_insert<F>(cache_key: String, load: F) -> Result<NetworkRef, RoutingError>
where
F: Future<Output = Result<RoadNetwork, RoutingError>>,
{
if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
return Ok(cached);
}

record_miss();

let (slot, _slot_guard) = acquire_in_flight_slot(&cache_key).await;

if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
cleanup_in_flight_slot(&cache_key, &slot).await;
return Ok(cached);
}

let network = load.await?;

{
let mut cache_guard = cache().write().await;
cache_guard.entry(cache_key.clone()).or_insert(network);
}

cleanup_in_flight_slot(&cache_key, &slot).await;

Self::get_cached_network(cache_key).await.ok_or_else(|| {
RoutingError::Network("cached network disappeared after insertion".to_string())
})
}

async fn get_cached_network(cache_key: String) -> Option<NetworkRef> {
let cache_guard = cache().read().await;
if cache_guard.contains_key(&cache_key) {
record_hit();
Some(NetworkRef::new(cache_guard, cache_key))
} else {
None
}
}
}

impl RoadNetwork {
Expand All @@ -442,3 +473,117 @@ impl RoadNetwork {
Self::load_or_fetch(bbox, &NetworkConfig::default(), None).await
}
}

async fn acquire_in_flight_slot(cache_key: &str) -> (Arc<Mutex<()>>, OwnedMutexGuard<()>) {
let slot = {
let mut in_flight = in_flight_loads().lock().await;
in_flight
.entry(cache_key.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
};

let guard = slot.clone().lock_owned().await;
(slot, guard)
}

async fn cleanup_in_flight_slot(cache_key: &str, slot: &Arc<Mutex<()>>) {
let mut in_flight = in_flight_loads().lock().await;
let should_remove = in_flight
.get(cache_key)
.map(|current| Arc::ptr_eq(current, slot) && Arc::strong_count(slot) == 2)
.unwrap_or(false);
Comment on lines +492 to +495
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Drop _slot_guard before pruning IN_FLIGHT_LOADS

At both cleanup call sites in load_or_insert(), _slot_guard and slot are still alive, so Arc::strong_count(slot) is at least 3 here (the map entry, the local slot, and the OwnedMutexGuard). That makes should_remove permanently false, so every distinct cache_key leaves an Arc<Mutex<()>> behind in IN_FLIGHT_LOADS even after the load completes. In a long-lived service that sees many one-off bounding boxes, this side map will grow without bound.

Useful? React with 👍 / 👎.


if should_remove {
in_flight.remove(cache_key);
}
}

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};

use tokio::time::sleep;

use super::*;

fn test_network() -> RoadNetwork {
RoadNetwork::from_test_data(&[(0.0, 0.0), (0.0, 0.01)], &[(0, 1, 60.0, 1_000.0)])
}

async fn reset_test_state() {
RoadNetwork::clear_cache().await;
in_flight_loads().lock().await.clear();
}

#[tokio::test]
async fn load_or_insert_allows_different_keys_to_progress_concurrently() {
reset_test_state().await;

let start = Instant::now();
let first = async {
RoadNetwork::load_or_insert("region-a".to_string(), async {
sleep(Duration::from_millis(100)).await;
Ok(test_network())
})
.await
.map(|network| network.node_count())
};
let second = async {
RoadNetwork::load_or_insert("region-b".to_string(), async {
sleep(Duration::from_millis(100)).await;
Ok(test_network())
})
.await
.map(|network| network.node_count())
};
let (left, right) = tokio::join!(first, second);
left.expect("first load should succeed");
right.expect("second load should succeed");

assert!(
start.elapsed() < Duration::from_millis(180),
"different keys should not serialize slow loads"
);
}

#[tokio::test]
async fn load_or_insert_deduplicates_same_key_work() {
reset_test_state().await;

let loads = Arc::new(AtomicUsize::new(0));

let first = {
let loads = loads.clone();
async move {
RoadNetwork::load_or_insert("region-a".to_string(), async move {
loads.fetch_add(1, Ordering::Relaxed);
sleep(Duration::from_millis(50)).await;
Ok(test_network())
})
.await
.map(|network| network.node_count())
}
};
let second = {
let loads = loads.clone();
async move {
RoadNetwork::load_or_insert("region-a".to_string(), async move {
loads.fetch_add(1, Ordering::Relaxed);
sleep(Duration::from_millis(50)).await;
Ok(test_network())
})
.await
.map(|network| network.node_count())
}
};

let (left, right) = tokio::join!(first, second);
left.expect("first load should succeed");
right.expect("second load should succeed");

assert_eq!(loads.load(Ordering::Relaxed), 1);
}
}
Loading