summaryrefslogtreecommitdiffstats
path: root/aoc/mixins.py
blob: 5986b6e695b93ad96c9b3308347583c10291e9af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
from math import floor, inf
from typing import Tuple, TypeVar, Generic, Callable, Iterator, List, Dict

Node = TypeVar("Node")


class AStarMixin(Generic[Node]):
    def _gen_path(self, current: Node, came_from: Dict[Node, Node]) -> List[Node]:
        print("GENPATH")
        path = [current]

        while current in came_from:
            current = came_from[current]
            path = [current, *path]

        return path

    def a_star(
        self,
        start: Node,
        end: Callable[[Node, Callable[[], List[Node]]], bool],
        neighbours: Callable[[Node], Iterator[Node]],
        distance: Callable[[Node, Node], int] = lambda a, b: 1,
        heuristic: Callable[[Node], int] = lambda a: 0,
    ) -> List[Node]:
        open_nodes = {start}
        came_from = {}

        g_scores = {start: 0}

        f_scores = {start: heuristic(start)}

        while True:
            try:
                current = sorted(
                    [item for item in open_nodes],
                    key=lambda item: f_scores.get(item, inf),
                )[0]
            except IndexError:
                raise RuntimeError("No path found")

            if end(current, lambda: self._gen_path(current, came_from)):
                return self._gen_path(current, came_from)

            open_nodes.remove(current)

            for neighbour in neighbours(current):
                g_score = g_scores.get(current, inf) + distance(current, neighbour)

                if g_score < g_scores.get(neighbour, inf):
                    came_from[neighbour] = current
                    g_scores[neighbour] = g_score
                    f_scores[neighbour] = g_score + heuristic(neighbour)

                    if neighbour not in open_nodes:
                        open_nodes.add(neighbour)


class BreathFirstSearchMixin(Generic[Node]):
    @staticmethod
    def bfs(
        start: Node, neighbours: Callable[[Node], Iterator[Node]]
    ) -> Iterator[Node]:
        queue = [start]
        searched = set()

        while len(queue):
            item = queue.pop(0)

            for neighbour in neighbours(item):
                if neighbour not in searched and neighbour not in set(queue):
                    queue.append(neighbour)

            searched.add(item)

            yield item