|
| 1 | +""" generic A-Star path searching algorithm (https://github.com/jrialland/python-astar/blob/master/tests/basic/test_basic.py) """ |
| 2 | +# @TODO make it so we return the best path as discussed in class |
| 3 | +# In class, we discussed that we should have a terminal function |
| 4 | +# which approximates the cost_to_go and returns the path to the node |
| 5 | +# with the least cost_to_go + terminal_cost |
| 6 | + |
| 7 | +from abc import ABC, abstractmethod |
| 8 | +from typing import Callable, Dict, Iterable, Union, TypeVar, Generic |
| 9 | +from math import inf as infinity |
| 10 | +from operator import attrgetter |
| 11 | +import heapq |
| 12 | + |
| 13 | +# introduce generic type |
| 14 | +T = TypeVar("T") |
| 15 | + |
| 16 | + |
| 17 | +################################################################################ |
| 18 | +class SearchNode(Generic[T]): |
| 19 | + """Representation of a search node""" |
| 20 | + |
| 21 | + __slots__ = ("data", "gscore", "fscore", "closed", "came_from", "in_openset", "cache") |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, data: T, gscore: float = infinity, fscore: float = infinity |
| 25 | + ) -> None: |
| 26 | + self.data = data |
| 27 | + self.gscore = gscore |
| 28 | + self.fscore = fscore |
| 29 | + self.closed = False |
| 30 | + self.in_openset = False |
| 31 | + self.came_from: Union[None, SearchNode[T]] = None |
| 32 | + self.cache: Any = None |
| 33 | + |
| 34 | + def __lt__(self, b: "SearchNode[T]") -> bool: |
| 35 | + """Natural order is based on the fscore value & is used by heapq operations""" |
| 36 | + return self.fscore < b.fscore |
| 37 | + |
| 38 | + |
| 39 | +################################################################################ |
| 40 | +class SearchNodeDict(Dict[T, SearchNode[T]]): |
| 41 | + """A dict that returns a new SearchNode when a key is missing""" |
| 42 | + |
| 43 | + def __missing__(self, k) -> SearchNode[T]: |
| 44 | + v = SearchNode(k) |
| 45 | + self.__setitem__(k, v) |
| 46 | + return v |
| 47 | + |
| 48 | + |
| 49 | +################################################################################ |
| 50 | +SNType = TypeVar("SNType", bound=SearchNode) |
| 51 | + |
| 52 | + |
| 53 | +class OpenSet(Generic[SNType]): |
| 54 | + def __init__(self) -> None: |
| 55 | + self.heap: list[SNType] = [] |
| 56 | + |
| 57 | + def push(self, item: SNType) -> None: |
| 58 | + item.in_openset = True |
| 59 | + heapq.heappush(self.heap, item) |
| 60 | + |
| 61 | + def pop(self) -> SNType: |
| 62 | + item = heapq.heappop(self.heap) |
| 63 | + item.in_openset = False |
| 64 | + return item |
| 65 | + |
| 66 | + def remove(self, item: SNType) -> None: |
| 67 | + idx = self.heap.index(item) |
| 68 | + item.in_openset = False |
| 69 | + item = self.heap.pop() |
| 70 | + if idx < len(self.heap): |
| 71 | + self.heap[idx] = item |
| 72 | + # Fix heap invariants |
| 73 | + heapq._siftup(self.heap, idx) |
| 74 | + heapq._siftdown(self.heap, 0, idx) |
| 75 | + |
| 76 | + def __len__(self) -> int: |
| 77 | + return len(self.heap) |
| 78 | + |
| 79 | + |
| 80 | +################################################################################* |
| 81 | + |
| 82 | + |
| 83 | +class AStar(ABC, Generic[T]): |
| 84 | + __slots__ = () |
| 85 | + |
| 86 | + @abstractmethod |
| 87 | + def heuristic_cost_estimate(self, current: T, goal: T) -> float: |
| 88 | + """ |
| 89 | + Computes the estimated (rough) distance between a node and the goal. |
| 90 | + The second parameter is always the goal. |
| 91 | +
|
| 92 | + This method must be implemented in a subclass. |
| 93 | + """ |
| 94 | + raise NotImplementedError |
| 95 | + |
| 96 | + def distance_between(self, n1: T, n2: T) -> float: |
| 97 | + """ |
| 98 | + Gives the real distance between two adjacent nodes n1 and n2 (i.e n2 |
| 99 | + belongs to the list of n1's neighbors). |
| 100 | + n2 is guaranteed to belong to the list returned by the call to neighbors(n1). |
| 101 | +
|
| 102 | + This method (or "path_distance_between") must be implemented in a subclass. |
| 103 | + """ |
| 104 | + raise NotImplementedError |
| 105 | + |
| 106 | + def path_distance_between(self, n1: SearchNode[T], n2: SearchNode[T]) -> float: |
| 107 | + """ |
| 108 | + Gives the real distance between the node n1 and its neighbor n2. |
| 109 | + n2 is guaranteed to belong to the list returned by the call to |
| 110 | + path_neighbors(n1). |
| 111 | +
|
| 112 | + Calls "distance_between"`by default. |
| 113 | + """ |
| 114 | + return self.distance_between(n1.data, n2.data) |
| 115 | + |
| 116 | + def neighbors(self, node: T) -> Iterable[T]: |
| 117 | + """ |
| 118 | + For a given node, returns (or yields) the list of its neighbors. |
| 119 | +
|
| 120 | + This method (or "path_neighbors") must be implemented in a subclass. |
| 121 | + """ |
| 122 | + raise NotImplementedError |
| 123 | + |
| 124 | + def path_neighbors(self, node: SearchNode[T]) -> Iterable[T]: |
| 125 | + """ |
| 126 | + For a given node, returns (or yields) the list of its reachable neighbors. |
| 127 | + Calls "neighbors" by default. |
| 128 | + """ |
| 129 | + return self.neighbors(node.data) |
| 130 | + |
| 131 | + def _neighbors(self, current: SearchNode[T], search_nodes: SearchNodeDict[T]) -> Iterable[SearchNode]: |
| 132 | + return (search_nodes[n] for n in self.path_neighbors(current)) |
| 133 | + |
| 134 | + def is_goal_reached(self, current: T, goal: T) -> bool: |
| 135 | + """ |
| 136 | + Returns true when we can consider that 'current' is the goal. |
| 137 | + The default implementation simply compares `current == goal`, but this |
| 138 | + method can be overwritten in a subclass to provide more refined checks. |
| 139 | + """ |
| 140 | + return current == goal |
| 141 | + |
| 142 | + def reconstruct_path(self, last: SearchNode, reversePath=False) -> Iterable[T]: |
| 143 | + def _gen(): |
| 144 | + current = last |
| 145 | + while current: |
| 146 | + yield current.data |
| 147 | + current = current.came_from |
| 148 | + |
| 149 | + if reversePath: |
| 150 | + return _gen() |
| 151 | + else: |
| 152 | + return reversed(list(_gen())) |
| 153 | + |
| 154 | + def astar( |
| 155 | + self, start: T, goal: T, reversePath: bool = False |
| 156 | + ) -> Union[Iterable[T], None]: |
| 157 | + if self.is_goal_reached(start, goal): |
| 158 | + return [start] |
| 159 | + |
| 160 | + openSet: OpenSet[SearchNode[T]] = OpenSet() |
| 161 | + searchNodes: SearchNodeDict[T] = SearchNodeDict() |
| 162 | + startNode = searchNodes[start] = SearchNode( |
| 163 | + start, gscore=0.0, fscore=self.heuristic_cost_estimate(start, goal) |
| 164 | + ) |
| 165 | + openSet.push(startNode) |
| 166 | + |
| 167 | + while openSet: |
| 168 | + current = openSet.pop() |
| 169 | + |
| 170 | + if self.is_goal_reached(current.data, goal): |
| 171 | + return self.reconstruct_path(current, reversePath) |
| 172 | + |
| 173 | + current.closed = True |
| 174 | + |
| 175 | + for neighbor in self._neighbors(current, searchNodes): |
| 176 | + if neighbor.closed: |
| 177 | + continue |
| 178 | + |
| 179 | + gscore = current.gscore + self.path_distance_between(current, neighbor) |
| 180 | + |
| 181 | + if gscore >= neighbor.gscore: |
| 182 | + continue |
| 183 | + |
| 184 | + fscore = gscore + self.heuristic_cost_estimate( |
| 185 | + neighbor.data, goal |
| 186 | + ) |
| 187 | + |
| 188 | + if neighbor.in_openset: |
| 189 | + if neighbor.fscore < fscore: |
| 190 | + # the new path to this node isn't better |
| 191 | + continue |
| 192 | + |
| 193 | + # we have to remove the item from the heap, as its score has changed |
| 194 | + openSet.remove(neighbor) |
| 195 | + |
| 196 | + # update the node |
| 197 | + neighbor.came_from = current |
| 198 | + neighbor.gscore = gscore |
| 199 | + neighbor.fscore = fscore |
| 200 | + |
| 201 | + openSet.push(neighbor) |
| 202 | + |
| 203 | + return None |
| 204 | + |
| 205 | + |
| 206 | +################################################################################ |
| 207 | +U = TypeVar("U") |
| 208 | + |
| 209 | + |
| 210 | +def find_path( |
| 211 | + start: U, |
| 212 | + goal: U, |
| 213 | + neighbors_fnct: Callable[[U], Iterable[U]], |
| 214 | + reversePath=False, |
| 215 | + heuristic_cost_estimate_fnct: Callable[[U, U], float] = lambda a, b: infinity, |
| 216 | + distance_between_fnct: Callable[[U, U], float] = lambda a, b: 1.0, |
| 217 | + is_goal_reached_fnct: Callable[[U, U], bool] = lambda a, b: a == b, |
| 218 | +) -> Union[Iterable[U], None]: |
| 219 | + """A non-class version of the path finding algorithm""" |
| 220 | + |
| 221 | + class FindPath(AStar): |
| 222 | + def heuristic_cost_estimate(self, current: U, goal: U) -> float: |
| 223 | + return heuristic_cost_estimate_fnct(current, goal) # type: ignore |
| 224 | + |
| 225 | + def distance_between(self, n1: U, n2: U) -> float: |
| 226 | + return distance_between_fnct(n1, n2) |
| 227 | + |
| 228 | + def neighbors(self, node) -> Iterable[U]: |
| 229 | + return neighbors_fnct(node) # type: ignore |
| 230 | + |
| 231 | + def is_goal_reached(self, current: U, goal: U) -> bool: |
| 232 | + return is_goal_reached_fnct(current, goal) |
| 233 | + |
| 234 | + return FindPath().astar(start, goal, reversePath) |
0 commit comments