1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
|
# -*- coding: utf-8 -*-
from math import floor, inf
from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict
Node = TypeVar("Node")
class AStarMixin(Generic[Node]):
def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]:
print("GENPATH")
path = [current]
while current in came_from:
current = came_from[current]
path = [current, *path]
return path
def a_star(
self,
start: Node,
end: Callable[[Node, Callable[[], List[Node]]], bool],
neighbours: Callable[[Node], Iterator[Node]],
distance: Callable[[Node, Node], int] = lambda a, b: 1,
heuristic: Callable[[Node], int] = lambda a: 0,
) -> List[Node]:
open_nodes = {start}
came_from = {}
g_scores = {start: 0}
f_scores = {start: heuristic(start)}
while True:
try:
current = sorted(
[item for item in open_nodes],
key=lambda item: f_scores.get(item, inf),
)[0]
except IndexError:
raise RuntimeError("No path found")
if end(current, lambda: self._gen_path(current, came_from)):
return self._gen_path(current, came_from)
open_nodes.remove(current)
for neighbour in neighbours(current):
g_score = g_scores.get(current, inf) + distance(current, neighbour)
if g_score < g_scores.get(neighbour, inf):
came_from[neighbour] = current
g_scores[neighbour] = g_score
f_scores[neighbour] = g_score + heuristic(neighbour)
if neighbour not in open_nodes:
open_nodes.add(neighbour)
|