From 7ac28203290a211a6e17ae0b91bc2b609f110514 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Thu, 8 Jun 2023 08:24:43 +0200 Subject: WIP --- ttun_server/proxy_queue.py | 63 ++++++++++++---------------------------------- 1 file changed, 16 insertions(+), 47 deletions(-) (limited to 'ttun_server/proxy_queue.py') diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py index 07e16e0..e521886 100644 --- a/ttun_server/proxy_queue.py +++ b/ttun_server/proxy_queue.py @@ -2,13 +2,14 @@ import asyncio import json import logging import os -from typing import Awaitable, Callable +from typing import Type from ttun_server.redis import RedisConnectionPool -from ttun_server.types import RequestData, ResponseData, MemoryConnection +from ttun_server.types import Message logger = logging.getLogger(__name__) + class BaseProxyQueue: def __init__(self, identifier: str): self.identifier = identifier @@ -18,7 +19,7 @@ class BaseProxyQueue: raise NotImplementedError(f'Please implement create_for_identifier') @classmethod - async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': + async def get_for_identifier(cls, identifier: str) -> Type['self']: assert await cls.has_connection(identifier) return cls(identifier) @@ -26,23 +27,18 @@ class BaseProxyQueue: async def has_connection(cls, identifier) -> bool: raise NotImplementedError(f'Please implement has_connection') - async def send_request(self, request_data: RequestData): + async def enqueue(self, message: Message): raise NotImplementedError(f'Please implement send_request') - async def handle_request(self) -> RequestData: + async def dequeue(self) -> Message: raise NotImplementedError(f'Please implement handle_requests') - async def send_response(self, response_data: ResponseData): - raise NotImplementedError(f'Please implement send_request') - - async def handle_response(self) -> ResponseData: - raise NotImplementedError(f'Please implement handle_response') - async def delete(self): raise NotImplementedError(f'Please implement delete') + class MemoryProxyQueue(BaseProxyQueue): - connections: dict[str, MemoryConnection] = {} + connections: dict[str, asyncio.Queue] = {} @classmethod async def has_connection(cls, identifier) -> bool: @@ -51,33 +47,15 @@ class MemoryProxyQueue(BaseProxyQueue): @classmethod async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': instance = cls(identifier) - - cls.connections[identifier] = { - 'requests': asyncio.Queue(), - 'responses': asyncio.Queue(), - } + cls.connections[identifier] = asyncio.Queue() return instance - @property - def requests(self) -> asyncio.Queue[RequestData]: - return self.__class__.connections[self.identifier]['requests'] - - @property - def responses(self) -> asyncio.Queue[ResponseData]: - return self.__class__.connections[self.identifier]['responses'] - - async def send_request(self, request_data: RequestData): - await self.requests.put(request_data) + async def enqueue(self, message: Message): + return await self.__class__.connections[self.identifier].put(message) - async def handle_request(self) -> RequestData: - return await self.requests.get() - - async def send_response(self, response_data: ResponseData): - return await self.responses.put(response_data) - - async def handle_response(self) -> ResponseData: - return await self.responses.get() + async def dequeue(self) -> Message: + return await self.__class__.connections[self.identifier].get() async def delete(self): del self.__class__.connections[self.identifier] @@ -127,21 +105,12 @@ class RedisProxyQueue(BaseProxyQueue): case _: return message['data'] - async def send_request(self, request_data: RequestData): - await RedisConnectionPool \ - .get_connection() \ - .publish(f'request_{self.identifier}', json.dumps(request_data)) - - async def handle_request(self) -> RequestData: - message = await self.wait_for_message() - return json.loads(message) - - async def send_response(self, response_data: ResponseData): + async def enqueue(self, message: Message): await RedisConnectionPool \ .get_connection() \ - .publish(f'response_{self.identifier}', json.dumps(response_data)) + .publish(self.identifier, json.dumps(message)) - async def handle_response(self) -> ResponseData: + async def dequeue(self) -> Message: message = await self.wait_for_message() return json.loads(message) -- cgit v1.2.3