from abc import ABC from collections import Counter from dataclasses import dataclass, field from typing import List, Tuple, Dict, Set from aoc import BaseAssignment Coordinate = Tuple[int, int] Field = List[List[int]] @dataclass class Node: id: str big: bool = False nodes: Set['Node'] = field(default_factory=set) def __eq__(self, other: 'Node'): return self.id == other.id def __hash__(self): return hash(self.id) def __repr__(self): return f'Node(id={self.id})' def __lt__(self, other): return self.id < other.id class Assignment(BaseAssignment, ABC): def parse_item(self, item: str) -> Tuple[str, str]: return tuple(item.split('-')) @classmethod def get_or_create_node(cls, nodes: Dict[str, Node], id) -> Node: if id not in nodes: nodes[id] = Node(id=id, big=id.upper() == id) return nodes[id] def read_input(self, example = False) -> Dict[str, Node]: nodes = {} for a, b in super().read_input(example): node_a = self.get_or_create_node(nodes, a) node_b = self.get_or_create_node(nodes, b) node_a.nodes.add(node_b) node_b.nodes.add(node_a) return nodes @classmethod def calculate_all_paths(cls, start: Node, end: Node, visited: List[Node] = []) -> List[List[Node]]: raise NotImplementedError def run(self, input: Dict[str, Node]) -> int: return len(self.calculate_all_paths(input['start'], input['end'])) class AssignmentOne(Assignment): example_result = 19 @classmethod def calculate_all_paths(cls, start: Node, end: Node, visited: List[Node] = []) -> List[List[Node]]: if start == end: return [[end]] return [ [start, *path] for node in start.nodes if node.big or node not in visited for path in cls.calculate_all_paths(node, end, visited=[*visited, start]) ] class AssignmentTwo(Assignment): example_result = 103 @classmethod def can_be_visited(cls, node: Node, visited: List[Node]) -> bool: if node.big: return True visited_counter = Counter([node for node in visited if not node.big]) if node.id in ['start', 'end']: return visited_counter.get(node, 0) < 1 small_node_visited_twice = any( v == 2 for n, v in visited_counter.items() ) return visited_counter.get(node, 0) < 1 if small_node_visited_twice else 2 @classmethod def calculate_all_paths(cls, start: Node, end: Node, visited: List[Node] = []) -> List[List[Node]]: if start == end: return [[end]] return [ [start, *path] for node in start.nodes if cls.can_be_visited(node, [*visited, start]) for path in cls.calculate_all_paths(node, end, visited=[*visited, start]) ] def run(self, input: Dict[str, Node]) -> int: paths = sorted(self.calculate_all_paths(input['start'], input['end'])) print() for path in paths: print(','.join([ n.id for n in path ])) return len(paths)