From 2d1263453d26cd2c6a64a3b6141ee0a12ddc0b11 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Sat, 17 Dec 2022 16:35:17 +0100 Subject: Day 16 [WIP] --- aoc/mixins.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 aoc/mixins.py (limited to 'aoc/mixins.py') diff --git a/aoc/mixins.py b/aoc/mixins.py new file mode 100644 index 0000000..faeb3eb --- /dev/null +++ b/aoc/mixins.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +from math import floor, inf +from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict + +Node = TypeVar("Node") + + +class AStarMixin(Generic[Node]): + def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]: + print("GENPATH") + path = [current] + + while current in came_from: + current = came_from[current] + path = [current, *path] + + return path + + def a_star( + self, + start: Node, + end: Callable[[Node, Callable[[], List[Node]]], bool], + neighbours: Callable[[Node], Iterator[Node]], + distance: Callable[[Node, Node], int] = lambda a, b: 1, + heuristic: Callable[[Node], int] = lambda a: 0, + ) -> List[Node]: + open_nodes = {start} + came_from = {} + + g_scores = {start: 0} + + f_scores = {start: heuristic(start)} + + while True: + try: + current = sorted( + [item for item in open_nodes], + key=lambda item: f_scores.get(item, inf), + )[0] + except IndexError: + raise RuntimeError("No path found") + + if end(current, lambda: self._gen_path(current, came_from)): + return self._gen_path(current, came_from) + + open_nodes.remove(current) + + for neighbour in neighbours(current): + g_score = g_scores.get(current, inf) + distance(current, neighbour) + + if g_score < g_scores.get(neighbour, inf): + came_from[neighbour] = current + g_scores[neighbour] = g_score + f_scores[neighbour] = g_score + heuristic(neighbour) + + if neighbour not in open_nodes: + open_nodes.add(neighbour) -- cgit v1.2.3