summaryrefslogtreecommitdiffstats
path: root/aoc/mixins.py
blob: faeb3ebfc273bee43c8245b3db8f0b846fd3584b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
from math import floor, inf
from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict

Node = TypeVar("Node")


class AStarMixin(Generic[Node]):
    def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]:
        print("GENPATH")
        path = [current]

        while current in came_from:
            current = came_from[current]
            path = [current, *path]

        return path

    def a_star(
        self,
        start: Node,
        end: Callable[[Node, Callable[[], List[Node]]], bool],
        neighbours: Callable[[Node], Iterator[Node]],
        distance: Callable[[Node, Node], int] = lambda a, b: 1,
        heuristic: Callable[[Node], int] = lambda a: 0,
    ) -> List[Node]:
        open_nodes = {start}
        came_from = {}

        g_scores = {start: 0}

        f_scores = {start: heuristic(start)}

        while True:
            try:
                current = sorted(
                    [item for item in open_nodes],
                    key=lambda item: f_scores.get(item, inf),
                )[0]
            except IndexError:
                raise RuntimeError("No path found")

            if end(current, lambda: self._gen_path(current, came_from)):
                return self._gen_path(current, came_from)

            open_nodes.remove(current)

            for neighbour in neighbours(current):
                g_score = g_scores.get(current, inf) + distance(current, neighbour)

                if g_score < g_scores.get(neighbour, inf):
                    came_from[neighbour] = current
                    g_scores[neighbour] = g_score
                    f_scores[neighbour] = g_score + heuristic(neighbour)

                    if neighbour not in open_nodes:
                        open_nodes.add(neighbour)