diff --git a/src/lib.rs b/src/lib.rs index e1b0dab..be1a3ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ pub mod routing; pub use geometry::{decode_polyline, encode_polyline, EncodedSegment}; pub use routing::{ - haversine_distance, BBoxError, BoundingBox, CacheStats, Coord, CoordError, NetworkConfig, - NetworkRef, Objective, RoadNetwork, RouteResult, RoutingError, RoutingProgress, RoutingResult, - SnappedCoord, SpeedProfile, TravelTimeMatrix, UNREACHABLE, + haversine_distance, BBoxError, BoundingBox, CacheStats, ConnectivityPolicy, Coord, CoordError, + NetworkConfig, NetworkRef, Objective, RoadNetwork, RouteResult, RoutingError, RoutingProgress, + RoutingResult, SnappedCoord, SpeedProfile, TravelTimeMatrix, UNREACHABLE, }; diff --git a/src/routing/config.rs b/src/routing/config.rs index 231a3d4..c16abee 100644 --- a/src/routing/config.rs +++ b/src/routing/config.rs @@ -3,6 +3,12 @@ use std::path::PathBuf; use std::time::Duration; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectivityPolicy { + KeepAll, + LargestStronglyConnectedComponent, +} + #[derive(Debug, Clone)] pub struct SpeedProfile { pub motorway: f64, @@ -92,6 +98,7 @@ pub struct NetworkConfig { pub connect_timeout: Duration, pub read_timeout: Duration, pub speed_profile: SpeedProfile, + pub connectivity_policy: ConnectivityPolicy, pub highway_types: Vec<&'static str>, } @@ -103,6 +110,7 @@ impl Default for NetworkConfig { connect_timeout: Duration::from_secs(30), read_timeout: Duration::from_secs(180), speed_profile: SpeedProfile::default(), + connectivity_policy: ConnectivityPolicy::KeepAll, highway_types: vec![ "motorway", "trunk", @@ -148,6 +156,11 @@ impl NetworkConfig { self } + pub fn connectivity_policy(mut self, policy: ConnectivityPolicy) -> Self { + self.connectivity_policy = policy; + self + } + pub fn highway_types(mut self, types: Vec<&'static str>) -> Self { self.highway_types = types; self diff --git a/src/routing/fetch.rs b/src/routing/fetch.rs index 5a1ac8c..9a2c929 100644 --- a/src/routing/fetch.rs +++ b/src/routing/fetch.rs @@ -11,7 +11,7 @@ use super::cache::{ cache, record_hit, record_miss, CachedEdge, CachedNetwork, CachedNode, NetworkRef, CACHE_VERSION, }; -use super::config::NetworkConfig; +use super::config::{ConnectivityPolicy, NetworkConfig}; use super::coord::Coord; use super::error::RoutingError; use super::network::{EdgeData, RoadNetwork}; @@ -60,7 +60,7 @@ impl RoadNetwork { if let Some(tx) = progress { let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await; } - match Self::load_from_file(&cache_path).await { + match Self::load_from_file(&cache_path, config).await { Ok(n) => { if let Some(tx) = progress { let _ = tx @@ -314,19 +314,30 @@ out body;"#, way_count ); - // Filter to largest strongly connected component to ensure all nodes are reachable let scc_count = network.strongly_connected_components(); - if scc_count > 1 { - info!( - "Road network has {} SCCs, filtering to largest component", - scc_count - ); - network.filter_to_largest_scc(); - info!( - "After SCC filter: {} nodes, {} edges", - network.node_count(), - network.edge_count() - ); + match config.connectivity_policy { + ConnectivityPolicy::KeepAll => { + if scc_count > 1 { + info!( + "Road network has {} SCCs, preserving all components by configuration", + scc_count + ); + } + } + ConnectivityPolicy::LargestStronglyConnectedComponent => { + if scc_count > 1 { + info!( + "Road network has {} SCCs, filtering to largest component", + scc_count + ); + network.filter_to_largest_scc(); + info!( + "After SCC filter: {} nodes, {} edges", + network.node_count(), + network.edge_count() + ); + } + } } network.build_spatial_index(); @@ -334,7 +345,7 @@ out body;"#, Ok(network) } - async fn load_from_file(path: &Path) -> Result { + async fn load_from_file(path: &Path, config: &NetworkConfig) -> Result { let data = tokio::fs::read_to_string(path).await?; let cached: CachedNetwork = match serde_json::from_str(&data) { @@ -365,19 +376,30 @@ out body;"#, network.add_edge_by_index(edge.from, edge.to, edge.travel_time_s, edge.distance_m); } - // Filter to largest SCC (cached networks from older versions may not be filtered) let scc_count = network.strongly_connected_components(); - if scc_count > 1 { - info!( - "Cached network has {} SCCs, filtering to largest component", - scc_count - ); - network.filter_to_largest_scc(); - info!( - "After SCC filter: {} nodes, {} edges", - network.node_count(), - network.edge_count() - ); + match config.connectivity_policy { + ConnectivityPolicy::KeepAll => { + if scc_count > 1 { + info!( + "Cached network has {} SCCs, preserving all components by configuration", + scc_count + ); + } + } + ConnectivityPolicy::LargestStronglyConnectedComponent => { + if scc_count > 1 { + info!( + "Cached network has {} SCCs, filtering to largest component", + scc_count + ); + network.filter_to_largest_scc(); + info!( + "After SCC filter: {} nodes, {} edges", + network.node_count(), + network.edge_count() + ); + } + } } network.build_spatial_index(); diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 0cc40eb..1a2fd0a 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -17,7 +17,7 @@ mod spatial; pub use bbox::BoundingBox; pub use cache::{CacheStats, NetworkRef}; -pub use config::{NetworkConfig, SpeedProfile}; +pub use config::{ConnectivityPolicy, NetworkConfig, SpeedProfile}; pub use coord::Coord; pub use error::{BBoxError, CoordError, RoutingError}; pub use matrix::{TravelTimeMatrix, UNREACHABLE}; diff --git a/tests/integration.rs b/tests/integration.rs index 3788579..e2937a1 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -4,8 +4,9 @@ use std::path::PathBuf; use std::time::Duration; use solverforge_maps::{ - decode_polyline, encode_polyline, haversine_distance, BBoxError, BoundingBox, Coord, - CoordError, NetworkConfig, RoadNetwork, RouteResult, RoutingError, SpeedProfile, UNREACHABLE, + decode_polyline, encode_polyline, haversine_distance, BBoxError, BoundingBox, + ConnectivityPolicy, Coord, CoordError, NetworkConfig, RoadNetwork, RouteResult, RoutingError, + SpeedProfile, UNREACHABLE, }; mod types { @@ -140,11 +141,16 @@ mod types { let config = NetworkConfig::new() .overpass_url("https://custom.api/interpreter") .cache_dir("/tmp/cache") - .connect_timeout(Duration::from_secs(60)); + .connect_timeout(Duration::from_secs(60)) + .connectivity_policy(ConnectivityPolicy::LargestStronglyConnectedComponent); assert_eq!(config.overpass_url, "https://custom.api/interpreter"); assert_eq!(config.cache_dir, PathBuf::from("/tmp/cache")); assert_eq!(config.connect_timeout, Duration::from_secs(60)); + assert_eq!( + config.connectivity_policy, + ConnectivityPolicy::LargestStronglyConnectedComponent + ); } #[test] @@ -156,6 +162,12 @@ mod types { let maxspeed_mps = profile.speed_mps(Some("50"), "motorway"); assert!((maxspeed_mps - 13.889).abs() < 0.1); } + + #[test] + fn default_connectivity_policy_keeps_all_components() { + let config = NetworkConfig::default(); + assert_eq!(config.connectivity_policy, ConnectivityPolicy::KeepAll); + } } } @@ -242,6 +254,27 @@ mod routing { assert_eq!(network.strongly_connected_components(), 0); assert!((network.largest_component_fraction() - 0.0).abs() < f64::EPSILON); } + + #[test] + fn largest_scc_filter_is_opt_in() { + let mut network = RoadNetwork::from_test_data( + &[(0.0, 0.0), (0.0, 1.0), (10.0, 10.0), (10.0, 11.0)], + &[ + (0, 1, 10.0, 100.0), + (1, 0, 10.0, 100.0), + (2, 3, 10.0, 100.0), + ], + ); + + assert_eq!(network.node_count(), 4); + assert_eq!(network.strongly_connected_components(), 3); + + network.filter_to_largest_scc(); + + assert_eq!(network.node_count(), 2); + assert_eq!(network.edge_count(), 2); + assert_eq!(network.strongly_connected_components(), 1); + } } mod route_simplify {