summaryrefslogtreecommitdiffstats
path: root/day15/__init__.py
blob: 38126fcf861ea34e48d9fc8d997a5051e52756a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import math
from abc import ABC
from dataclasses import dataclass
from typing import List, Tuple, Iterator, TypedDict, Dict

from aoc import BaseAssignment
from aoc.utils import bold

Coordinate = Tuple[int, int]
Map = List[List[int]]

@dataclass
class Score:
    g_score: float
    h_score: float

    @property
    def f_score(self):
        return self.g_score + self.h_score

class Assignment(BaseAssignment, ABC):
    def parse_item(self, item: str) -> List[int]:
        return [int(i) for i in item]

    def read_input(self, example = False) -> Map:
        return list(super().read_input(example))

    @classmethod
    def neighbours(cls, field: Map, x: int, y: int) -> Iterator[Coordinate]:
        for ny in list(range(max(0, y - 1), min(len(field) - 1, y + 1) + 1)):
            if ny == y:
                continue
            yield (x, ny)

        for nx in list(range(max(0, x - 1), min(len(field[0]) - 1, x + 1) + 1)):
            if nx == x:
                continue
            yield (nx, y)

    @classmethod
    def distance(cls, start: Coordinate, end: Coordinate) -> float:
        start_x, start_y = start
        end_x, end_y = end

        dx = end_x - start_x
        dy = end_y - start_y

        return math.sqrt((dx ** 2) + (dy ** 2))

    @classmethod
    def gen_path(cls, current: Coordinate, came_from: Dict[Coordinate, Coordinate]) -> List[Coordinate]:
        path = [current]
        while current in came_from:
            current = came_from[current]
            path.append(current)
        return list(reversed(path))

    @classmethod
    def a_star(self, field: Map, start: Coordinate, end: Coordinate) -> List[Coordinate]:
        open = {start}
        came_from: Dict[Coordinate, Coordinate] = {}

        scores: Dict[Coordinate, Score] = {
            start: Score(
                g_score=0,
                h_score = self.distance(start, end),
            )
        }

        while len(open) > 0:
            current = sorted(open, key=lambda c: scores[c].f_score)[0]
            if current == end:
                return self.gen_path(current, came_from)

            open.remove(current)
            for neighbour in self.neighbours(field, current[0], current[1]):
                g_score = scores[current].g_score + field[neighbour[1]][neighbour[0]]
                neighbour_scores = scores.get(neighbour)

                if neighbour_scores is None or g_score < neighbour_scores.g_score:
                    came_from[neighbour] = current
                    neighbour_scores = Score(
                        g_score=g_score,
                        h_score=self.distance(neighbour, end)
                    )
                    scores[neighbour] = neighbour_scores

                    if neighbour not in open:
                        open.add(neighbour)

        raise Exception('No Path')

    @classmethod
    def print_field(cls, field: Map, path: List[Coordinate]):
        print('', flush=False)
        for y, row in enumerate(field):
            print(''.join([ bold(str(i), lambda: (x, y) in path) for x, i in enumerate(row) ]), flush=False)
        print('', flush=True)

    def run(self, input: Map) -> int:
        start = (0, 0)
        end = (len(input[0]) - 1, len(input) - 1)

        path = self.a_star(input, start, end)

        self.print_field(input, path)

        path.remove(start)
        return sum([
            input[y][x]
            for x, y
            in path
        ])

class AssignmentOne(Assignment):
    example_result = 40

class AssignmentTwo(Assignment):
    example_result =  315

    @classmethod
    def overflow(self, item):
        return item - 9 if item > 9 else item

    @classmethod
    def increase_map(cls, map) -> Map:
        return [
            [
                cls.overflow(item + dx + dy)
                for dx in range(5)
                for item in row
            ]
            for dy in range(5)
            for row in map
        ]

    def read_input(self, example = False) -> Map:
        return self.increase_map(super().read_input(example))