Skip to content

Commit 02cc345

Browse files
committed
necessary components for parking. Can likely put astar and reed_shepp into some sort of util folder.
1 parent 31f2bcd commit 02cc345

3 files changed

Lines changed: 1256 additions & 0 deletions

File tree

GEMstack/onboard/planning/astar.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
""" generic A-Star path searching algorithm (https://github.com/jrialland/python-astar/blob/master/tests/basic/test_basic.py) """
2+
# @TODO make it so we return the best path as discussed in class
3+
# In class, we discussed that we should have a terminal function
4+
# which approximates the cost_to_go and returns the path to the node
5+
# with the least cost_to_go + terminal_cost
6+
7+
from abc import ABC, abstractmethod
8+
from typing import Callable, Dict, Iterable, Union, TypeVar, Generic
9+
from math import inf as infinity
10+
from operator import attrgetter
11+
import heapq
12+
13+
# introduce generic type
14+
T = TypeVar("T")
15+
16+
17+
################################################################################
18+
class SearchNode(Generic[T]):
19+
"""Representation of a search node"""
20+
21+
__slots__ = ("data", "gscore", "fscore", "tscore", "closed", "came_from", "in_openset", "cache")
22+
23+
def __init__(
24+
self, data: T, gscore: float = infinity, fscore: float = infinity, tscore:float = infinity
25+
) -> None:
26+
self.data = data
27+
self.gscore = gscore
28+
self.fscore = fscore
29+
self.tscore = tscore
30+
self.closed = False
31+
self.in_openset = False
32+
self.came_from: Union[None, SearchNode[T]] = None
33+
self.cache: Any = None
34+
35+
def __lt__(self, b: "SearchNode[T]") -> bool:
36+
"""Natural order is based on the fscore value & is used by heapq operations"""
37+
return self.fscore < b.fscore
38+
39+
40+
################################################################################
41+
class SearchNodeDict(Dict[T, SearchNode[T]]):
42+
"""A dict that returns a new SearchNode when a key is missing"""
43+
44+
def __missing__(self, k) -> SearchNode[T]:
45+
v = SearchNode(k)
46+
self.__setitem__(k, v)
47+
return v
48+
49+
50+
################################################################################
51+
SNType = TypeVar("SNType", bound=SearchNode)
52+
53+
54+
class OpenSet(Generic[SNType]):
55+
def __init__(self) -> None:
56+
self.heap: list[SNType] = []
57+
58+
def push(self, item: SNType) -> None:
59+
item.in_openset = True
60+
heapq.heappush(self.heap, item)
61+
62+
def pop(self) -> SNType:
63+
item = heapq.heappop(self.heap)
64+
item.in_openset = False
65+
return item
66+
67+
def remove(self, item: SNType) -> None:
68+
idx = self.heap.index(item)
69+
item.in_openset = False
70+
item = self.heap.pop()
71+
if idx < len(self.heap):
72+
self.heap[idx] = item
73+
# Fix heap invariants
74+
heapq._siftup(self.heap, idx)
75+
heapq._siftdown(self.heap, 0, idx)
76+
77+
def __len__(self) -> int:
78+
return len(self.heap)
79+
80+
81+
################################################################################*
82+
83+
84+
class AStar(ABC, Generic[T]):
85+
__slots__ = ()
86+
87+
@abstractmethod
88+
def heuristic_cost_estimate(self, current: T, goal: T) -> float:
89+
"""
90+
Computes the estimated (rough) distance between a node and the goal.
91+
The second parameter is always the goal.
92+
93+
This method must be implemented in a subclass.
94+
"""
95+
raise NotImplementedError
96+
97+
@abstractmethod
98+
def terminal_cost_estimate(self, current: T, goal: T) -> float:
99+
"""Computes the estimated distance between a node and the goal.
100+
This function is called after all iterations of A* have been run
101+
and is used to determine the closest node to the goal found so far.
102+
103+
This method must be implemented in a subclass.
104+
105+
Args:
106+
current (T): Current T
107+
goal (T): goal T
108+
109+
Returns:
110+
float: _description_
111+
"""
112+
raise NotImplementedError
113+
114+
def distance_between(self, n1: T, n2: T) -> float:
115+
"""
116+
Gives the real distance between two adjacent nodes n1 and n2 (i.e n2
117+
belongs to the list of n1's neighbors).
118+
n2 is guaranteed to belong to the list returned by the call to neighbors(n1).
119+
120+
This method (or "path_distance_between") must be implemented in a subclass.
121+
"""
122+
raise NotImplementedError
123+
124+
def path_distance_between(self, n1: SearchNode[T], n2: SearchNode[T]) -> float:
125+
"""
126+
Gives the real distance between the node n1 and its neighbor n2.
127+
n2 is guaranteed to belong to the list returned by the call to
128+
path_neighbors(n1).
129+
130+
Calls "distance_between"`by default.
131+
"""
132+
return self.distance_between(n1.data, n2.data)
133+
134+
def neighbors(self, node: T) -> Iterable[T]:
135+
"""
136+
For a given node, returns (or yields) the list of its neighbors.
137+
138+
This method (or "path_neighbors") must be implemented in a subclass.
139+
"""
140+
raise NotImplementedError
141+
142+
def path_neighbors(self, node: SearchNode[T]) -> Iterable[T]:
143+
"""
144+
For a given node, returns (or yields) the list of its reachable neighbors.
145+
Calls "neighbors" by default.
146+
"""
147+
return self.neighbors(node.data)
148+
149+
def _neighbors(self, current: SearchNode[T], search_nodes: SearchNodeDict[T]) -> Iterable[SearchNode]:
150+
return (search_nodes[n] for n in self.path_neighbors(current))
151+
152+
def is_goal_reached(self, current: T, goal: T) -> bool:
153+
"""
154+
Returns true when we can consider that 'current' is the goal.
155+
The default implementation simply compares `current == goal`, but this
156+
method can be overwritten in a subclass to provide more refined checks.
157+
"""
158+
return current == goal
159+
160+
def reconstruct_path(self, last: SearchNode, reversePath=False) -> Iterable[T]:
161+
def _gen():
162+
current = last
163+
while current:
164+
yield current.data
165+
current = current.came_from
166+
167+
if reversePath:
168+
return _gen()
169+
else:
170+
return reversed(list(_gen()))
171+
172+
def astar(
173+
self, start: T, goal: T, reversePath: bool = False, iterations: int = 5000
174+
) -> Union[Iterable[T], None]:
175+
if self.is_goal_reached(start, goal):
176+
return [start]
177+
178+
openSet: OpenSet[SearchNode[T]] = OpenSet()
179+
searchNodes: SearchNodeDict[T] = SearchNodeDict()
180+
startNode = searchNodes[start] = SearchNode(
181+
start, gscore=0.0, fscore=self.heuristic_cost_estimate(start, goal)
182+
)
183+
openSet.push(startNode)
184+
bestNode = startNode
185+
186+
iteration = 0
187+
188+
while openSet and iteration < iterations:
189+
current = openSet.pop()
190+
191+
if self.is_goal_reached(current.data, goal):
192+
return self.reconstruct_path(current, reversePath)
193+
194+
current.closed = True
195+
196+
for neighbor in self._neighbors(current, searchNodes):
197+
if neighbor.closed:
198+
continue
199+
200+
gscore = current.gscore + self.path_distance_between(current, neighbor)
201+
202+
if gscore >= neighbor.gscore:
203+
continue
204+
205+
fscore = gscore + self.heuristic_cost_estimate(
206+
neighbor.data, goal
207+
)
208+
tscore = self.terminal_cost_estimate(
209+
neighbor.data, goal
210+
)
211+
212+
# print(f"Checking node: {neighbor.data} with tscore {tscore}")
213+
if tscore < bestNode.tscore:
214+
# print(f"Found a better node: {neighbor.data} with tscore {tscore}")
215+
bestNode = neighbor
216+
217+
if neighbor.in_openset:
218+
if neighbor.fscore < fscore:
219+
# the new path to this node isn't better
220+
continue
221+
222+
# we have to remove the item from the heap, as its score has changed
223+
openSet.remove(neighbor)
224+
225+
# update the node
226+
neighbor.came_from = current
227+
neighbor.gscore = gscore
228+
neighbor.fscore = fscore
229+
neighbor.tscore = tscore
230+
231+
openSet.push(neighbor)
232+
233+
iteration += 1
234+
235+
# print("Warning: A* search failed to find a path")
236+
return self.reconstruct_path(bestNode, reversePath)
237+
238+
239+
################################################################################
240+
U = TypeVar("U")
241+
242+
243+
def find_path(
244+
start: U,
245+
goal: U,
246+
neighbors_fnct: Callable[[U], Iterable[U]],
247+
reversePath=False,
248+
heuristic_cost_estimate_fnct: Callable[[U, U], float] = lambda a, b: infinity,
249+
distance_between_fnct: Callable[[U, U], float] = lambda a, b: 1.0,
250+
is_goal_reached_fnct: Callable[[U, U], bool] = lambda a, b: a == b,
251+
) -> Union[Iterable[U], None]:
252+
"""A non-class version of the path finding algorithm"""
253+
254+
class FindPath(AStar):
255+
def heuristic_cost_estimate(self, current: U, goal: U) -> float:
256+
return heuristic_cost_estimate_fnct(current, goal) # type: ignore
257+
258+
def distance_between(self, n1: U, n2: U) -> float:
259+
return distance_between_fnct(n1, n2)
260+
261+
def neighbors(self, node) -> Iterable[U]:
262+
return neighbors_fnct(node) # type: ignore
263+
264+
def is_goal_reached(self, current: U, goal: U) -> bool:
265+
return is_goal_reached_fnct(current, goal)
266+
267+
return FindPath().astar(start, goal, reversePath)

0 commit comments

Comments
 (0)