diff options
| author | 2023-11-19 16:55:03 +0100 | |
|---|---|---|
| committer | 2025-12-01 09:53:01 +0100 | |
| commit | d7e30321ae6ae4c82a8ab7455f6ce33afd719c67 (patch) | |
| tree | e873d640f909ae3e247adc7661b7d954c8af3a26 /aoc/mixins.py | |
| download | 2025-d7e30321ae6ae4c82a8ab7455f6ce33afd719c67.tar.gz 2025-d7e30321ae6ae4c82a8ab7455f6ce33afd719c67.tar.bz2 2025-d7e30321ae6ae4c82a8ab7455f6ce33afd719c67.zip | |
Initial commit
Diffstat (limited to 'aoc/mixins.py')
| -rw-r--r-- | aoc/mixins.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/aoc/mixins.py b/aoc/mixins.py new file mode 100644 index 0000000..5986b6e --- /dev/null +++ b/aoc/mixins.py | |||
| @@ -0,0 +1,77 @@ | |||
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | from math import floor, inf | ||
| 3 | from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict | ||
| 4 | |||
| 5 | Node = TypeVar("Node") | ||
| 6 | |||
| 7 | |||
| 8 | class AStarMixin(Generic[Node]): | ||
| 9 | def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]: | ||
| 10 | print("GENPATH") | ||
| 11 | path = [current] | ||
| 12 | |||
| 13 | while current in came_from: | ||
| 14 | current = came_from[current] | ||
| 15 | path = [current, *path] | ||
| 16 | |||
| 17 | return path | ||
| 18 | |||
| 19 | def a_star( | ||
| 20 | self, | ||
| 21 | start: Node, | ||
| 22 | end: Callable[[Node, Callable[[], List[Node]]], bool], | ||
| 23 | neighbours: Callable[[Node], Iterator[Node]], | ||
| 24 | distance: Callable[[Node, Node], int] = lambda a, b: 1, | ||
| 25 | heuristic: Callable[[Node], int] = lambda a: 0, | ||
| 26 | ) -> List[Node]: | ||
| 27 | open_nodes = {start} | ||
| 28 | came_from = {} | ||
| 29 | |||
| 30 | g_scores = {start: 0} | ||
| 31 | |||
| 32 | f_scores = {start: heuristic(start)} | ||
| 33 | |||
| 34 | while True: | ||
| 35 | try: | ||
| 36 | current = sorted( | ||
| 37 | [item for item in open_nodes], | ||
| 38 | key=lambda item: f_scores.get(item, inf), | ||
| 39 | )[0] | ||
| 40 | except IndexError: | ||
| 41 | raise RuntimeError("No path found") | ||
| 42 | |||
| 43 | if end(current, lambda: self._gen_path(current, came_from)): | ||
| 44 | return self._gen_path(current, came_from) | ||
| 45 | |||
| 46 | open_nodes.remove(current) | ||
| 47 | |||
| 48 | for neighbour in neighbours(current): | ||
| 49 | g_score = g_scores.get(current, inf) + distance(current, neighbour) | ||
| 50 | |||
| 51 | if g_score < g_scores.get(neighbour, inf): | ||
| 52 | came_from[neighbour] = current | ||
| 53 | g_scores[neighbour] = g_score | ||
| 54 | f_scores[neighbour] = g_score + heuristic(neighbour) | ||
| 55 | |||
| 56 | if neighbour not in open_nodes: | ||
| 57 | open_nodes.add(neighbour) | ||
| 58 | |||
| 59 | |||
| 60 | class BreathFirstSearchMixin(Generic[Node]): | ||
| 61 | @staticmethod | ||
| 62 | def bfs( | ||
| 63 | start: Node, neighbours: Callable[[Node], Iterator[Node]] | ||
| 64 | ) -> Iterator[Node]: | ||
| 65 | queue = [start] | ||
| 66 | searched = set() | ||
| 67 | |||
| 68 | while len(queue): | ||
| 69 | item = queue.pop(0) | ||
| 70 | |||
| 71 | for neighbour in neighbours(item): | ||
| 72 | if neighbour not in searched and neighbour not in set(queue): | ||
| 73 | queue.append(neighbour) | ||
| 74 | |||
| 75 | searched.add(item) | ||
| 76 | |||
| 77 | yield item | ||
