summaryrefslogtreecommitdiffstats
path: root/ttun_server/proxy_queue.py
blob: 07e16e05b748de2abdbb7cbaa84f1aa004063700 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import asyncio
import json
import logging
import os
from typing import Awaitable, Callable

from ttun_server.redis import RedisConnectionPool
from ttun_server.types import RequestData, ResponseData, MemoryConnection

logger = logging.getLogger(__name__)

class BaseProxyQueue:
    def __init__(self, identifier: str):
        self.identifier = identifier

    @classmethod
    async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
        raise NotImplementedError(f'Please implement create_for_identifier')

    @classmethod
    async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
        assert await cls.has_connection(identifier)
        return cls(identifier)

    @classmethod
    async def has_connection(cls, identifier) -> bool:
        raise NotImplementedError(f'Please implement has_connection')

    async def send_request(self, request_data: RequestData):
        raise NotImplementedError(f'Please implement send_request')

    async def handle_request(self) -> RequestData:
        raise NotImplementedError(f'Please implement handle_requests')

    async def send_response(self, response_data: ResponseData):
        raise NotImplementedError(f'Please implement send_request')

    async def handle_response(self) -> ResponseData:
        raise NotImplementedError(f'Please implement handle_response')

    async def delete(self):
        raise NotImplementedError(f'Please implement delete')

class MemoryProxyQueue(BaseProxyQueue):
    connections: dict[str, MemoryConnection] = {}

    @classmethod
    async def has_connection(cls, identifier) -> bool:
        return identifier in cls.connections

    @classmethod
    async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue':
        instance = cls(identifier)

        cls.connections[identifier] = {
            'requests': asyncio.Queue(),
            'responses': asyncio.Queue(),
        }

        return instance

    @property
    def requests(self) -> asyncio.Queue[RequestData]:
        return self.__class__.connections[self.identifier]['requests']

    @property
    def responses(self) -> asyncio.Queue[ResponseData]:
        return self.__class__.connections[self.identifier]['responses']

    async def send_request(self, request_data: RequestData):
        await self.requests.put(request_data)

    async def handle_request(self) -> RequestData:
        return await self.requests.get()

    async def send_response(self, response_data: ResponseData):
        return await self.responses.put(response_data)

    async def handle_response(self) -> ResponseData:
        return await self.responses.get()

    async def delete(self):
        del self.__class__.connections[self.identifier]


class RedisProxyQueue(BaseProxyQueue):
    def __init__(self, identifier):
        super().__init__(identifier)

        self.pubsub = RedisConnectionPool()\
            .get_connection()\
            .pubsub()

        self.subscription_queue = asyncio.Queue()

    @classmethod
    async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
        instance = cls(identifier)

        await instance.pubsub.subscribe(f'request_{identifier}')
        return instance

    @classmethod
    async def get_for_identifier(cls, identifier: str) -> 'RedisProxyQueue':
        instance: 'RedisProxyQueue' = await super().get_for_identifier(identifier)

        await instance.pubsub.subscribe(f'response_{identifier}')

        return instance

    @classmethod
    async def has_connection(cls, identifier) -> bool:
        logger.debug(await RedisConnectionPool.get_connection().pubsub_channels())
        return f'request_{identifier}' in {
            channel.decode()
            for channel
            in await RedisConnectionPool \
                .get_connection() \
                .pubsub_channels()
        }

    async def wait_for_message(self):
        async for message in self.pubsub.listen():
            match message['type']:
                case 'subscribe':
                    continue
                case _:
                    return message['data']

    async def send_request(self, request_data: RequestData):
        await RedisConnectionPool \
            .get_connection() \
            .publish(f'request_{self.identifier}', json.dumps(request_data))

    async def handle_request(self) -> RequestData:
        message = await self.wait_for_message()
        return json.loads(message)

    async def send_response(self, response_data: ResponseData):
        await RedisConnectionPool \
            .get_connection() \
            .publish(f'response_{self.identifier}', json.dumps(response_data))

    async def handle_response(self) -> ResponseData:
        message = await self.wait_for_message()
        return json.loads(message)

    async def delete(self):
        await self.pubsub.unsubscribe(f'request_{self.identifier}')

        await RedisConnectionPool.get_connection()\
            .srem('connections', self.identifier)


class ProxyQueueMeta(type):
    def __new__(cls, name, superclasses, attributes):
        return RedisProxyQueue \
            if 'REDIS_URL' in os.environ \
            else MemoryProxyQueue


class ProxyQueue(BaseProxyQueue, metaclass=ProxyQueueMeta):
    pass