From 5a38ebf365bfa0718dcbd7ab013af5f2da4610f6 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Mon, 17 Jan 2022 19:41:55 +0100 Subject: Added scaling support via redis --- ttun_server/proxy_queue.py | 163 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 ttun_server/proxy_queue.py (limited to 'ttun_server/proxy_queue.py') diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py new file mode 100644 index 0000000..07e16e0 --- /dev/null +++ b/ttun_server/proxy_queue.py @@ -0,0 +1,163 @@ +import asyncio +import json +import logging +import os +from typing import Awaitable, Callable + +from ttun_server.redis import RedisConnectionPool +from ttun_server.types import RequestData, ResponseData, MemoryConnection + +logger = logging.getLogger(__name__) + +class BaseProxyQueue: + def __init__(self, identifier: str): + self.identifier = identifier + + @classmethod + async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': + raise NotImplementedError(f'Please implement create_for_identifier') + + @classmethod + async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': + assert await cls.has_connection(identifier) + return cls(identifier) + + @classmethod + async def has_connection(cls, identifier) -> bool: + raise NotImplementedError(f'Please implement has_connection') + + async def send_request(self, request_data: RequestData): + raise NotImplementedError(f'Please implement send_request') + + async def handle_request(self) -> RequestData: + raise NotImplementedError(f'Please implement handle_requests') + + async def send_response(self, response_data: ResponseData): + raise NotImplementedError(f'Please implement send_request') + + async def handle_response(self) -> ResponseData: + raise NotImplementedError(f'Please implement handle_response') + + async def delete(self): + raise NotImplementedError(f'Please implement delete') + +class MemoryProxyQueue(BaseProxyQueue): + connections: dict[str, MemoryConnection] = {} + + @classmethod + async def has_connection(cls, identifier) -> bool: + return identifier in cls.connections + + @classmethod + async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': + instance = cls(identifier) + + cls.connections[identifier] = { + 'requests': asyncio.Queue(), + 'responses': asyncio.Queue(), + } + + return instance + + @property + def requests(self) -> asyncio.Queue[RequestData]: + return self.__class__.connections[self.identifier]['requests'] + + @property + def responses(self) -> asyncio.Queue[ResponseData]: + return self.__class__.connections[self.identifier]['responses'] + + async def send_request(self, request_data: RequestData): + await self.requests.put(request_data) + + async def handle_request(self) -> RequestData: + return await self.requests.get() + + async def send_response(self, response_data: ResponseData): + return await self.responses.put(response_data) + + async def handle_response(self) -> ResponseData: + return await self.responses.get() + + async def delete(self): + del self.__class__.connections[self.identifier] + + +class RedisProxyQueue(BaseProxyQueue): + def __init__(self, identifier): + super().__init__(identifier) + + self.pubsub = RedisConnectionPool()\ + .get_connection()\ + .pubsub() + + self.subscription_queue = asyncio.Queue() + + @classmethod + async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': + instance = cls(identifier) + + await instance.pubsub.subscribe(f'request_{identifier}') + return instance + + @classmethod + async def get_for_identifier(cls, identifier: str) -> 'RedisProxyQueue': + instance: 'RedisProxyQueue' = await super().get_for_identifier(identifier) + + await instance.pubsub.subscribe(f'response_{identifier}') + + return instance + + @classmethod + async def has_connection(cls, identifier) -> bool: + logger.debug(await RedisConnectionPool.get_connection().pubsub_channels()) + return f'request_{identifier}' in { + channel.decode() + for channel + in await RedisConnectionPool \ + .get_connection() \ + .pubsub_channels() + } + + async def wait_for_message(self): + async for message in self.pubsub.listen(): + match message['type']: + case 'subscribe': + continue + case _: + return message['data'] + + async def send_request(self, request_data: RequestData): + await RedisConnectionPool \ + .get_connection() \ + .publish(f'request_{self.identifier}', json.dumps(request_data)) + + async def handle_request(self) -> RequestData: + message = await self.wait_for_message() + return json.loads(message) + + async def send_response(self, response_data: ResponseData): + await RedisConnectionPool \ + .get_connection() \ + .publish(f'response_{self.identifier}', json.dumps(response_data)) + + async def handle_response(self) -> ResponseData: + message = await self.wait_for_message() + return json.loads(message) + + async def delete(self): + await self.pubsub.unsubscribe(f'request_{self.identifier}') + + await RedisConnectionPool.get_connection()\ + .srem('connections', self.identifier) + + +class ProxyQueueMeta(type): + def __new__(cls, name, superclasses, attributes): + return RedisProxyQueue \ + if 'REDIS_URL' in os.environ \ + else MemoryProxyQueue + + +class ProxyQueue(BaseProxyQueue, metaclass=ProxyQueueMeta): + pass -- cgit v1.2.3