|
1 | 1 | //! Zero-erasure spatial index implementation using K-D Trees. |
2 | 2 |
|
3 | 3 | /// 2D point for spatial queries. |
4 | | -#[derive(Debug, Copy, Clone)] |
| 4 | +#[derive(Debug, Copy, Clone, PartialEq)] |
5 | 5 | pub struct Point2D { |
6 | 6 | pub x: f64, |
7 | 7 | pub y: f64, |
@@ -197,7 +197,7 @@ impl<T> KdTree<T> { |
197 | 197 | } |
198 | 198 |
|
199 | 199 | /// Segment in 2D space. |
200 | | -#[derive(Debug, Copy, Clone)] |
| 200 | +#[derive(Debug, Copy, Clone, PartialEq)] |
201 | 201 | pub struct Segment { |
202 | 202 | pub from: Point2D, |
203 | 203 | pub to: Point2D, |
@@ -244,88 +244,245 @@ impl Segment { |
244 | 244 | } |
245 | 245 | } |
246 | 246 |
|
| 247 | +#[derive(Debug, Copy, Clone)] |
| 248 | +struct BoundingBox2D { |
| 249 | + min: Point2D, |
| 250 | + max: Point2D, |
| 251 | +} |
| 252 | + |
| 253 | +impl BoundingBox2D { |
| 254 | + fn from_segment(segment: &Segment) -> Self { |
| 255 | + Self { |
| 256 | + min: Point2D::new( |
| 257 | + segment.from.x.min(segment.to.x), |
| 258 | + segment.from.y.min(segment.to.y), |
| 259 | + ), |
| 260 | + max: Point2D::new( |
| 261 | + segment.from.x.max(segment.to.x), |
| 262 | + segment.from.y.max(segment.to.y), |
| 263 | + ), |
| 264 | + } |
| 265 | + } |
| 266 | + |
| 267 | + fn union(self, other: Self) -> Self { |
| 268 | + Self { |
| 269 | + min: Point2D::new(self.min.x.min(other.min.x), self.min.y.min(other.min.y)), |
| 270 | + max: Point2D::new(self.max.x.max(other.max.x), self.max.y.max(other.max.y)), |
| 271 | + } |
| 272 | + } |
| 273 | + |
| 274 | + fn distance_squared_to_point(&self, point: &Point2D) -> f64 { |
| 275 | + let dx = if point.x < self.min.x { |
| 276 | + self.min.x - point.x |
| 277 | + } else if point.x > self.max.x { |
| 278 | + point.x - self.max.x |
| 279 | + } else { |
| 280 | + 0.0 |
| 281 | + }; |
| 282 | + let dy = if point.y < self.min.y { |
| 283 | + self.min.y - point.y |
| 284 | + } else if point.y > self.max.y { |
| 285 | + point.y - self.max.y |
| 286 | + } else { |
| 287 | + 0.0 |
| 288 | + }; |
| 289 | + dx * dx + dy * dy |
| 290 | + } |
| 291 | +} |
| 292 | + |
| 293 | +struct SegmentNode<T> { |
| 294 | + segment: Segment, |
| 295 | + data: T, |
| 296 | + bounds: BoundingBox2D, |
| 297 | + left: Option<usize>, |
| 298 | + right: Option<usize>, |
| 299 | +} |
| 300 | + |
247 | 301 | /// Spatial index for line segments. |
248 | 302 | /// |
249 | | -/// Uses a K-D Tree on segment centroids, then refines with actual segment distance. |
| 303 | +/// Uses a branch-and-bound K-D tree over segment centroids with exact subtree bounds. |
250 | 304 | pub struct SegmentIndex<T> { |
251 | | - /// K-D Tree indexed by segment centroid |
252 | | - tree: KdTree<(Segment, T)>, |
253 | | - /// All segments for brute-force refinement within candidate set |
254 | | - segments: Vec<(Segment, T)>, |
| 305 | + nodes: Vec<SegmentNode<T>>, |
| 306 | + root: Option<usize>, |
255 | 307 | } |
256 | 308 |
|
257 | | -impl<T: Clone> SegmentIndex<T> { |
| 309 | +impl<T> SegmentIndex<T> { |
258 | 310 | /// Build a segment index from a list of segments with associated data. |
259 | 311 | pub fn bulk_load(segments: Vec<(Segment, T)>) -> Self { |
260 | | - let items: Vec<(Point2D, (Segment, T))> = segments |
| 312 | + let mut indexed: Vec<(usize, Point2D)> = segments |
261 | 313 | .iter() |
262 | | - .cloned() |
263 | | - .map(|(seg, data)| (seg.centroid(), (seg, data))) |
| 314 | + .enumerate() |
| 315 | + .map(|(i, (segment, _))| (i, segment.centroid())) |
264 | 316 | .collect(); |
265 | | - |
266 | | - Self { |
267 | | - tree: KdTree::from_items(items), |
268 | | - segments, |
269 | | - } |
| 317 | + let mut items: Vec<Option<(Segment, T)>> = segments.into_iter().map(Some).collect(); |
| 318 | + let mut nodes = Vec::with_capacity(items.len()); |
| 319 | + let root = Self::build_nodes(&mut indexed, &mut items, 0, &mut nodes); |
| 320 | + Self { nodes, root } |
270 | 321 | } |
271 | 322 |
|
272 | 323 | /// Find the nearest segment to a query point. |
273 | 324 | /// |
274 | 325 | /// Returns the segment, associated data, projected point on segment, and squared distance. |
275 | 326 | pub fn nearest_segment(&self, query: &Point2D) -> Option<(&Segment, &T, Point2D, f64)> { |
276 | | - if self.segments.is_empty() { |
| 327 | + let root = self.root?; |
| 328 | + let mut best: Option<(usize, Point2D, f64)> = None; |
| 329 | + self.search_nearest(root, query, 0, &mut best); |
| 330 | + let (idx, projection, dist) = best?; |
| 331 | + let node = &self.nodes[idx]; |
| 332 | + Some((&node.segment, &node.data, projection, dist)) |
| 333 | + } |
| 334 | + |
| 335 | + fn build_nodes( |
| 336 | + indexed: &mut [(usize, Point2D)], |
| 337 | + items: &mut [Option<(Segment, T)>], |
| 338 | + depth: usize, |
| 339 | + nodes: &mut Vec<SegmentNode<T>>, |
| 340 | + ) -> Option<usize> { |
| 341 | + if indexed.is_empty() { |
277 | 342 | return None; |
278 | 343 | } |
279 | 344 |
|
280 | | - // For small datasets, just brute force |
281 | | - if self.segments.len() <= 100 { |
282 | | - return self.brute_force_nearest(query); |
283 | | - } |
| 345 | + let axis = depth % 2; |
| 346 | + indexed.sort_by(|a, b| { |
| 347 | + let va = if axis == 0 { a.1.x } else { a.1.y }; |
| 348 | + let vb = if axis == 0 { b.1.x } else { b.1.y }; |
| 349 | + va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal) |
| 350 | + }); |
284 | 351 |
|
285 | | - // Use K-D tree to find candidates, then refine |
286 | | - // First, get the nearest centroid as a starting point |
287 | | - let nearest_centroid = self.tree.nearest_neighbor_with_distance(query)?; |
288 | | - |
289 | | - // The actual nearest segment might be different from the one with nearest centroid |
290 | | - // We need to search within a radius equal to our best distance so far |
291 | | - let (seg, data) = nearest_centroid.0; |
292 | | - let (proj, _) = seg.project_point(query); |
293 | | - let mut best_dist = query.distance_squared(&proj); |
294 | | - let mut best_result = (seg, data, proj, best_dist); |
295 | | - |
296 | | - // Check all segments (for correctness; can be optimized with spatial queries) |
297 | | - // Since K-D tree on centroids doesn't give us tight bounds, we do a full scan |
298 | | - // This is still faster than pure brute force for construction + multiple queries |
299 | | - for (seg, data) in &self.segments { |
300 | | - let (proj, _) = seg.project_point(query); |
301 | | - let dist = query.distance_squared(&proj); |
302 | | - if dist < best_dist { |
303 | | - best_dist = dist; |
304 | | - best_result = (seg, data, proj, dist); |
305 | | - } |
| 352 | + let mid = indexed.len() / 2; |
| 353 | + let (left_slice, rest) = indexed.split_at_mut(mid); |
| 354 | + let (mid_item, right_slice) = rest.split_first_mut().expect("not empty"); |
| 355 | + let left = Self::build_nodes(left_slice, items, depth + 1, nodes); |
| 356 | + let right = Self::build_nodes(right_slice, items, depth + 1, nodes); |
| 357 | + |
| 358 | + let (segment, data) = items[mid_item.0].take().expect("item already taken"); |
| 359 | + let bounds = BoundingBox2D::from_segment(&segment); |
| 360 | + let idx = nodes.len(); |
| 361 | + nodes.push(SegmentNode { |
| 362 | + segment, |
| 363 | + data, |
| 364 | + bounds, |
| 365 | + left, |
| 366 | + right, |
| 367 | + }); |
| 368 | + |
| 369 | + let mut bounds = BoundingBox2D::from_segment(&nodes[idx].segment); |
| 370 | + if let Some(left_idx) = left { |
| 371 | + bounds = bounds.union(nodes[left_idx].bounds); |
306 | 372 | } |
| 373 | + if let Some(right_idx) = right { |
| 374 | + bounds = bounds.union(nodes[right_idx].bounds); |
| 375 | + } |
| 376 | + nodes[idx].bounds = bounds; |
307 | 377 |
|
308 | | - Some(best_result) |
| 378 | + Some(idx) |
309 | 379 | } |
310 | 380 |
|
311 | | - fn brute_force_nearest(&self, query: &Point2D) -> Option<(&Segment, &T, Point2D, f64)> { |
312 | | - let mut best: Option<(&Segment, &T, Point2D, f64)> = None; |
| 381 | + fn search_nearest( |
| 382 | + &self, |
| 383 | + node_idx: usize, |
| 384 | + query: &Point2D, |
| 385 | + depth: usize, |
| 386 | + best: &mut Option<(usize, Point2D, f64)>, |
| 387 | + ) { |
| 388 | + let node = &self.nodes[node_idx]; |
| 389 | + let (projection, _) = node.segment.project_point(query); |
| 390 | + let dist = query.distance_squared(&projection); |
313 | 391 |
|
314 | | - for (seg, data) in &self.segments { |
315 | | - let (proj, _) = seg.project_point(query); |
316 | | - let dist = query.distance_squared(&proj); |
| 392 | + match best { |
| 393 | + Some((_, _, best_dist)) if dist < *best_dist => { |
| 394 | + *best = Some((node_idx, projection, dist)); |
| 395 | + } |
| 396 | + None => { |
| 397 | + *best = Some((node_idx, projection, dist)); |
| 398 | + } |
| 399 | + _ => {} |
| 400 | + } |
317 | 401 |
|
318 | | - match &best { |
319 | | - Some((_, _, _, best_dist)) if dist < *best_dist => { |
320 | | - best = Some((seg, data, proj, dist)); |
321 | | - } |
322 | | - None => { |
323 | | - best = Some((seg, data, proj, dist)); |
324 | | - } |
325 | | - _ => {} |
| 402 | + let axis = depth % 2; |
| 403 | + let centroid = node.segment.centroid(); |
| 404 | + let query_val = if axis == 0 { query.x } else { query.y }; |
| 405 | + let node_val = if axis == 0 { centroid.x } else { centroid.y }; |
| 406 | + let (first, second) = if query_val <= node_val { |
| 407 | + (node.left, node.right) |
| 408 | + } else { |
| 409 | + (node.right, node.left) |
| 410 | + }; |
| 411 | + |
| 412 | + if let Some(child) = first { |
| 413 | + self.search_child(child, query, depth + 1, best); |
| 414 | + } |
| 415 | + if let Some(child) = second { |
| 416 | + self.search_child(child, query, depth + 1, best); |
| 417 | + } |
| 418 | + } |
| 419 | + |
| 420 | + fn search_child( |
| 421 | + &self, |
| 422 | + child_idx: usize, |
| 423 | + query: &Point2D, |
| 424 | + depth: usize, |
| 425 | + best: &mut Option<(usize, Point2D, f64)>, |
| 426 | + ) { |
| 427 | + if let Some((_, _, best_dist)) = best { |
| 428 | + let child_dist = self.nodes[child_idx] |
| 429 | + .bounds |
| 430 | + .distance_squared_to_point(query); |
| 431 | + if child_dist > *best_dist { |
| 432 | + return; |
326 | 433 | } |
327 | 434 | } |
328 | 435 |
|
329 | | - best |
| 436 | + self.search_nearest(child_idx, query, depth, best); |
| 437 | + } |
| 438 | +} |
| 439 | + |
| 440 | +#[cfg(test)] |
| 441 | +mod tests { |
| 442 | + use super::*; |
| 443 | + |
| 444 | + #[test] |
| 445 | + fn segment_index_returns_none_when_empty() { |
| 446 | + let index: SegmentIndex<usize> = SegmentIndex::bulk_load(vec![]); |
| 447 | + assert!(index.nearest_segment(&Point2D::new(0.0, 0.0)).is_none()); |
| 448 | + } |
| 449 | + |
| 450 | + #[test] |
| 451 | + fn segment_index_matches_bruteforce_on_large_input() { |
| 452 | + let segments: Vec<(Segment, usize)> = (0..256) |
| 453 | + .map(|i| { |
| 454 | + let y = i as f64 * 10.0; |
| 455 | + ( |
| 456 | + Segment::new(Point2D::new(0.0, y), Point2D::new(5.0, y + 1.0)), |
| 457 | + i, |
| 458 | + ) |
| 459 | + }) |
| 460 | + .collect(); |
| 461 | + let index = SegmentIndex::bulk_load(segments.clone()); |
| 462 | + let query = Point2D::new(2.25, 1234.4); |
| 463 | + |
| 464 | + let indexed = index |
| 465 | + .nearest_segment(&query) |
| 466 | + .expect("expected nearest segment"); |
| 467 | + |
| 468 | + let brute_force = segments |
| 469 | + .iter() |
| 470 | + .map(|(segment, data)| { |
| 471 | + let (projection, _) = segment.project_point(&query); |
| 472 | + ( |
| 473 | + *data, |
| 474 | + projection, |
| 475 | + query.distance_squared(&projection), |
| 476 | + *segment, |
| 477 | + ) |
| 478 | + }) |
| 479 | + .min_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal)) |
| 480 | + .expect("expected brute-force result"); |
| 481 | + |
| 482 | + assert_eq!(*indexed.1, brute_force.0); |
| 483 | + assert_eq!(*indexed.0, brute_force.3); |
| 484 | + assert!((indexed.2.x - brute_force.1.x).abs() < 1e-9); |
| 485 | + assert!((indexed.2.y - brute_force.1.y).abs() < 1e-9); |
| 486 | + assert!((indexed.3 - brute_force.2).abs() < 1e-9); |
330 | 487 | } |
331 | 488 | } |
0 commit comments