1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
|
import asyncio
import json
import logging
import os
from typing import Awaitable, Callable
from ttun_server.redis import RedisConnectionPool
from ttun_server.types import RequestData, ResponseData, MemoryConnection
logger = logging.getLogger(__name__)
class BaseProxyQueue:
def __init__(self, identifier: str):
self.identifier = identifier
@classmethod
async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
raise NotImplementedError(f'Please implement create_for_identifier')
@classmethod
async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
assert await cls.has_connection(identifier)
return cls(identifier)
@classmethod
async def has_connection(cls, identifier) -> bool:
raise NotImplementedError(f'Please implement has_connection')
async def send_request(self, request_data: RequestData):
raise NotImplementedError(f'Please implement send_request')
async def handle_request(self) -> RequestData:
raise NotImplementedError(f'Please implement handle_requests')
async def send_response(self, response_data: ResponseData):
raise NotImplementedError(f'Please implement send_request')
async def handle_response(self) -> ResponseData:
raise NotImplementedError(f'Please implement handle_response')
async def delete(self):
raise NotImplementedError(f'Please implement delete')
class MemoryProxyQueue(BaseProxyQueue):
connections: dict[str, MemoryConnection] = {}
@classmethod
async def has_connection(cls, identifier) -> bool:
return identifier in cls.connections
@classmethod
async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue':
instance = cls(identifier)
cls.connections[identifier] = {
'requests': asyncio.Queue(),
'responses': asyncio.Queue(),
}
return instance
@property
def requests(self) -> asyncio.Queue[RequestData]:
return self.__class__.connections[self.identifier]['requests']
@property
def responses(self) -> asyncio.Queue[ResponseData]:
return self.__class__.connections[self.identifier]['responses']
async def send_request(self, request_data: RequestData):
await self.requests.put(request_data)
async def handle_request(self) -> RequestData:
return await self.requests.get()
async def send_response(self, response_data: ResponseData):
return await self.responses.put(response_data)
async def handle_response(self) -> ResponseData:
return await self.responses.get()
async def delete(self):
del self.__class__.connections[self.identifier]
class RedisProxyQueue(BaseProxyQueue):
def __init__(self, identifier):
super().__init__(identifier)
self.pubsub = RedisConnectionPool()\
.get_connection()\
.pubsub()
self.subscription_queue = asyncio.Queue()
@classmethod
async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
instance = cls(identifier)
await instance.pubsub.subscribe(f'request_{identifier}')
return instance
@classmethod
async def get_for_identifier(cls, identifier: str) -> 'RedisProxyQueue':
instance: 'RedisProxyQueue' = await super().get_for_identifier(identifier)
await instance.pubsub.subscribe(f'response_{identifier}')
return instance
@classmethod
async def has_connection(cls, identifier) -> bool:
logger.debug(await RedisConnectionPool.get_connection().pubsub_channels())
return f'request_{identifier}' in {
channel.decode()
for channel
in await RedisConnectionPool \
.get_connection() \
.pubsub_channels()
}
async def wait_for_message(self):
async for message in self.pubsub.listen():
match message['type']:
case 'subscribe':
continue
case _:
return message['data']
async def send_request(self, request_data: RequestData):
await RedisConnectionPool \
.get_connection() \
.publish(f'request_{self.identifier}', json.dumps(request_data))
async def handle_request(self) -> RequestData:
message = await self.wait_for_message()
return json.loads(message)
async def send_response(self, response_data: ResponseData):
await RedisConnectionPool \
.get_connection() \
.publish(f'response_{self.identifier}', json.dumps(response_data))
async def handle_response(self) -> ResponseData:
message = await self.wait_for_message()
return json.loads(message)
async def delete(self):
await self.pubsub.unsubscribe(f'request_{self.identifier}')
await RedisConnectionPool.get_connection()\
.srem('connections', self.identifier)
class ProxyQueueMeta(type):
def __new__(cls, name, superclasses, attributes):
return RedisProxyQueue \
if 'REDIS_URL' in os.environ \
else MemoryProxyQueue
class ProxyQueue(BaseProxyQueue, metaclass=ProxyQueueMeta):
pass
|