From 4147da1317c19fa61d6aa265e8370e63231f9207 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Sun, 19 Nov 2023 16:55:03 +0100 Subject: Initial commit --- aoc/mixins.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 aoc/mixins.py (limited to 'aoc/mixins.py') diff --git a/aoc/mixins.py b/aoc/mixins.py new file mode 100644 index 0000000..5986b6e --- /dev/null +++ b/aoc/mixins.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +from math import floor, inf +from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict + +Node = TypeVar("Node") + + +class AStarMixin(Generic[Node]): + def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]: + print("GENPATH") + path = [current] + + while current in came_from: + current = came_from[current] + path = [current, *path] + + return path + + def a_star( + self, + start: Node, + end: Callable[[Node, Callable[[], List[Node]]], bool], + neighbours: Callable[[Node], Iterator[Node]], + distance: Callable[[Node, Node], int] = lambda a, b: 1, + heuristic: Callable[[Node], int] = lambda a: 0, + ) -> List[Node]: + open_nodes = {start} + came_from = {} + + g_scores = {start: 0} + + f_scores = {start: heuristic(start)} + + while True: + try: + current = sorted( + [item for item in open_nodes], + key=lambda item: f_scores.get(item, inf), + )[0] + except IndexError: + raise RuntimeError("No path found") + + if end(current, lambda: self._gen_path(current, came_from)): + return self._gen_path(current, came_from) + + open_nodes.remove(current) + + for neighbour in neighbours(current): + g_score = g_scores.get(current, inf) + distance(current, neighbour) + + if g_score < g_scores.get(neighbour, inf): + came_from[neighbour] = current + g_scores[neighbour] = g_score + f_scores[neighbour] = g_score + heuristic(neighbour) + + if neighbour not in open_nodes: + open_nodes.add(neighbour) + + +class BreathFirstSearchMixin(Generic[Node]): + @staticmethod + def bfs( + start: Node, neighbours: Callable[[Node], Iterator[Node]] + ) -> Iterator[Node]: + queue = [start] + searched = set() + + while len(queue): + item = queue.pop(0) + + for neighbour in neighbours(item): + if neighbour not in searched and neighbour not in set(queue): + queue.append(neighbour) + + searched.add(item) + + yield item -- cgit v1.2.3