From 7ac28203290a211a6e17ae0b91bc2b609f110514 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Thu, 8 Jun 2023 08:24:43 +0200 Subject: WIP --- .python-version | 2 +- ttun_server/endpoints.py | 52 +++++++++++++++++++++++++------------- ttun_server/proxy_queue.py | 63 ++++++++++++---------------------------------- ttun_server/types.py | 17 +++++++++++-- 4 files changed, 66 insertions(+), 68 deletions(-) diff --git a/.python-version b/.python-version index 30291cb..09dcc78 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.10.0 +3.10.11 diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index b25ffe4..6728c31 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py @@ -12,7 +12,7 @@ from starlette.types import Scope, Receive, Send from starlette.websockets import WebSocket from ttun_server.proxy_queue import ProxyQueue -from ttun_server.types import RequestData, Config +from ttun_server.types import RequestData, Config, Message, MessageType logger = logging.getLogger(__name__) @@ -33,26 +33,39 @@ class Proxy(HTTPEndpoint): [subdomain, *_] = request.headers['host'].split('.') response = Response(content='Not Found', status_code=404) + identifier = str(uuid4()) + response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{identifier}') + try: - queue = await ProxyQueue.get_for_identifier(subdomain) - await queue.send_request(RequestData( - method=request.method, - path=str(request.url).replace(str(request.base_url), '/'), - headers=list(request.headers.items()), - body=b64encode(await request.body()).decode() - )) + request_queue = await ProxyQueue.get_for_identifier(subdomain) + + await request_queue.enqueue( + Message( + type=MessageType.request, + identifier=identifier, + payload= + RequestData( + method=request.method, + path=str(request.url).replace(str(request.base_url), '/'), + headers=list(request.headers.items()), + body=b64encode(await request.body()).decode() + ) + ) + ) - _response = await queue.handle_response() + _response = await response_queue.dequeue() + payload = _response['payload'] response = Response( - status_code=_response['status'], - headers=HeaderMapping(_response['headers']), - content=b64decode(_response['body'].encode()) + status_code=payload['status'], + headers=HeaderMapping(payload['headers']), + content=b64decode(payload['body'].encode()) ) except AssertionError: pass - - await response(self.scope, self.receive, self.send) + finally: + await response(self.scope, self.receive, self.send) + await response_queue.delete() class Health(HTTPEndpoint): @@ -62,7 +75,6 @@ class Health(HTTPEndpoint): await response(self.scope, self.receive, self.send) - class Tunnel(WebSocketEndpoint): encoding = 'json' @@ -72,7 +84,7 @@ class Tunnel(WebSocketEndpoint): self.config: Optional[Config] = None async def handle_requests(self, websocket: WebSocket): - while request := await self.proxy_queue.handle_request(): + while request := await self.proxy_queue.dequeue(): await websocket.send_json(request) async def on_connect(self, websocket: WebSocket) -> None: @@ -94,8 +106,12 @@ class Tunnel(WebSocketEndpoint): self.request_task = asyncio.create_task(self.handle_requests(websocket)) - async def on_receive(self, websocket: WebSocket, data: Any): - await self.proxy_queue.send_response(data) + async def on_receive(self, websocket: WebSocket, data: Message): + try: + response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}") + await response_queue.enqueue(data) + except AssertionError: + pass async def on_disconnect(self, websocket: WebSocket, close_code: int): await self.proxy_queue.delete() diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py index 07e16e0..e521886 100644 --- a/ttun_server/proxy_queue.py +++ b/ttun_server/proxy_queue.py @@ -2,13 +2,14 @@ import asyncio import json import logging import os -from typing import Awaitable, Callable +from typing import Type from ttun_server.redis import RedisConnectionPool -from ttun_server.types import RequestData, ResponseData, MemoryConnection +from ttun_server.types import Message logger = logging.getLogger(__name__) + class BaseProxyQueue: def __init__(self, identifier: str): self.identifier = identifier @@ -18,7 +19,7 @@ class BaseProxyQueue: raise NotImplementedError(f'Please implement create_for_identifier') @classmethod - async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': + async def get_for_identifier(cls, identifier: str) -> Type['self']: assert await cls.has_connection(identifier) return cls(identifier) @@ -26,23 +27,18 @@ class BaseProxyQueue: async def has_connection(cls, identifier) -> bool: raise NotImplementedError(f'Please implement has_connection') - async def send_request(self, request_data: RequestData): + async def enqueue(self, message: Message): raise NotImplementedError(f'Please implement send_request') - async def handle_request(self) -> RequestData: + async def dequeue(self) -> Message: raise NotImplementedError(f'Please implement handle_requests') - async def send_response(self, response_data: ResponseData): - raise NotImplementedError(f'Please implement send_request') - - async def handle_response(self) -> ResponseData: - raise NotImplementedError(f'Please implement handle_response') - async def delete(self): raise NotImplementedError(f'Please implement delete') + class MemoryProxyQueue(BaseProxyQueue): - connections: dict[str, MemoryConnection] = {} + connections: dict[str, asyncio.Queue] = {} @classmethod async def has_connection(cls, identifier) -> bool: @@ -51,33 +47,15 @@ class MemoryProxyQueue(BaseProxyQueue): @classmethod async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': instance = cls(identifier) - - cls.connections[identifier] = { - 'requests': asyncio.Queue(), - 'responses': asyncio.Queue(), - } + cls.connections[identifier] = asyncio.Queue() return instance - @property - def requests(self) -> asyncio.Queue[RequestData]: - return self.__class__.connections[self.identifier]['requests'] - - @property - def responses(self) -> asyncio.Queue[ResponseData]: - return self.__class__.connections[self.identifier]['responses'] - - async def send_request(self, request_data: RequestData): - await self.requests.put(request_data) + async def enqueue(self, message: Message): + return await self.__class__.connections[self.identifier].put(message) - async def handle_request(self) -> RequestData: - return await self.requests.get() - - async def send_response(self, response_data: ResponseData): - return await self.responses.put(response_data) - - async def handle_response(self) -> ResponseData: - return await self.responses.get() + async def dequeue(self) -> Message: + return await self.__class__.connections[self.identifier].get() async def delete(self): del self.__class__.connections[self.identifier] @@ -127,21 +105,12 @@ class RedisProxyQueue(BaseProxyQueue): case _: return message['data'] - async def send_request(self, request_data: RequestData): - await RedisConnectionPool \ - .get_connection() \ - .publish(f'request_{self.identifier}', json.dumps(request_data)) - - async def handle_request(self) -> RequestData: - message = await self.wait_for_message() - return json.loads(message) - - async def send_response(self, response_data: ResponseData): + async def enqueue(self, message: Message): await RedisConnectionPool \ .get_connection() \ - .publish(f'response_{self.identifier}', json.dumps(response_data)) + .publish(self.identifier, json.dumps(message)) - async def handle_response(self) -> ResponseData: + async def dequeue(self) -> Message: message = await self.wait_for_message() return json.loads(message) diff --git a/ttun_server/types.py b/ttun_server/types.py index 643ff21..2f1959f 100644 --- a/ttun_server/types.py +++ b/ttun_server/types.py @@ -1,10 +1,17 @@ from asyncio import Queue +from enum import Enum from typing import TypedDict, Optional +class MessageType(Enum): + request = 'request' + response = 'response' + + class Config(TypedDict): subdomain: str + class RequestData(TypedDict): method: str path: str @@ -18,6 +25,12 @@ class ResponseData(TypedDict): body: Optional[str] +class Message(TypedDict): + type: MessageType + identifier: str + payload: Config | RequestData | ResponseData + + class MemoryConnection(TypedDict): - requests: Queue[RequestData] - responses: Queue[ResponseData] + requests: Queue[Message] + responses: Queue[Message] -- cgit v1.2.3