summaryrefslogtreecommitdiffstats
path: root/day12/__init__.py
blob: 9d7a0364d73a212612c77847d8e63657773110a7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from abc import ABC
from collections import Counter
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Set

from aoc import BaseAssignment

Coordinate = Tuple[int, int]
Field = List[List[int]]


@dataclass
class Node:
    id: str
    big: bool = False
    nodes: Set['Node'] = field(default_factory=set)

    def __eq__(self, other: 'Node'):
        return self.id == other.id

    def __hash__(self):
        return hash(self.id)

    def __repr__(self):
        return f'Node(id={self.id})'

    def __lt__(self, other):
        return self.id < other.id

class Assignment(BaseAssignment, ABC):
    def parse_item(self, item: str) -> Tuple[str, str]:
        return tuple(item.split('-'))

    @classmethod
    def get_or_create_node(cls, nodes: Dict[str, Node], id) -> Node:
        if id not in nodes:
             nodes[id] = Node(id=id, big=id.upper() == id)
        return nodes[id]

    def read_input(self, example = False) -> Dict[str, Node]:
        nodes = {}
        for a, b in super().read_input(example):
            node_a = self.get_or_create_node(nodes, a)
            node_b = self.get_or_create_node(nodes, b)

            node_a.nodes.add(node_b)
            node_b.nodes.add(node_a)

        return nodes

    @classmethod
    def calculate_all_paths(cls, start: Node, end: Node, visited: List[Node] = []) -> List[List[Node]]:
        raise NotImplementedError

    def run(self, input: Dict[str, Node]) -> int:
        return len(self.calculate_all_paths(input['start'], input['end']))


class AssignmentOne(Assignment):
    example_result = 19

    @classmethod
    def calculate_all_paths(cls, start: Node, end: Node, visited: List[Node] = []) -> List[List[Node]]:
        if start == end:
            return [[end]]

        return [
            [start, *path]
            for node in start.nodes
            if node.big or node not in visited
            for path in cls.calculate_all_paths(node, end, visited=[*visited, start])
        ]


class AssignmentTwo(Assignment):
    example_result =  103

    @classmethod
    def can_be_visited(cls, node: Node, visited: List[Node]) -> bool:
        if node.big:
            return True

        visited_counter = Counter([node for node in visited if not node.big])

        if node.id in ['start', 'end']:
            return visited_counter.get(node, 0) < 1

        small_node_visited_twice = any(
            v == 2
            for n, v
            in visited_counter.items()
        )

        return  visited_counter.get(node, 0) < 1 if small_node_visited_twice else 2


    @classmethod
    def calculate_all_paths(cls, start: Node, end: Node, visited: List[Node] = []) -> List[List[Node]]:
        if start == end:
            return [[end]]

        return [
            [start, *path]
            for node in start.nodes
            if cls.can_be_visited(node, [*visited, start])
            for path in cls.calculate_all_paths(node, end, visited=[*visited, start])
        ]

    def run(self, input: Dict[str, Node]) -> int:
        paths = sorted(self.calculate_all_paths(input['start'], input['end']))
        return len(paths)