summaryrefslogtreecommitdiffstats
path: root/ttun_server/proxy_queue.py
diff options
context:
space:
mode:
Diffstat (limited to 'ttun_server/proxy_queue.py')
-rw-r--r--ttun_server/proxy_queue.py63
1 files changed, 16 insertions, 47 deletions
diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py
index 07e16e0..e521886 100644
--- a/ttun_server/proxy_queue.py
+++ b/ttun_server/proxy_queue.py
@@ -2,13 +2,14 @@ import asyncio
2import json 2import json
3import logging 3import logging
4import os 4import os
5from typing import Awaitable, Callable 5from typing import Type
6 6
7from ttun_server.redis import RedisConnectionPool 7from ttun_server.redis import RedisConnectionPool
8from ttun_server.types import RequestData, ResponseData, MemoryConnection 8from ttun_server.types import Message
9 9
10logger = logging.getLogger(__name__) 10logger = logging.getLogger(__name__)
11 11
12
12class BaseProxyQueue: 13class BaseProxyQueue:
13 def __init__(self, identifier: str): 14 def __init__(self, identifier: str):
14 self.identifier = identifier 15 self.identifier = identifier
@@ -18,7 +19,7 @@ class BaseProxyQueue:
18 raise NotImplementedError(f'Please implement create_for_identifier') 19 raise NotImplementedError(f'Please implement create_for_identifier')
19 20
20 @classmethod 21 @classmethod
21 async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': 22 async def get_for_identifier(cls, identifier: str) -> Type['self']:
22 assert await cls.has_connection(identifier) 23 assert await cls.has_connection(identifier)
23 return cls(identifier) 24 return cls(identifier)
24 25
@@ -26,23 +27,18 @@ class BaseProxyQueue:
26 async def has_connection(cls, identifier) -> bool: 27 async def has_connection(cls, identifier) -> bool:
27 raise NotImplementedError(f'Please implement has_connection') 28 raise NotImplementedError(f'Please implement has_connection')
28 29
29 async def send_request(self, request_data: RequestData): 30 async def enqueue(self, message: Message):
30 raise NotImplementedError(f'Please implement send_request') 31 raise NotImplementedError(f'Please implement send_request')
31 32
32 async def handle_request(self) -> RequestData: 33 async def dequeue(self) -> Message:
33 raise NotImplementedError(f'Please implement handle_requests') 34 raise NotImplementedError(f'Please implement handle_requests')
34 35
35 async def send_response(self, response_data: ResponseData):
36 raise NotImplementedError(f'Please implement send_request')
37
38 async def handle_response(self) -> ResponseData:
39 raise NotImplementedError(f'Please implement handle_response')
40
41 async def delete(self): 36 async def delete(self):
42 raise NotImplementedError(f'Please implement delete') 37 raise NotImplementedError(f'Please implement delete')
43 38
39
44class MemoryProxyQueue(BaseProxyQueue): 40class MemoryProxyQueue(BaseProxyQueue):
45 connections: dict[str, MemoryConnection] = {} 41 connections: dict[str, asyncio.Queue] = {}
46 42
47 @classmethod 43 @classmethod
48 async def has_connection(cls, identifier) -> bool: 44 async def has_connection(cls, identifier) -> bool:
@@ -51,33 +47,15 @@ class MemoryProxyQueue(BaseProxyQueue):
51 @classmethod 47 @classmethod
52 async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': 48 async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue':
53 instance = cls(identifier) 49 instance = cls(identifier)
54 50 cls.connections[identifier] = asyncio.Queue()
55 cls.connections[identifier] = {
56 'requests': asyncio.Queue(),
57 'responses': asyncio.Queue(),
58 }
59 51
60 return instance 52 return instance
61 53
62 @property 54 async def enqueue(self, message: Message):
63 def requests(self) -> asyncio.Queue[RequestData]: 55 return await self.__class__.connections[self.identifier].put(message)
64 return self.__class__.connections[self.identifier]['requests']
65
66 @property
67 def responses(self) -> asyncio.Queue[ResponseData]:
68 return self.__class__.connections[self.identifier]['responses']
69
70 async def send_request(self, request_data: RequestData):
71 await self.requests.put(request_data)
72 56
73 async def handle_request(self) -> RequestData: 57 async def dequeue(self) -> Message:
74 return await self.requests.get() 58 return await self.__class__.connections[self.identifier].get()
75
76 async def send_response(self, response_data: ResponseData):
77 return await self.responses.put(response_data)
78
79 async def handle_response(self) -> ResponseData:
80 return await self.responses.get()
81 59
82 async def delete(self): 60 async def delete(self):
83 del self.__class__.connections[self.identifier] 61 del self.__class__.connections[self.identifier]
@@ -127,21 +105,12 @@ class RedisProxyQueue(BaseProxyQueue):
127 case _: 105 case _:
128 return message['data'] 106 return message['data']
129 107
130 async def send_request(self, request_data: RequestData): 108 async def enqueue(self, message: Message):
131 await RedisConnectionPool \
132 .get_connection() \
133 .publish(f'request_{self.identifier}', json.dumps(request_data))
134
135 async def handle_request(self) -> RequestData:
136 message = await self.wait_for_message()
137 return json.loads(message)
138
139 async def send_response(self, response_data: ResponseData):
140 await RedisConnectionPool \ 109 await RedisConnectionPool \
141 .get_connection() \ 110 .get_connection() \
142 .publish(f'response_{self.identifier}', json.dumps(response_data)) 111 .publish(self.identifier, json.dumps(message))
143 112
144 async def handle_response(self) -> ResponseData: 113 async def dequeue(self) -> Message:
145 message = await self.wait_for_message() 114 message = await self.wait_for_message()
146 return json.loads(message) 115 return json.loads(message)
147 116