diff --git a/src/directed/dfs.rs b/src/directed/dfs.rs index 15049889..86f0abf5 100644 --- a/src/directed/dfs.rs +++ b/src/directed/dfs.rs @@ -163,19 +163,21 @@ where type Item = N; fn next(&mut self) -> Option { - let n = self.to_see.pop()?; - if self.visited.contains(&n) { - return self.next(); - } - self.visited.insert(n.clone()); - let mut to_insert = Vec::new(); - for s in (self.successors)(&n) { - if !self.visited.contains(&s) { - to_insert.push(s.clone()); + loop { + let n = self.to_see.pop()?; + if self.visited.contains(&n) { + continue; + } + self.visited.insert(n.clone()); + let mut to_insert = Vec::new(); + for s in (self.successors)(&n) { + if !self.visited.contains(&s) { + to_insert.push(s.clone()); + } } + self.to_see.extend(to_insert.into_iter().rev()); + return Some(n); } - self.to_see.extend(to_insert.into_iter().rev()); - Some(n) } } diff --git a/src/directed/yen.rs b/src/directed/yen.rs index 57235442..61356152 100644 --- a/src/directed/yen.rs +++ b/src/directed/yen.rs @@ -108,6 +108,9 @@ where IN: IntoIterator, FS: FnMut(&N) -> bool, { + if k == 0 { + return vec![]; + } let Some((n, c)) = dijkstra_internal(start, &mut successors, &mut success) else { return vec![]; }; diff --git a/tests/dfs-reach.rs b/tests/dfs-reach.rs index 842a3c61..80386af7 100644 --- a/tests/dfs-reach.rs +++ b/tests/dfs-reach.rs @@ -14,3 +14,20 @@ fn issue_511_branches() { let it = dfs_reach(0, |&n| [n + 2, n + 5].into_iter().filter(|&x| x <= 10)); assert_eq!(vec![0, 2, 4, 6, 8, 10, 9, 7, 5], it.collect::>()); } + +/// Test that `dfs_reach` does not stack overflow when many duplicate nodes +/// pile up in the `to_see` stack (would previously recurse for each duplicate). +#[test] +fn no_stack_overflow_with_duplicates() { + // Each node has N successors all pointing to the same next node, creating + // many duplicates in to_see. With the old recursive implementation, the + // recursion depth could equal the number of duplicates, causing stack overflow. + let n = 200_usize; + // Node 0 -> [1, 1, 1, ...] (n copies of 1) + // Node k -> [k+1, k+1, k+1, ...] (n copies of k+1) for k < n + // Node n -> [] + let result: Vec = + dfs_reach(0usize, |&k| if k < n { vec![k + 1; n] } else { vec![] }).collect(); + let expected: Vec = (0..=n).collect(); + assert_eq!(result, expected); +} diff --git a/tests/yen.rs b/tests/yen.rs index 74010330..3d5c7432 100644 --- a/tests/yen.rs +++ b/tests/yen.rs @@ -177,3 +177,19 @@ fn multiple_equal_cost_paths() { assert_eq!(paths[0], (vec!['A', 'B', 'D'], 2)); assert_eq!(paths[1], (vec!['A', 'C', 'D'], 2)); } + +/// Test that k=0 returns an empty result without panicking. +#[test] +fn k_zero() { + let result = yen( + &'c', + |c| match c { + 'c' => vec![('d', 3)], + 'd' => vec![], + _ => panic!(""), + }, + |c| *c == 'd', + 0, + ); + assert!(result.is_empty()); +}