diff options
Diffstat (limited to 'ttun_server')
| -rw-r--r-- | ttun_server/endpoints.py | 52 | ||||
| -rw-r--r-- | ttun_server/proxy_queue.py | 63 | ||||
| -rw-r--r-- | ttun_server/types.py | 17 |
3 files changed, 65 insertions, 67 deletions
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index b25ffe4..6728c31 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py | |||
| @@ -12,7 +12,7 @@ from starlette.types import Scope, Receive, Send | |||
| 12 | from starlette.websockets import WebSocket | 12 | from starlette.websockets import WebSocket |
| 13 | 13 | ||
| 14 | from ttun_server.proxy_queue import ProxyQueue | 14 | from ttun_server.proxy_queue import ProxyQueue |
| 15 | from ttun_server.types import RequestData, Config | 15 | from ttun_server.types import RequestData, Config, Message, MessageType |
| 16 | 16 | ||
| 17 | logger = logging.getLogger(__name__) | 17 | logger = logging.getLogger(__name__) |
| 18 | 18 | ||
| @@ -33,26 +33,39 @@ class Proxy(HTTPEndpoint): | |||
| 33 | [subdomain, *_] = request.headers['host'].split('.') | 33 | [subdomain, *_] = request.headers['host'].split('.') |
| 34 | response = Response(content='Not Found', status_code=404) | 34 | response = Response(content='Not Found', status_code=404) |
| 35 | 35 | ||
| 36 | identifier = str(uuid4()) | ||
| 37 | response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{identifier}') | ||
| 38 | |||
| 36 | try: | 39 | try: |
| 37 | queue = await ProxyQueue.get_for_identifier(subdomain) | ||
| 38 | 40 | ||
| 39 | await queue.send_request(RequestData( | 41 | request_queue = await ProxyQueue.get_for_identifier(subdomain) |
| 40 | method=request.method, | 42 | |
| 41 | path=str(request.url).replace(str(request.base_url), '/'), | 43 | await request_queue.enqueue( |
| 42 | headers=list(request.headers.items()), | 44 | Message( |
| 43 | body=b64encode(await request.body()).decode() | 45 | type=MessageType.request, |
| 44 | )) | 46 | identifier=identifier, |
| 47 | payload= | ||
| 48 | RequestData( | ||
| 49 | method=request.method, | ||
| 50 | path=str(request.url).replace(str(request.base_url), '/'), | ||
| 51 | headers=list(request.headers.items()), | ||
| 52 | body=b64encode(await request.body()).decode() | ||
| 53 | ) | ||
| 54 | ) | ||
| 55 | ) | ||
| 45 | 56 | ||
| 46 | _response = await queue.handle_response() | 57 | _response = await response_queue.dequeue() |
| 58 | payload = _response['payload'] | ||
| 47 | response = Response( | 59 | response = Response( |
| 48 | status_code=_response['status'], | 60 | status_code=payload['status'], |
| 49 | headers=HeaderMapping(_response['headers']), | 61 | headers=HeaderMapping(payload['headers']), |
| 50 | content=b64decode(_response['body'].encode()) | 62 | content=b64decode(payload['body'].encode()) |
| 51 | ) | 63 | ) |
| 52 | except AssertionError: | 64 | except AssertionError: |
| 53 | pass | 65 | pass |
| 54 | 66 | finally: | |
| 55 | await response(self.scope, self.receive, self.send) | 67 | await response(self.scope, self.receive, self.send) |
| 68 | await response_queue.delete() | ||
| 56 | 69 | ||
| 57 | 70 | ||
| 58 | class Health(HTTPEndpoint): | 71 | class Health(HTTPEndpoint): |
| @@ -62,7 +75,6 @@ class Health(HTTPEndpoint): | |||
| 62 | await response(self.scope, self.receive, self.send) | 75 | await response(self.scope, self.receive, self.send) |
| 63 | 76 | ||
| 64 | 77 | ||
| 65 | |||
| 66 | class Tunnel(WebSocketEndpoint): | 78 | class Tunnel(WebSocketEndpoint): |
| 67 | encoding = 'json' | 79 | encoding = 'json' |
| 68 | 80 | ||
| @@ -72,7 +84,7 @@ class Tunnel(WebSocketEndpoint): | |||
| 72 | self.config: Optional[Config] = None | 84 | self.config: Optional[Config] = None |
| 73 | 85 | ||
| 74 | async def handle_requests(self, websocket: WebSocket): | 86 | async def handle_requests(self, websocket: WebSocket): |
| 75 | while request := await self.proxy_queue.handle_request(): | 87 | while request := await self.proxy_queue.dequeue(): |
| 76 | await websocket.send_json(request) | 88 | await websocket.send_json(request) |
| 77 | 89 | ||
| 78 | async def on_connect(self, websocket: WebSocket) -> None: | 90 | async def on_connect(self, websocket: WebSocket) -> None: |
| @@ -94,8 +106,12 @@ class Tunnel(WebSocketEndpoint): | |||
| 94 | 106 | ||
| 95 | self.request_task = asyncio.create_task(self.handle_requests(websocket)) | 107 | self.request_task = asyncio.create_task(self.handle_requests(websocket)) |
| 96 | 108 | ||
| 97 | async def on_receive(self, websocket: WebSocket, data: Any): | 109 | async def on_receive(self, websocket: WebSocket, data: Message): |
| 98 | await self.proxy_queue.send_response(data) | 110 | try: |
| 111 | response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}") | ||
| 112 | await response_queue.enqueue(data) | ||
| 113 | except AssertionError: | ||
| 114 | pass | ||
| 99 | 115 | ||
| 100 | async def on_disconnect(self, websocket: WebSocket, close_code: int): | 116 | async def on_disconnect(self, websocket: WebSocket, close_code: int): |
| 101 | await self.proxy_queue.delete() | 117 | await self.proxy_queue.delete() |
diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py index 07e16e0..e521886 100644 --- a/ttun_server/proxy_queue.py +++ b/ttun_server/proxy_queue.py | |||
| @@ -2,13 +2,14 @@ import asyncio | |||
| 2 | import json | 2 | import json |
| 3 | import logging | 3 | import logging |
| 4 | import os | 4 | import os |
| 5 | from typing import Awaitable, Callable | 5 | from typing import Type |
| 6 | 6 | ||
| 7 | from ttun_server.redis import RedisConnectionPool | 7 | from ttun_server.redis import RedisConnectionPool |
| 8 | from ttun_server.types import RequestData, ResponseData, MemoryConnection | 8 | from ttun_server.types import Message |
| 9 | 9 | ||
| 10 | logger = logging.getLogger(__name__) | 10 | logger = logging.getLogger(__name__) |
| 11 | 11 | ||
| 12 | |||
| 12 | class BaseProxyQueue: | 13 | class BaseProxyQueue: |
| 13 | def __init__(self, identifier: str): | 14 | def __init__(self, identifier: str): |
| 14 | self.identifier = identifier | 15 | self.identifier = identifier |
| @@ -18,7 +19,7 @@ class BaseProxyQueue: | |||
| 18 | raise NotImplementedError(f'Please implement create_for_identifier') | 19 | raise NotImplementedError(f'Please implement create_for_identifier') |
| 19 | 20 | ||
| 20 | @classmethod | 21 | @classmethod |
| 21 | async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': | 22 | async def get_for_identifier(cls, identifier: str) -> Type['self']: |
| 22 | assert await cls.has_connection(identifier) | 23 | assert await cls.has_connection(identifier) |
| 23 | return cls(identifier) | 24 | return cls(identifier) |
| 24 | 25 | ||
| @@ -26,23 +27,18 @@ class BaseProxyQueue: | |||
| 26 | async def has_connection(cls, identifier) -> bool: | 27 | async def has_connection(cls, identifier) -> bool: |
| 27 | raise NotImplementedError(f'Please implement has_connection') | 28 | raise NotImplementedError(f'Please implement has_connection') |
| 28 | 29 | ||
| 29 | async def send_request(self, request_data: RequestData): | 30 | async def enqueue(self, message: Message): |
| 30 | raise NotImplementedError(f'Please implement send_request') | 31 | raise NotImplementedError(f'Please implement send_request') |
| 31 | 32 | ||
| 32 | async def handle_request(self) -> RequestData: | 33 | async def dequeue(self) -> Message: |
| 33 | raise NotImplementedError(f'Please implement handle_requests') | 34 | raise NotImplementedError(f'Please implement handle_requests') |
| 34 | 35 | ||
| 35 | async def send_response(self, response_data: ResponseData): | ||
| 36 | raise NotImplementedError(f'Please implement send_request') | ||
| 37 | |||
| 38 | async def handle_response(self) -> ResponseData: | ||
| 39 | raise NotImplementedError(f'Please implement handle_response') | ||
| 40 | |||
| 41 | async def delete(self): | 36 | async def delete(self): |
| 42 | raise NotImplementedError(f'Please implement delete') | 37 | raise NotImplementedError(f'Please implement delete') |
| 43 | 38 | ||
| 39 | |||
| 44 | class MemoryProxyQueue(BaseProxyQueue): | 40 | class MemoryProxyQueue(BaseProxyQueue): |
| 45 | connections: dict[str, MemoryConnection] = {} | 41 | connections: dict[str, asyncio.Queue] = {} |
| 46 | 42 | ||
| 47 | @classmethod | 43 | @classmethod |
| 48 | async def has_connection(cls, identifier) -> bool: | 44 | async def has_connection(cls, identifier) -> bool: |
| @@ -51,33 +47,15 @@ class MemoryProxyQueue(BaseProxyQueue): | |||
| 51 | @classmethod | 47 | @classmethod |
| 52 | async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': | 48 | async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': |
| 53 | instance = cls(identifier) | 49 | instance = cls(identifier) |
| 54 | 50 | cls.connections[identifier] = asyncio.Queue() | |
| 55 | cls.connections[identifier] = { | ||
| 56 | 'requests': asyncio.Queue(), | ||
| 57 | 'responses': asyncio.Queue(), | ||
| 58 | } | ||
| 59 | 51 | ||
| 60 | return instance | 52 | return instance |
| 61 | 53 | ||
| 62 | @property | 54 | async def enqueue(self, message: Message): |
| 63 | def requests(self) -> asyncio.Queue[RequestData]: | 55 | return await self.__class__.connections[self.identifier].put(message) |
| 64 | return self.__class__.connections[self.identifier]['requests'] | ||
| 65 | |||
| 66 | @property | ||
| 67 | def responses(self) -> asyncio.Queue[ResponseData]: | ||
| 68 | return self.__class__.connections[self.identifier]['responses'] | ||
| 69 | |||
| 70 | async def send_request(self, request_data: RequestData): | ||
| 71 | await self.requests.put(request_data) | ||
| 72 | 56 | ||
| 73 | async def handle_request(self) -> RequestData: | 57 | async def dequeue(self) -> Message: |
| 74 | return await self.requests.get() | 58 | return await self.__class__.connections[self.identifier].get() |
| 75 | |||
| 76 | async def send_response(self, response_data: ResponseData): | ||
| 77 | return await self.responses.put(response_data) | ||
| 78 | |||
| 79 | async def handle_response(self) -> ResponseData: | ||
| 80 | return await self.responses.get() | ||
| 81 | 59 | ||
| 82 | async def delete(self): | 60 | async def delete(self): |
| 83 | del self.__class__.connections[self.identifier] | 61 | del self.__class__.connections[self.identifier] |
| @@ -127,21 +105,12 @@ class RedisProxyQueue(BaseProxyQueue): | |||
| 127 | case _: | 105 | case _: |
| 128 | return message['data'] | 106 | return message['data'] |
| 129 | 107 | ||
| 130 | async def send_request(self, request_data: RequestData): | 108 | async def enqueue(self, message: Message): |
| 131 | await RedisConnectionPool \ | ||
| 132 | .get_connection() \ | ||
| 133 | .publish(f'request_{self.identifier}', json.dumps(request_data)) | ||
| 134 | |||
| 135 | async def handle_request(self) -> RequestData: | ||
| 136 | message = await self.wait_for_message() | ||
| 137 | return json.loads(message)< | ||
