# -*- coding: utf-8 -*- from math import floor, inf from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict Node = TypeVar("Node") class AStarMixin(Generic[Node]): def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]: print("GENPATH") path = [current] while current in came_from: current = came_from[current] path = [current, *path] return path def a_star( self, start: Node, end: Callable[[Node, Callable[[], List[Node]]], bool], neighbours: Callable[[Node], Iterator[Node]], distance: Callable[[Node, Node], int] = lambda a, b: 1, heuristic: Callable[[Node], int] = lambda a: 0, ) -> List[Node]: open_nodes = {start} came_from = {} g_scores = {start: 0} f_scores = {start: heuristic(start)} while True: try: current = sorted( [item for item in open_nodes], key=lambda item: f_scores.get(item, inf), )[0] except IndexError: raise RuntimeError("No path found") if end(current, lambda: self._gen_path(current, came_from)): return self._gen_path(current, came_from) open_nodes.remove(current) for neighbour in neighbours(current): g_score = g_scores.get(current, inf) + distance(current, neighbour) if g_score < g_scores.get(neighbour, inf): came_from[neighbour] = current g_scores[neighbour] = g_score f_scores[neighbour] = g_score + heuristic(neighbour) if neighbour not in open_nodes: open_nodes.add(neighbour)