summaryrefslogtreecommitdiffstats
path: root/ttun_server/proxy_queue.py
blob: e521886f344def078520d7233344aed2cdc16a31 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import asyncio
import json
import logging
import os
from typing import Type

from ttun_server.redis import RedisConnectionPool
from ttun_server.types import Message

logger = logging.getLogger(__name__)


class BaseProxyQueue:
    def __init__(self, identifier: str):
        self.identifier = identifier

    @classmethod
    async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
        raise NotImplementedError(f'Please implement create_for_identifier')

    @classmethod
    async def get_for_identifier(cls, identifier: str) -> Type['self']:
        assert await cls.has_connection(identifier)
        return cls(identifier)

    @classmethod
    async def has_connection(cls, identifier) -> bool:
        raise NotImplementedError(f'Please implement has_connection')

    async def enqueue(self, message: Message):
        raise NotImplementedError(f'Please implement send_request')

    async def dequeue(self) -> Message:
        raise NotImplementedError(f'Please implement handle_requests')

    async def delete(self):
        raise NotImplementedError(f'Please implement delete')


class MemoryProxyQueue(BaseProxyQueue):
    connections: dict[str, asyncio.Queue] = {}

    @classmethod
    async def has_connection(cls, identifier) -> bool:
        return identifier in cls.connections

    @classmethod
    async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue':
        instance = cls(identifier)
        cls.connections[identifier] = asyncio.Queue()

        return instance

    async def enqueue(self, message: Message):
        return await self.__class__.connections[self.identifier].put(message)

    async def dequeue(self) -> Message:
        return await self.__class__.connections[self.identifier].get()

    async def delete(self):
        del self.__class__.connections[self.identifier]


class RedisProxyQueue(BaseProxyQueue):
    def __init__(self, identifier):
        super().__init__(identifier)

        self.pubsub = RedisConnectionPool()\
            .get_connection()\
            .pubsub()

        self.subscription_queue = asyncio.Queue()

    @classmethod
    async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
        instance = cls(identifier)

        await instance.pubsub.subscribe(f'request_{identifier}')
        return instance

    @classmethod
    async def get_for_identifier(cls, identifier: str) -> 'RedisProxyQueue':
        instance: 'RedisProxyQueue' = await super().get_for_identifier(identifier)

        await instance.pubsub.subscribe(f'response_{identifier}')

        return instance

    @classmethod
    async def has_connection(cls, identifier) -> bool:
        logger.debug(await RedisConnectionPool.get_connection().pubsub_channels())
        return f'request_{identifier}' in {
            channel.decode()
            for channel
            in await RedisConnectionPool \
                .get_connection() \
                .pubsub_channels()
        }

    async def wait_for_message(self):
        async for message in self.pubsub.listen():
            match message['type']:
                case 'subscribe':
                    continue
                case _:
                    return message['data']

    async def enqueue(self, message: Message):
        await RedisConnectionPool \
            .get_connection() \
            .publish(self.identifier, json.dumps(message))

    async def dequeue(self) -> Message:
        message = await self.wait_for_message()
        return json.loads(message)

    async def delete(self):
        await self.pubsub.unsubscribe(f'request_{self.identifier}')

        await RedisConnectionPool.get_connection()\
            .srem('connections', self.identifier)


class ProxyQueueMeta(type):
    def __new__(cls, name, superclasses, attributes):
        return RedisProxyQueue \
            if 'REDIS_URL' in os.environ \
            else MemoryProxyQueue


class ProxyQueue(BaseProxyQueue, metaclass=ProxyQueueMeta):
    pass