summaryrefslogtreecommitdiffstats
path: root/ttun_server
diff options
context:
space:
mode:
Diffstat (limited to 'ttun_server')
-rw-r--r--ttun_server/endpoints.py52
-rw-r--r--ttun_server/proxy_queue.py63
-rw-r--r--ttun_server/types.py17
3 files changed, 65 insertions, 67 deletions
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py
index b25ffe4..6728c31 100644
--- a/ttun_server/endpoints.py
+++ b/ttun_server/endpoints.py
@@ -12,7 +12,7 @@ from starlette.types import Scope, Receive, Send
12from starlette.websockets import WebSocket 12from starlette.websockets import WebSocket
13 13
14from ttun_server.proxy_queue import ProxyQueue 14from ttun_server.proxy_queue import ProxyQueue
15from ttun_server.types import RequestData, Config 15from ttun_server.types import RequestData, Config, Message, MessageType
16 16
17logger = logging.getLogger(__name__) 17logger = logging.getLogger(__name__)
18 18
@@ -33,26 +33,39 @@ class Proxy(HTTPEndpoint):
33 [subdomain, *_] = request.headers['host'].split('.') 33 [subdomain, *_] = request.headers['host'].split('.')
34 response = Response(content='Not Found', status_code=404) 34 response = Response(content='Not Found', status_code=404)
35 35
36 identifier = str(uuid4())
37 response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{identifier}')
38
36 try: 39 try:
37 queue = await ProxyQueue.get_for_identifier(subdomain)
38 40
39 await queue.send_request(RequestData( 41 request_queue = await ProxyQueue.get_for_identifier(subdomain)
40 method=request.method, 42
41 path=str(request.url).replace(str(request.base_url), '/'), 43 await request_queue.enqueue(
42 headers=list(request.headers.items()), 44 Message(
43 body=b64encode(await request.body()).decode() 45 type=MessageType.request,
44 )) 46 identifier=identifier,
47 payload=
48 RequestData(
49 method=request.method,
50 path=str(request.url).replace(str(request.base_url), '/'),
51 headers=list(request.headers.items()),
52 body=b64encode(await request.body()).decode()
53 )
54 )
55 )
45 56
46 _response = await queue.handle_response() 57 _response = await response_queue.dequeue()
58 payload = _response['payload']
47 response = Response( 59 response = Response(
48 status_code=_response['status'], 60 status_code=payload['status'],
49 headers=HeaderMapping(_response['headers']), 61 headers=HeaderMapping(payload['headers']),
50 content=b64decode(_response['body'].encode()) 62 content=b64decode(payload['body'].encode())
51 ) 63 )
52 except AssertionError: 64 except AssertionError:
53 pass 65 pass
54 66 finally:
55 await response(self.scope, self.receive, self.send) 67 await response(self.scope, self.receive, self.send)
68 await response_queue.delete()
56 69
57 70
58class Health(HTTPEndpoint): 71class Health(HTTPEndpoint):
@@ -62,7 +75,6 @@ class Health(HTTPEndpoint):
62 await response(self.scope, self.receive, self.send) 75 await response(self.scope, self.receive, self.send)
63 76
64 77
65
66class Tunnel(WebSocketEndpoint): 78class Tunnel(WebSocketEndpoint):
67 encoding = 'json' 79 encoding = 'json'
68 80
@@ -72,7 +84,7 @@ class Tunnel(WebSocketEndpoint):
72 self.config: Optional[Config] = None 84 self.config: Optional[Config] = None
73 85
74 async def handle_requests(self, websocket: WebSocket): 86 async def handle_requests(self, websocket: WebSocket):
75 while request := await self.proxy_queue.handle_request(): 87 while request := await self.proxy_queue.dequeue():
76 await websocket.send_json(request) 88 await websocket.send_json(request)
77 89
78 async def on_connect(self, websocket: WebSocket) -> None: 90 async def on_connect(self, websocket: WebSocket) -> None:
@@ -94,8 +106,12 @@ class Tunnel(WebSocketEndpoint):
94 106
95 self.request_task = asyncio.create_task(self.handle_requests(websocket)) 107 self.request_task = asyncio.create_task(self.handle_requests(websocket))
96 108
97 async def on_receive(self, websocket: WebSocket, data: Any): 109 async def on_receive(self, websocket: WebSocket, data: Message):
98 await self.proxy_queue.send_response(data) 110 try:
111 response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}")
112 await response_queue.enqueue(data)
113 except AssertionError:
114 pass
99 115
100 async def on_disconnect(self, websocket: WebSocket, close_code: int): 116 async def on_disconnect(self, websocket: WebSocket, close_code: int):
101 await self.proxy_queue.delete() 117 await self.proxy_queue.delete()
diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py
index 07e16e0..e521886 100644
--- a/ttun_server/proxy_queue.py
+++ b/ttun_server/proxy_queue.py
@@ -2,13 +2,14 @@ import asyncio
2import json 2import json
3import logging 3import logging
4import os 4import os
5from typing import Awaitable, Callable 5from typing import Type
6 6
7from ttun_server.redis import RedisConnectionPool 7from ttun_server.redis import RedisConnectionPool
8from ttun_server.types import RequestData, ResponseData, MemoryConnection 8from ttun_server.types import Message
9 9
10logger = logging.getLogger(__name__) 10logger = logging.getLogger(__name__)
11 11
12
12class BaseProxyQueue: 13class BaseProxyQueue:
13 def __init__(self, identifier: str): 14 def __init__(self, identifier: str):
14 self.identifier = identifier 15 self.identifier = identifier
@@ -18,7 +19,7 @@ class BaseProxyQueue:
18 raise NotImplementedError(f'Please implement create_for_identifier') 19 raise NotImplementedError(f'Please implement create_for_identifier')
19 20
20 @classmethod 21 @classmethod
21 async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': 22 async def get_for_identifier(cls, identifier: str) -> Type['self']:
22 assert await cls.has_connection(identifier) 23 assert await cls.has_connection(identifier)
23 return cls(identifier) 24 return cls(identifier)
24 25
@@ -26,23 +27,18 @@ class BaseProxyQueue:
26 async def has_connection(cls, identifier) -> bool: 27 async def has_connection(cls, identifier) -> bool:
27 raise NotImplementedError(f'Please implement has_connection') 28 raise NotImplementedError(f'Please implement has_connection')
28 29
29 async def send_request(self, request_data: RequestData): 30 async def enqueue(self, message: Message):
30 raise NotImplementedError(f'Please implement send_request') 31 raise NotImplementedError(f'Please implement send_request')
31 32
32 async def handle_request(self) -> RequestData: 33 async def dequeue(self) -> Message:
33 raise NotImplementedError(f'Please implement handle_requests') 34 raise NotImplementedError(f'Please implement handle_requests')
34 35
35 async def send_response(self, response_data: ResponseData):
36 raise NotImplementedError(f'Please implement send_request')
37
38 async def handle_response(self) -> ResponseData:
39 raise NotImplementedError(f'Please implement handle_response')
40
41 async def delete(self): 36 async def delete(self):
42 raise NotImplementedError(f'Please implement delete') 37 raise NotImplementedError(f'Please implement delete')
43 38
39
44class MemoryProxyQueue(BaseProxyQueue): 40class MemoryProxyQueue(BaseProxyQueue):
45 connections: dict[str, MemoryConnection] = {} 41 connections: dict[str, asyncio.Queue] = {}
46 42
47 @classmethod 43 @classmethod
48 async def has_connection(cls, identifier) -> bool: 44 async def has_connection(cls, identifier) -> bool:
@@ -51,33 +47,15 @@ class MemoryProxyQueue(BaseProxyQueue):
51 @classmethod 47 @classmethod
52 async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': 48 async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue':
53 instance = cls(identifier) 49 instance = cls(identifier)
54 50 cls.connections[identifier] = asyncio.Queue()
55 cls.connections[identifier] = {
56 'requests': asyncio.Queue(),
57 'responses': asyncio.Queue(),
58 }
59 51
60 return instance 52 return instance
61 53
62 @property 54 async def enqueue(self, message: Message):
63 def requests(self) -> asyncio.Queue[RequestData]: 55 return await self.__class__.connections[self.identifier].put(message)
64 return self.__class__.connections[self.identifier]['requests']
65
66 @property
67 def responses(self) -> asyncio.Queue[ResponseData]:
68 return self.__class__.connections[self.identifier]['responses']
69
70 async def send_request(self, request_data: RequestData):
71 await self.requests.put(request_data)
72 56
73 async def handle_request(self) -> RequestData: 57 async def dequeue(self) -> Message:
74 return await self.requests.get() 58 return await self.__class__.connections[self.identifier].get()
75
76 async def send_response(self, response_data: ResponseData):
77 return await self.responses.put(response_data)
78
79 async def handle_response(self) -> ResponseData:
80 return await self.responses.get()
81 59
82 async def delete(self): 60 async def delete(self):
83 del self.__class__.connections[self.identifier] 61 del self.__class__.connections[self.identifier]
@@ -127,21 +105,12 @@ class RedisProxyQueue(BaseProxyQueue):
127 case _: 105 case _:
128 return message['data'] 106 return message['data']
129 107
130 async def send_request(self, request_data: RequestData): 108 async def enqueue(self, message: Message):
131 await RedisConnectionPool \
132 .get_connection() \
133 .publish(f'request_{self.identifier}', json.dumps(request_data))
134
135 async def handle_request(self) -> RequestData:
136 message = await self.wait_for_message()
137 return json.loads(message)<