diff --git a/Cargo.toml b/Cargo.toml index a2132bc..0a7e2f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ rayon = "1.11" reqwest = { version = "0.13", features = ["json"] } serde = { version = "1", features = ["derive"] } serde_json = "1" -tokio = { version = "1", features = ["fs", "sync", "rt", "rt-multi-thread", "macros"] } +tokio = { version = "1", features = ["fs", "sync", "rt", "rt-multi-thread", "macros", "time"] } tracing = "0.1" utoipa = "5" diff --git a/src/routing/config.rs b/src/routing/config.rs index c16abee..22a813f 100644 --- a/src/routing/config.rs +++ b/src/routing/config.rs @@ -94,6 +94,9 @@ impl SpeedProfile { #[derive(Debug, Clone)] pub struct NetworkConfig { pub overpass_url: String, + pub overpass_endpoints: Vec, + pub overpass_max_retries: usize, + pub overpass_retry_backoff: Duration, pub cache_dir: PathBuf, pub connect_timeout: Duration, pub read_timeout: Duration, @@ -104,8 +107,12 @@ pub struct NetworkConfig { impl Default for NetworkConfig { fn default() -> Self { + let default_overpass = "https://overpass-api.de/api/interpreter".to_string(); Self { - overpass_url: "https://overpass-api.de/api/interpreter".to_string(), + overpass_url: default_overpass.clone(), + overpass_endpoints: vec![default_overpass], + overpass_max_retries: 2, + overpass_retry_backoff: Duration::from_secs(2), cache_dir: PathBuf::from(".osm_cache"), connect_timeout: Duration::from_secs(30), read_timeout: Duration::from_secs(180), @@ -132,7 +139,27 @@ impl NetworkConfig { } pub fn overpass_url(mut self, url: impl Into) -> Self { - self.overpass_url = url.into(); + let url = url.into(); + self.overpass_url = url.clone(); + self.overpass_endpoints = vec![url]; + self + } + + pub fn overpass_endpoints(mut self, urls: Vec) -> Self { + if let Some(primary) = urls.first().cloned() { + self.overpass_url = primary; + self.overpass_endpoints = urls; + } + self + } + + pub fn overpass_max_retries(mut self, retries: usize) -> Self { + self.overpass_max_retries = retries; + self + } + + pub fn overpass_retry_backoff(mut self, backoff: Duration) -> Self { + self.overpass_retry_backoff = backoff; self } diff --git a/src/routing/fetch.rs b/src/routing/fetch.rs index eb2b515..f27f3bc 100644 --- a/src/routing/fetch.rs +++ b/src/routing/fetch.rs @@ -4,8 +4,10 @@ use std::collections::HashMap; use std::future::Future; use std::path::Path; use std::sync::Arc; +use std::time::Duration; use tokio::sync::{mpsc::Sender, Mutex, OwnedMutexGuard}; +use tokio::time::sleep; use tracing::{debug, info}; use super::bbox::BoundingBox; @@ -131,8 +133,6 @@ out body;"#, .build() .map_err(|e| RoutingError::Network(e.to_string()))?; - info!("Sending request to Overpass API..."); - if let Some(tx) = progress { let _ = tx .send(RoutingProgress::DownloadingNetwork { @@ -142,36 +142,7 @@ out body;"#, .await; } - let response = client - .post(&config.overpass_url) - .body(query) - .header("Content-Type", "text/plain") - .send() - .await - .map_err(|e| RoutingError::Network(e.to_string()))?; - - info!("Received response: status={}", response.status()); - - if !response.status().is_success() { - return Err(RoutingError::Network(format!( - "Overpass API returned status {}", - response.status() - ))); - } - - if let Some(tx) = progress { - let _ = tx - .send(RoutingProgress::DownloadingNetwork { - percent: 25, - bytes: 0, - }) - .await; - } - - let bytes = response - .bytes() - .await - .map_err(|e| RoutingError::Network(e.to_string()))?; + let bytes = fetch_overpass_bytes(&client, &query, config, progress).await?; let bytes_len = bytes.len(); if let Some(tx) = progress { @@ -467,6 +438,125 @@ out body;"#, } } +async fn fetch_overpass_bytes( + client: &reqwest::Client, + query: &str, + config: &NetworkConfig, + progress: Option<&Sender>, +) -> Result, RoutingError> { + let endpoints = overpass_endpoints(config); + let mut failures = Vec::new(); + + for (endpoint_index, endpoint) in endpoints.iter().enumerate() { + for attempt in 0..=config.overpass_max_retries { + info!( + "Sending request to Overpass API endpoint {} attempt {}: {}", + endpoint_index + 1, + attempt + 1, + endpoint + ); + + let response = client + .post(endpoint) + .body(query.to_owned()) + .header("Content-Type", "text/plain") + .send() + .await; + + match response { + Ok(response) if response.status().is_success() => { + info!( + "Received successful Overpass response from {} with status {}", + endpoint, + response.status() + ); + + if let Some(tx) = progress { + let _ = tx + .send(RoutingProgress::DownloadingNetwork { + percent: 25, + bytes: 0, + }) + .await; + } + + return response + .bytes() + .await + .map(|bytes| bytes.to_vec()) + .map_err(|error| { + RoutingError::Network(format!( + "Overpass response body read failed from {} on attempt {}: {}", + endpoint, + attempt + 1, + error + )) + }); + } + Ok(response) => { + let status = response.status(); + failures.push(format!( + "{} attempt {} returned HTTP {}", + endpoint, + attempt + 1, + status + )); + + if is_retryable_status(status) && attempt < config.overpass_max_retries { + sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await; + continue; + } + + break; + } + Err(error) => { + failures.push(format!( + "{} attempt {} failed: {}", + endpoint, + attempt + 1, + error + )); + + if is_retryable_error(&error) && attempt < config.overpass_max_retries { + sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await; + continue; + } + + break; + } + } + } + } + + Err(RoutingError::Network(format!( + "Overpass fetch failed after trying {} endpoint(s): {}", + endpoints.len(), + failures.join("; ") + ))) +} + +fn overpass_endpoints(config: &NetworkConfig) -> Vec { + if config.overpass_endpoints.is_empty() { + vec![config.overpass_url.clone()] + } else { + config.overpass_endpoints.clone() + } +} + +fn retry_backoff(base: Duration, attempt: usize) -> Duration { + base.saturating_mul((attempt + 1) as u32) +} + +fn is_retryable_status(status: reqwest::StatusCode) -> bool { + status.is_server_error() + || status == reqwest::StatusCode::TOO_MANY_REQUESTS + || status == reqwest::StatusCode::REQUEST_TIMEOUT +} + +fn is_retryable_error(error: &reqwest::Error) -> bool { + error.is_timeout() || error.is_connect() || error.is_request() +} + impl RoadNetwork { #[doc(hidden)] pub async fn load_or_fetch_simple(bbox: &BoundingBox) -> Result { @@ -501,13 +591,17 @@ async fn cleanup_in_flight_slot(cache_key: &str, slot: &Arc>) { #[cfg(test)] mod tests { + use std::io::{Read, Write}; + use std::net::TcpListener; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; + use std::thread; use std::time::{Duration, Instant}; use tokio::time::sleep; use super::*; + use crate::routing::BoundingBox; fn test_network() -> RoadNetwork { RoadNetwork::from_test_data(&[(0.0, 0.0), (0.0, 0.01)], &[(0, 1, 60.0, 1_000.0)]) @@ -586,4 +680,94 @@ mod tests { assert_eq!(loads.load(Ordering::Relaxed), 1); } + + fn overpass_fixture_json() -> &'static str { + r#"{ + "elements": [ + {"type": "node", "id": 1, "lat": 39.95, "lon": -75.16}, + {"type": "node", "id": 2, "lat": 39.96, "lon": -75.17}, + {"type": "way", "id": 10, "nodes": [1, 2], "tags": {"highway": "residential"}} + ] + }"# + } + + fn spawn_overpass_server( + responses: Vec<(&'static str, &'static str)>, + ) -> (String, Arc, thread::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind"); + let address = format!( + "http://{}/api/interpreter", + listener.local_addr().expect("listener addr") + ); + let requests = Arc::new(AtomicUsize::new(0)); + let served = requests.clone(); + + let handle = thread::spawn(move || { + for (status, body) in responses { + let (mut stream, _) = listener.accept().expect("connection should arrive"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer); + let response = format!( + "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}", + status, + body.len(), + body + ); + stream + .write_all(response.as_bytes()) + .expect("response should write"); + served.fetch_add(1, Ordering::Relaxed); + } + }); + + (address, requests, handle) + } + + #[tokio::test] + async fn fetch_retries_same_endpoint_until_success() { + let (endpoint, requests, handle) = spawn_overpass_server(vec![ + ("429 Too Many Requests", r#"{"elements":[]}"#), + ("200 OK", overpass_fixture_json()), + ]); + + let bbox = BoundingBox::try_new(39.94, -75.18, 39.97, -75.15).expect("bbox should build"); + let config = NetworkConfig::default() + .overpass_url(endpoint) + .overpass_max_retries(1) + .overpass_retry_backoff(Duration::from_millis(1)); + + let network = RoadNetwork::fetch(&bbox, &config, None) + .await + .expect("fetch should succeed after retry"); + + assert_eq!(network.node_count(), 2); + assert_eq!(requests.load(Ordering::Relaxed), 2); + handle.join().expect("server should join"); + } + + #[tokio::test] + async fn fetch_falls_back_to_second_endpoint() { + let (primary, primary_requests, primary_handle) = + spawn_overpass_server(vec![("503 Service Unavailable", r#"{"elements":[]}"#)]); + let (secondary, secondary_requests, secondary_handle) = + spawn_overpass_server(vec![("200 OK", overpass_fixture_json())]); + + let bbox = BoundingBox::try_new(39.94, -75.18, 39.97, -75.15).expect("bbox should build"); + let config = NetworkConfig::default() + .overpass_endpoints(vec![primary, secondary]) + .overpass_max_retries(0) + .overpass_retry_backoff(Duration::from_millis(1)); + + let network = RoadNetwork::fetch(&bbox, &config, None) + .await + .expect("fetch should fall back to second endpoint"); + + assert_eq!(network.node_count(), 2); + assert_eq!(primary_requests.load(Ordering::Relaxed), 1); + assert_eq!(secondary_requests.load(Ordering::Relaxed), 1); + primary_handle.join().expect("primary server should join"); + secondary_handle + .join() + .expect("secondary server should join"); + } } diff --git a/tests/integration.rs b/tests/integration.rs index e2937a1..c983a52 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -133,6 +133,11 @@ mod types { config.overpass_url, "https://overpass-api.de/api/interpreter" ); + assert_eq!( + config.overpass_endpoints, + vec!["https://overpass-api.de/api/interpreter".to_string()] + ); + assert_eq!(config.overpass_max_retries, 2); assert_eq!(config.cache_dir, PathBuf::from(".osm_cache")); } @@ -140,11 +145,19 @@ mod types { fn builder_pattern() { let config = NetworkConfig::new() .overpass_url("https://custom.api/interpreter") + .overpass_max_retries(4) + .overpass_retry_backoff(Duration::from_secs(3)) .cache_dir("/tmp/cache") .connect_timeout(Duration::from_secs(60)) .connectivity_policy(ConnectivityPolicy::LargestStronglyConnectedComponent); assert_eq!(config.overpass_url, "https://custom.api/interpreter"); + assert_eq!( + config.overpass_endpoints, + vec!["https://custom.api/interpreter".to_string()] + ); + assert_eq!(config.overpass_max_retries, 4); + assert_eq!(config.overpass_retry_backoff, Duration::from_secs(3)); assert_eq!(config.cache_dir, PathBuf::from("/tmp/cache")); assert_eq!(config.connect_timeout, Duration::from_secs(60)); assert_eq!( @@ -153,6 +166,23 @@ mod types { ); } + #[test] + fn overpass_endpoint_pool_builder() { + let config = NetworkConfig::new().overpass_endpoints(vec![ + "https://a.example/api/interpreter".to_string(), + "https://b.example/api/interpreter".to_string(), + ]); + + assert_eq!(config.overpass_url, "https://a.example/api/interpreter"); + assert_eq!( + config.overpass_endpoints, + vec![ + "https://a.example/api/interpreter".to_string(), + "https://b.example/api/interpreter".to_string(), + ] + ); + } + #[test] fn speed_profile() { let profile = SpeedProfile::default();