diff options
Diffstat (limited to 'aoc')
| -rw-r--r-- | aoc/mixins.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/aoc/mixins.py b/aoc/mixins.py new file mode 100644 index 0000000..faeb3eb --- /dev/null +++ b/aoc/mixins.py | |||
| @@ -0,0 +1,57 @@ | |||
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | from math import floor, inf | ||
| 3 | from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict | ||
| 4 | |||
| 5 | Node = TypeVar("Node") | ||
| 6 | |||
| 7 | |||
| 8 | class AStarMixin(Generic[Node]): | ||
| 9 | def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]: | ||
| 10 | print("GENPATH") | ||
| 11 | path = [current] | ||
| 12 | |||
| 13 | while current in came_from: | ||
| 14 | current = came_from[current] | ||
| 15 | path = [current, *path] | ||
| 16 | |||
| 17 | return path | ||
| 18 | |||
| 19 | def a_star( | ||
| 20 | self, | ||
| 21 | start: Node, | ||
| 22 | end: Callable[[Node, Callable[[], List[Node]]], bool], | ||
| 23 | neighbours: Callable[[Node], Iterator[Node]], | ||
| 24 | distance: Callable[[Node, Node], int] = lambda a, b: 1, | ||
| 25 | heuristic: Callable[[Node], int] = lambda a: 0, | ||
| 26 | ) -> List[Node]: | ||
| 27 | open_nodes = {start} | ||
| 28 | came_from = {} | ||
| 29 | |||
| 30 | g_scores = {start: 0} | ||
| 31 | |||
| 32 | f_scores = {start: heuristic(start)} | ||
| 33 | |||
| 34 | while True: | ||
| 35 | try: | ||
| 36 | current = sorted( | ||
| 37 | [item for item in open_nodes], | ||
| 38 | key=lambda item: f_scores.get(item, inf), | ||
| 39 | )[0] | ||
| 40 | except IndexError: | ||
| 41 | raise RuntimeError("No path found") | ||
| 42 | |||
| 43 | if end(current, lambda: self._gen_path(current, came_from)): | ||
| 44 | return self._gen_path(current, came_from) | ||
| 45 | |||
| 46 | open_nodes.remove(current) | ||
| 47 | |||
| 48 | for neighbour in neighbours(current): | ||
| 49 | g_score = g_scores.get(current, inf) + distance(current, neighbour) | ||
| 50 | |||
| 51 | if g_score < g_scores.get(neighbour, inf): | ||
| 52 | came_from[neighbour] = current | ||
| 53 | g_scores[neighbour] = g_score | ||
| 54 | f_scores[neighbour] = g_score + heuristic(neighbour) | ||
| 55 | |||
| 56 | if neighbour not in open_nodes: | ||
| 57 | open_nodes.add(neighbour) | ||
