diff options
Diffstat (limited to 'ttun_server/proxy_queue.py')
| -rw-r--r-- | ttun_server/proxy_queue.py | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py new file mode 100644 index 0000000..07e16e0 --- /dev/null +++ b/ttun_server/proxy_queue.py | |||
| @@ -0,0 +1,163 @@ | |||
| 1 | import asyncio | ||
| 2 | import json | ||
| 3 | import logging | ||
| 4 | import os | ||
| 5 | from typing import Awaitable, Callable | ||
| 6 | |||
| 7 | from ttun_server.redis import RedisConnectionPool | ||
| 8 | from ttun_server.types import RequestData, ResponseData, MemoryConnection | ||
| 9 | |||
| 10 | logger = logging.getLogger(__name__) | ||
| 11 | |||
| 12 | class BaseProxyQueue: | ||
| 13 | def __init__(self, identifier: str): | ||
| 14 | self.identifier = identifier | ||
| 15 | |||
| 16 | @classmethod | ||
| 17 | async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': | ||
| 18 | raise NotImplementedError(f'Please implement create_for_identifier') | ||
| 19 | |||
| 20 | @classmethod | ||
| 21 | async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': | ||
| 22 | assert await cls.has_connection(identifier) | ||
| 23 | return cls(identifier) | ||
| 24 | |||
| 25 | @classmethod | ||
| 26 | async def has_connection(cls, identifier) -> bool: | ||
| 27 | raise NotImplementedError(f'Please implement has_connection') | ||
| 28 | |||
| 29 | async def send_request(self, request_data: RequestData): | ||
| 30 | raise NotImplementedError(f'Please implement send_request') | ||
| 31 | |||
| 32 | async def handle_request(self) -> RequestData: | ||
| 33 | raise NotImplementedError(f'Please implement handle_requests') | ||
| 34 | |||
| 35 | async def send_response(self, response_data: ResponseData): | ||
| 36 | raise NotImplementedError(f'Please implement send_request') | ||
| 37 | |||
| 38 | async def handle_response(self) -> ResponseData: | ||
| 39 | raise NotImplementedError(f'Please implement handle_response') | ||
| 40 | |||
| 41 | async def delete(self): | ||
| 42 | raise NotImplementedError(f'Please implement delete') | ||
| 43 | |||
| 44 | class MemoryProxyQueue(BaseProxyQueue): | ||
| 45 | connections: dict[str, MemoryConnection] = {} | ||
| 46 | |||
| 47 | @classmethod | ||
| 48 | async def has_connection(cls, identifier) -> bool: | ||
| 49 | return identifier in cls.connections | ||
| 50 | |||
| 51 | @classmethod | ||
| 52 | async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': | ||
| 53 | instance = cls(identifier) | ||
| 54 | |||
| 55 | cls.connections[identifier] = { | ||
| 56 | 'requests': asyncio.Queue(), | ||
| 57 | 'responses': asyncio.Queue(), | ||
| 58 | } | ||
| 59 | |||
| 60 | return instance | ||
| 61 | |||
| 62 | @property | ||
| 63 | def requests(self) -> asyncio.Queue[RequestData]: | ||
| 64 | return self.__class__.connections[self.identifier]['requests'] | ||
| 65 | |||
| 66 | @property | ||
| 67 | def responses(self) -> asyncio.Queue[ResponseData]: | ||
| 68 | return self.__class__.connections[self.identifier]['responses'] | ||
| 69 | |||
| 70 | async def send_request(self, request_data: RequestData): | ||
| 71 | await self.requests.put(request_data) | ||
| 72 | |||
| 73 | async def handle_request(self) -> RequestData: | ||
| 74 | return await self.requests.get() | ||
| 75 | |||
| 76 | async def send_response(self, response_data: ResponseData): | ||
| 77 | return await self.responses.put(response_data) | ||
| 78 | |||
| 79 | async def handle_response(self) -> ResponseData: | ||
| 80 | return await self.responses.get() | ||
| 81 | |||
| 82 | async def delete(self): | ||
| 83 | del self.__class__.connections[self.identifier] | ||
| 84 | |||
| 85 | |||
| 86 | class RedisProxyQueue(BaseProxyQueue): | ||
| 87 | def __init__(self, identifier): | ||
| 88 | super().__init__(identifier) | ||
| 89 | |||
| 90 | self.pubsub = RedisConnectionPool()\ | ||
| 91 | .get_connection()\ | ||
| 92 | .pubsub() | ||
| 93 | |||
| 94 | self.subscription_queue = asyncio.Queue() | ||
| 95 | |||
| 96 | @classmethod | ||
| 97 | async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': | ||
| 98 | instance = cls(identifier) | ||
| 99 | |||
| 100 | await instance.pubsub.subscribe(f'request_{identifier}') | ||
| 101 | return instance | ||
| 102 | |||
| 103 | @classmethod | ||
| 104 | async def get_for_identifier(cls, identifier: str) -> 'RedisProxyQueue': | ||
| 105 | instance: 'RedisProxyQueue' = await super().get_for_identifier(identifier) | ||
| 106 | |||
| 107 | await instance.pubsub.subscribe(f'response_{identifier}') | ||
| 108 | |||
| 109 | return instance | ||
| 110 | |||
| 111 | @classmethod | ||
| 112 | async def has_connection(cls, identifier) -> bool: | ||
| 113 | logger.debug(await RedisConnectionPool.get_connection().pubsub_channels()) | ||
| 114 | return f'request_{identifier}' in { | ||
| 115 | channel.decode() | ||
| 116 | for channel | ||
| 117 | in await RedisConnectionPool \ | ||
| 118 | .get_connection() \ | ||
| 119 | .pubsub_channels() | ||
| 120 | } | ||
| 121 | |||
| 122 | async def wait_for_message(self): | ||
| 123 | async for message in self.pubsub.listen(): | ||
| 124 | match message['type']: | ||
| 125 | case 'subscribe': | ||
| 126 | continue | ||
| 127 | case _: | ||
| 128 | return message['data'] | ||
| 129 | |||
| 130 | async def send_request(self, request_data: RequestData): | ||
| 131 | await RedisConnectionPool \ | ||
| 132 | .get_connection() \ | ||
| 133 | .publish(f'request_{self.identifier}', json.dumps(request_data)) | ||
| 134 | |||
| 135 | async def handle_request(self) -> RequestData: | ||
| 136 | message = await self.wait_for_message() | ||
| 137 | return json.loads(message) | ||
| 138 | |||
| 139 | async def send_response(self, response_data: ResponseData): | ||
| 140 | await RedisConnectionPool \ | ||
| 141 | .get_connection() \ | ||
| 142 | .publish(f'response_{self.identifier}', json.dumps(response_data)) | ||
| 143 | |||
| 144 | async def handle_response(self) -> ResponseData: | ||
| 145 | message = await self.wait_for_message() | ||
| 146 | return json.loads(message) | ||
| 147 | |||
| 148 | async def delete(self): | ||
| 149 | await self.pubsub.unsubscribe(f'request_{self.identifier}') | ||
| 150 | |||
| 151 | await RedisConnectionPool.get_connection()\ | ||
| 152 | .srem('connections', self.identifier) | ||
| 153 | |||
| 154 | |||
| 155 | class ProxyQueueMeta(type): | ||
| 156 | def __new__(cls, name, superclasses, attributes): | ||
| 157 | return RedisProxyQueue \ | ||
| 158 | if 'REDIS_URL' in os.environ \ | ||
| 159 | else MemoryProxyQueue | ||
| 160 | |||
| 161 | |||
| 162 | class ProxyQueue(BaseProxyQueue, metaclass=ProxyQueueMeta): | ||
| 163 | pass | ||
