Skip to content

Commit 91a86df

Browse files
authored
Merge pull request #18 from SolverForge/issue/5-make-nearest-segment-sublinear
spatial: make nearest-segment queries sublinear for large networks
2 parents e2d7b89 + ef81691 commit 91a86df

File tree

1 file changed

+214
-57
lines changed

1 file changed

+214
-57
lines changed

src/routing/spatial.rs

Lines changed: 214 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Zero-erasure spatial index implementation using K-D Trees.
22
33
/// 2D point for spatial queries.
4-
#[derive(Debug, Copy, Clone)]
4+
#[derive(Debug, Copy, Clone, PartialEq)]
55
pub struct Point2D {
66
pub x: f64,
77
pub y: f64,
@@ -197,7 +197,7 @@ impl<T> KdTree<T> {
197197
}
198198

199199
/// Segment in 2D space.
200-
#[derive(Debug, Copy, Clone)]
200+
#[derive(Debug, Copy, Clone, PartialEq)]
201201
pub struct Segment {
202202
pub from: Point2D,
203203
pub to: Point2D,
@@ -244,88 +244,245 @@ impl Segment {
244244
}
245245
}
246246

247+
#[derive(Debug, Copy, Clone)]
248+
struct BoundingBox2D {
249+
min: Point2D,
250+
max: Point2D,
251+
}
252+
253+
impl BoundingBox2D {
254+
fn from_segment(segment: &Segment) -> Self {
255+
Self {
256+
min: Point2D::new(
257+
segment.from.x.min(segment.to.x),
258+
segment.from.y.min(segment.to.y),
259+
),
260+
max: Point2D::new(
261+
segment.from.x.max(segment.to.x),
262+
segment.from.y.max(segment.to.y),
263+
),
264+
}
265+
}
266+
267+
fn union(self, other: Self) -> Self {
268+
Self {
269+
min: Point2D::new(self.min.x.min(other.min.x), self.min.y.min(other.min.y)),
270+
max: Point2D::new(self.max.x.max(other.max.x), self.max.y.max(other.max.y)),
271+
}
272+
}
273+
274+
fn distance_squared_to_point(&self, point: &Point2D) -> f64 {
275+
let dx = if point.x < self.min.x {
276+
self.min.x - point.x
277+
} else if point.x > self.max.x {
278+
point.x - self.max.x
279+
} else {
280+
0.0
281+
};
282+
let dy = if point.y < self.min.y {
283+
self.min.y - point.y
284+
} else if point.y > self.max.y {
285+
point.y - self.max.y
286+
} else {
287+
0.0
288+
};
289+
dx * dx + dy * dy
290+
}
291+
}
292+
293+
struct SegmentNode<T> {
294+
segment: Segment,
295+
data: T,
296+
bounds: BoundingBox2D,
297+
left: Option<usize>,
298+
right: Option<usize>,
299+
}
300+
247301
/// Spatial index for line segments.
248302
///
249-
/// Uses a K-D Tree on segment centroids, then refines with actual segment distance.
303+
/// Uses a branch-and-bound K-D tree over segment centroids with exact subtree bounds.
250304
pub struct SegmentIndex<T> {
251-
/// K-D Tree indexed by segment centroid
252-
tree: KdTree<(Segment, T)>,
253-
/// All segments for brute-force refinement within candidate set
254-
segments: Vec<(Segment, T)>,
305+
nodes: Vec<SegmentNode<T>>,
306+
root: Option<usize>,
255307
}
256308

257-
impl<T: Clone> SegmentIndex<T> {
309+
impl<T> SegmentIndex<T> {
258310
/// Build a segment index from a list of segments with associated data.
259311
pub fn bulk_load(segments: Vec<(Segment, T)>) -> Self {
260-
let items: Vec<(Point2D, (Segment, T))> = segments
312+
let mut indexed: Vec<(usize, Point2D)> = segments
261313
.iter()
262-
.cloned()
263-
.map(|(seg, data)| (seg.centroid(), (seg, data)))
314+
.enumerate()
315+
.map(|(i, (segment, _))| (i, segment.centroid()))
264316
.collect();
265-
266-
Self {
267-
tree: KdTree::from_items(items),
268-
segments,
269-
}
317+
let mut items: Vec<Option<(Segment, T)>> = segments.into_iter().map(Some).collect();
318+
let mut nodes = Vec::with_capacity(items.len());
319+
let root = Self::build_nodes(&mut indexed, &mut items, 0, &mut nodes);
320+
Self { nodes, root }
270321
}
271322

272323
/// Find the nearest segment to a query point.
273324
///
274325
/// Returns the segment, associated data, projected point on segment, and squared distance.
275326
pub fn nearest_segment(&self, query: &Point2D) -> Option<(&Segment, &T, Point2D, f64)> {
276-
if self.segments.is_empty() {
327+
let root = self.root?;
328+
let mut best: Option<(usize, Point2D, f64)> = None;
329+
self.search_nearest(root, query, 0, &mut best);
330+
let (idx, projection, dist) = best?;
331+
let node = &self.nodes[idx];
332+
Some((&node.segment, &node.data, projection, dist))
333+
}
334+
335+
fn build_nodes(
336+
indexed: &mut [(usize, Point2D)],
337+
items: &mut [Option<(Segment, T)>],
338+
depth: usize,
339+
nodes: &mut Vec<SegmentNode<T>>,
340+
) -> Option<usize> {
341+
if indexed.is_empty() {
277342
return None;
278343
}
279344

280-
// For small datasets, just brute force
281-
if self.segments.len() <= 100 {
282-
return self.brute_force_nearest(query);
283-
}
345+
let axis = depth % 2;
346+
indexed.sort_by(|a, b| {
347+
let va = if axis == 0 { a.1.x } else { a.1.y };
348+
let vb = if axis == 0 { b.1.x } else { b.1.y };
349+
va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
350+
});
284351

285-
// Use K-D tree to find candidates, then refine
286-
// First, get the nearest centroid as a starting point
287-
let nearest_centroid = self.tree.nearest_neighbor_with_distance(query)?;
288-
289-
// The actual nearest segment might be different from the one with nearest centroid
290-
// We need to search within a radius equal to our best distance so far
291-
let (seg, data) = nearest_centroid.0;
292-
let (proj, _) = seg.project_point(query);
293-
let mut best_dist = query.distance_squared(&proj);
294-
let mut best_result = (seg, data, proj, best_dist);
295-
296-
// Check all segments (for correctness; can be optimized with spatial queries)
297-
// Since K-D tree on centroids doesn't give us tight bounds, we do a full scan
298-
// This is still faster than pure brute force for construction + multiple queries
299-
for (seg, data) in &self.segments {
300-
let (proj, _) = seg.project_point(query);
301-
let dist = query.distance_squared(&proj);
302-
if dist < best_dist {
303-
best_dist = dist;
304-
best_result = (seg, data, proj, dist);
305-
}
352+
let mid = indexed.len() / 2;
353+
let (left_slice, rest) = indexed.split_at_mut(mid);
354+
let (mid_item, right_slice) = rest.split_first_mut().expect("not empty");
355+
let left = Self::build_nodes(left_slice, items, depth + 1, nodes);
356+
let right = Self::build_nodes(right_slice, items, depth + 1, nodes);
357+
358+
let (segment, data) = items[mid_item.0].take().expect("item already taken");
359+
let bounds = BoundingBox2D::from_segment(&segment);
360+
let idx = nodes.len();
361+
nodes.push(SegmentNode {
362+
segment,
363+
data,
364+
bounds,
365+
left,
366+
right,
367+
});
368+
369+
let mut bounds = BoundingBox2D::from_segment(&nodes[idx].segment);
370+
if let Some(left_idx) = left {
371+
bounds = bounds.union(nodes[left_idx].bounds);
306372
}
373+
if let Some(right_idx) = right {
374+
bounds = bounds.union(nodes[right_idx].bounds);
375+
}
376+
nodes[idx].bounds = bounds;
307377

308-
Some(best_result)
378+
Some(idx)
309379
}
310380

311-
fn brute_force_nearest(&self, query: &Point2D) -> Option<(&Segment, &T, Point2D, f64)> {
312-
let mut best: Option<(&Segment, &T, Point2D, f64)> = None;
381+
fn search_nearest(
382+
&self,
383+
node_idx: usize,
384+
query: &Point2D,
385+
depth: usize,
386+
best: &mut Option<(usize, Point2D, f64)>,
387+
) {
388+
let node = &self.nodes[node_idx];
389+
let (projection, _) = node.segment.project_point(query);
390+
let dist = query.distance_squared(&projection);
313391

314-
for (seg, data) in &self.segments {
315-
let (proj, _) = seg.project_point(query);
316-
let dist = query.distance_squared(&proj);
392+
match best {
393+
Some((_, _, best_dist)) if dist < *best_dist => {
394+
*best = Some((node_idx, projection, dist));
395+
}
396+
None => {
397+
*best = Some((node_idx, projection, dist));
398+
}
399+
_ => {}
400+
}
317401

318-
match &best {
319-
Some((_, _, _, best_dist)) if dist < *best_dist => {
320-
best = Some((seg, data, proj, dist));
321-
}
322-
None => {
323-
best = Some((seg, data, proj, dist));
324-
}
325-
_ => {}
402+
let axis = depth % 2;
403+
let centroid = node.segment.centroid();
404+
let query_val = if axis == 0 { query.x } else { query.y };
405+
let node_val = if axis == 0 { centroid.x } else { centroid.y };
406+
let (first, second) = if query_val <= node_val {
407+
(node.left, node.right)
408+
} else {
409+
(node.right, node.left)
410+
};
411+
412+
if let Some(child) = first {
413+
self.search_child(child, query, depth + 1, best);
414+
}
415+
if let Some(child) = second {
416+
self.search_child(child, query, depth + 1, best);
417+
}
418+
}
419+
420+
fn search_child(
421+
&self,
422+
child_idx: usize,
423+
query: &Point2D,
424+
depth: usize,
425+
best: &mut Option<(usize, Point2D, f64)>,
426+
) {
427+
if let Some((_, _, best_dist)) = best {
428+
let child_dist = self.nodes[child_idx]
429+
.bounds
430+
.distance_squared_to_point(query);
431+
if child_dist > *best_dist {
432+
return;
326433
}
327434
}
328435

329-
best
436+
self.search_nearest(child_idx, query, depth, best);
437+
}
438+
}
439+
440+
#[cfg(test)]
441+
mod tests {
442+
use super::*;
443+
444+
#[test]
445+
fn segment_index_returns_none_when_empty() {
446+
let index: SegmentIndex<usize> = SegmentIndex::bulk_load(vec![]);
447+
assert!(index.nearest_segment(&Point2D::new(0.0, 0.0)).is_none());
448+
}
449+
450+
#[test]
451+
fn segment_index_matches_bruteforce_on_large_input() {
452+
let segments: Vec<(Segment, usize)> = (0..256)
453+
.map(|i| {
454+
let y = i as f64 * 10.0;
455+
(
456+
Segment::new(Point2D::new(0.0, y), Point2D::new(5.0, y + 1.0)),
457+
i,
458+
)
459+
})
460+
.collect();
461+
let index = SegmentIndex::bulk_load(segments.clone());
462+
let query = Point2D::new(2.25, 1234.4);
463+
464+
let indexed = index
465+
.nearest_segment(&query)
466+
.expect("expected nearest segment");
467+
468+
let brute_force = segments
469+
.iter()
470+
.map(|(segment, data)| {
471+
let (projection, _) = segment.project_point(&query);
472+
(
473+
*data,
474+
projection,
475+
query.distance_squared(&projection),
476+
*segment,
477+
)
478+
})
479+
.min_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal))
480+
.expect("expected brute-force result");
481+
482+
assert_eq!(*indexed.1, brute_force.0);
483+
assert_eq!(*indexed.0, brute_force.3);
484+
assert!((indexed.2.x - brute_force.1.x).abs() < 1e-9);
485+
assert!((indexed.2.y - brute_force.1.y).abs() < 1e-9);
486+
assert!((indexed.3 - brute_force.2).abs() < 1e-9);
330487
}
331488
}

0 commit comments

Comments
 (0)