diff --git a/src/routing/spatial.rs b/src/routing/spatial.rs index a9a2e25..d3f1da7 100644 --- a/src/routing/spatial.rs +++ b/src/routing/spatial.rs @@ -1,7 +1,7 @@ //! Zero-erasure spatial index implementation using K-D Trees. /// 2D point for spatial queries. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct Point2D { pub x: f64, pub y: f64, @@ -197,7 +197,7 @@ impl KdTree { } /// Segment in 2D space. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct Segment { pub from: Point2D, pub to: Point2D, @@ -244,88 +244,245 @@ impl Segment { } } +#[derive(Debug, Copy, Clone)] +struct BoundingBox2D { + min: Point2D, + max: Point2D, +} + +impl BoundingBox2D { + fn from_segment(segment: &Segment) -> Self { + Self { + min: Point2D::new( + segment.from.x.min(segment.to.x), + segment.from.y.min(segment.to.y), + ), + max: Point2D::new( + segment.from.x.max(segment.to.x), + segment.from.y.max(segment.to.y), + ), + } + } + + fn union(self, other: Self) -> Self { + Self { + min: Point2D::new(self.min.x.min(other.min.x), self.min.y.min(other.min.y)), + max: Point2D::new(self.max.x.max(other.max.x), self.max.y.max(other.max.y)), + } + } + + fn distance_squared_to_point(&self, point: &Point2D) -> f64 { + let dx = if point.x < self.min.x { + self.min.x - point.x + } else if point.x > self.max.x { + point.x - self.max.x + } else { + 0.0 + }; + let dy = if point.y < self.min.y { + self.min.y - point.y + } else if point.y > self.max.y { + point.y - self.max.y + } else { + 0.0 + }; + dx * dx + dy * dy + } +} + +struct SegmentNode { + segment: Segment, + data: T, + bounds: BoundingBox2D, + left: Option, + right: Option, +} + /// Spatial index for line segments. /// -/// Uses a K-D Tree on segment centroids, then refines with actual segment distance. +/// Uses a branch-and-bound K-D tree over segment centroids with exact subtree bounds. pub struct SegmentIndex { - /// K-D Tree indexed by segment centroid - tree: KdTree<(Segment, T)>, - /// All segments for brute-force refinement within candidate set - segments: Vec<(Segment, T)>, + nodes: Vec>, + root: Option, } -impl SegmentIndex { +impl SegmentIndex { /// Build a segment index from a list of segments with associated data. pub fn bulk_load(segments: Vec<(Segment, T)>) -> Self { - let items: Vec<(Point2D, (Segment, T))> = segments + let mut indexed: Vec<(usize, Point2D)> = segments .iter() - .cloned() - .map(|(seg, data)| (seg.centroid(), (seg, data))) + .enumerate() + .map(|(i, (segment, _))| (i, segment.centroid())) .collect(); - - Self { - tree: KdTree::from_items(items), - segments, - } + let mut items: Vec> = segments.into_iter().map(Some).collect(); + let mut nodes = Vec::with_capacity(items.len()); + let root = Self::build_nodes(&mut indexed, &mut items, 0, &mut nodes); + Self { nodes, root } } /// Find the nearest segment to a query point. /// /// Returns the segment, associated data, projected point on segment, and squared distance. pub fn nearest_segment(&self, query: &Point2D) -> Option<(&Segment, &T, Point2D, f64)> { - if self.segments.is_empty() { + let root = self.root?; + let mut best: Option<(usize, Point2D, f64)> = None; + self.search_nearest(root, query, 0, &mut best); + let (idx, projection, dist) = best?; + let node = &self.nodes[idx]; + Some((&node.segment, &node.data, projection, dist)) + } + + fn build_nodes( + indexed: &mut [(usize, Point2D)], + items: &mut [Option<(Segment, T)>], + depth: usize, + nodes: &mut Vec>, + ) -> Option { + if indexed.is_empty() { return None; } - // For small datasets, just brute force - if self.segments.len() <= 100 { - return self.brute_force_nearest(query); - } + let axis = depth % 2; + indexed.sort_by(|a, b| { + let va = if axis == 0 { a.1.x } else { a.1.y }; + let vb = if axis == 0 { b.1.x } else { b.1.y }; + va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal) + }); - // Use K-D tree to find candidates, then refine - // First, get the nearest centroid as a starting point - let nearest_centroid = self.tree.nearest_neighbor_with_distance(query)?; - - // The actual nearest segment might be different from the one with nearest centroid - // We need to search within a radius equal to our best distance so far - let (seg, data) = nearest_centroid.0; - let (proj, _) = seg.project_point(query); - let mut best_dist = query.distance_squared(&proj); - let mut best_result = (seg, data, proj, best_dist); - - // Check all segments (for correctness; can be optimized with spatial queries) - // Since K-D tree on centroids doesn't give us tight bounds, we do a full scan - // This is still faster than pure brute force for construction + multiple queries - for (seg, data) in &self.segments { - let (proj, _) = seg.project_point(query); - let dist = query.distance_squared(&proj); - if dist < best_dist { - best_dist = dist; - best_result = (seg, data, proj, dist); - } + let mid = indexed.len() / 2; + let (left_slice, rest) = indexed.split_at_mut(mid); + let (mid_item, right_slice) = rest.split_first_mut().expect("not empty"); + let left = Self::build_nodes(left_slice, items, depth + 1, nodes); + let right = Self::build_nodes(right_slice, items, depth + 1, nodes); + + let (segment, data) = items[mid_item.0].take().expect("item already taken"); + let bounds = BoundingBox2D::from_segment(&segment); + let idx = nodes.len(); + nodes.push(SegmentNode { + segment, + data, + bounds, + left, + right, + }); + + let mut bounds = BoundingBox2D::from_segment(&nodes[idx].segment); + if let Some(left_idx) = left { + bounds = bounds.union(nodes[left_idx].bounds); } + if let Some(right_idx) = right { + bounds = bounds.union(nodes[right_idx].bounds); + } + nodes[idx].bounds = bounds; - Some(best_result) + Some(idx) } - fn brute_force_nearest(&self, query: &Point2D) -> Option<(&Segment, &T, Point2D, f64)> { - let mut best: Option<(&Segment, &T, Point2D, f64)> = None; + fn search_nearest( + &self, + node_idx: usize, + query: &Point2D, + depth: usize, + best: &mut Option<(usize, Point2D, f64)>, + ) { + let node = &self.nodes[node_idx]; + let (projection, _) = node.segment.project_point(query); + let dist = query.distance_squared(&projection); - for (seg, data) in &self.segments { - let (proj, _) = seg.project_point(query); - let dist = query.distance_squared(&proj); + match best { + Some((_, _, best_dist)) if dist < *best_dist => { + *best = Some((node_idx, projection, dist)); + } + None => { + *best = Some((node_idx, projection, dist)); + } + _ => {} + } - match &best { - Some((_, _, _, best_dist)) if dist < *best_dist => { - best = Some((seg, data, proj, dist)); - } - None => { - best = Some((seg, data, proj, dist)); - } - _ => {} + let axis = depth % 2; + let centroid = node.segment.centroid(); + let query_val = if axis == 0 { query.x } else { query.y }; + let node_val = if axis == 0 { centroid.x } else { centroid.y }; + let (first, second) = if query_val <= node_val { + (node.left, node.right) + } else { + (node.right, node.left) + }; + + if let Some(child) = first { + self.search_child(child, query, depth + 1, best); + } + if let Some(child) = second { + self.search_child(child, query, depth + 1, best); + } + } + + fn search_child( + &self, + child_idx: usize, + query: &Point2D, + depth: usize, + best: &mut Option<(usize, Point2D, f64)>, + ) { + if let Some((_, _, best_dist)) = best { + let child_dist = self.nodes[child_idx] + .bounds + .distance_squared_to_point(query); + if child_dist > *best_dist { + return; } } - best + self.search_nearest(child_idx, query, depth, best); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn segment_index_returns_none_when_empty() { + let index: SegmentIndex = SegmentIndex::bulk_load(vec![]); + assert!(index.nearest_segment(&Point2D::new(0.0, 0.0)).is_none()); + } + + #[test] + fn segment_index_matches_bruteforce_on_large_input() { + let segments: Vec<(Segment, usize)> = (0..256) + .map(|i| { + let y = i as f64 * 10.0; + ( + Segment::new(Point2D::new(0.0, y), Point2D::new(5.0, y + 1.0)), + i, + ) + }) + .collect(); + let index = SegmentIndex::bulk_load(segments.clone()); + let query = Point2D::new(2.25, 1234.4); + + let indexed = index + .nearest_segment(&query) + .expect("expected nearest segment"); + + let brute_force = segments + .iter() + .map(|(segment, data)| { + let (projection, _) = segment.project_point(&query); + ( + *data, + projection, + query.distance_squared(&projection), + *segment, + ) + }) + .min_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal)) + .expect("expected brute-force result"); + + assert_eq!(*indexed.1, brute_force.0); + assert_eq!(*indexed.0, brute_force.3); + assert!((indexed.2.x - brute_force.1.x).abs() < 1e-9); + assert!((indexed.2.y - brute_force.1.y).abs() < 1e-9); + assert!((indexed.3 - brute_force.2).abs() < 1e-9); } }