summaryrefslogtreecommitdiffstats
path: root/aoc
diff options
context:
space:
mode:
Diffstat (limited to 'aoc')
-rw-r--r--aoc/mixins.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/aoc/mixins.py b/aoc/mixins.py
new file mode 100644
index 0000000..faeb3eb
--- /dev/null
+++ b/aoc/mixins.py
@@ -0,0 +1,57 @@
1# -*- coding: utf-8 -*-
2from math import floor, inf
3from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict
4
5Node = TypeVar("Node")
6
7
8class AStarMixin(Generic[Node]):
9 def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]:
10 print("GENPATH")
11 path = [current]
12
13 while current in came_from:
14 current = came_from[current]
15 path = [current, *path]
16
17 return path
18
19 def a_star(
20 self,
21 start: Node,
22 end: Callable[[Node, Callable[[], List[Node]]], bool],
23 neighbours: Callable[[Node], Iterator[Node]],
24 distance: Callable[[Node, Node], int] = lambda a, b: 1,
25 heuristic: Callable[[Node], int] = lambda a: 0,
26 ) -> List[Node]:
27 open_nodes = {start}
28 came_from = {}
29
30 g_scores = {start: 0}
31
32 f_scores = {start: heuristic(start)}
33
34 while True:
35 try:
36 current = sorted(
37 [item for item in open_nodes],
38 key=lambda item: f_scores.get(item, inf),
39 )[0]
40 except IndexError:
41 raise RuntimeError("No path found")
42
43 if end(current, lambda: self._gen_path(current, came_from)):
44 return self._gen_path(current, came_from)
45
46 open_nodes.remove(current)
47
48 for neighbour in neighbours(current):
49 g_score = g_scores.get(current, inf) + distance(current, neighbour)
50
51 if g_score < g_scores.get(neighbour, inf):
52 came_from[neighbour] = current
53 g_scores[neighbour] = g_score
54 f_scores[neighbour] = g_score + heuristic(neighbour)
55
56 if neighbour not in open_nodes:
57 open_nodes.add(neighbour)