From 5a38ebf365bfa0718dcbd7ab013af5f2da4610f6 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Mon, 17 Jan 2022 19:41:55 +0100 Subject: Added scaling support via redis --- ttun_server/endpoints.py | 43 +++++++++++++++++-------------------------- 1 file changed, 17 insertions(+), 26 deletions(-) (limited to 'ttun_server/endpoints.py') diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index d59cb7c..5b9e57f 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from asyncio import Queue from base64 import b64decode, b64encode @@ -11,9 +12,10 @@ from starlette.responses import Response from starlette.types import Scope, Receive, Send from starlette.websockets import WebSocket -from ttun_server.types import Connection, RequestData, Config, ResponseData +from ttun_server.proxy_queue import ProxyQueue +from ttun_server.types import RequestData, Config, ResponseData -from ttun_server.connections import connections +logger = logging.getLogger(__name__) class Proxy(HTTPEndpoint): @@ -23,10 +25,10 @@ class Proxy(HTTPEndpoint): [subdomain, *_] = request.headers['host'].split('.') response = Response(content='Not Found', status_code=404) - if subdomain in connections: - connection = connections[subdomain] + try: + queue = await ProxyQueue.get_for_identifier(subdomain) - await connection['requests'].put(RequestData( + await queue.send_request(RequestData( method=request.method, path=str(request.url).replace(str(request.base_url), '/'), headers=dict(request.headers), @@ -34,12 +36,14 @@ class Proxy(HTTPEndpoint): body=b64encode(await request.body()).decode() )) - _response = await connection['responses'].get() + _response = await queue.handle_response() response = Response( status_code=_response['status'], headers=_response['headers'], content=b64decode(_response['body'].encode()) ) + except AssertionError: + pass await response(self.scope, self.receive, self.send) @@ -52,16 +56,8 @@ class Tunnel(WebSocketEndpoint): self.request_task = None self.config: Optional[Config] = None - @property - def requests(self) -> Queue[RequestData]: - return connections[self.config['subdomain']]['requests'] - - @property - def responses(self) -> Queue[ResponseData]: - return connections[self.config['subdomain']]['responses'] - async def handle_requests(self, websocket: WebSocket): - while request := await self.requests.get(): + while request := await self.proxy_queue.handle_request(): await websocket.send_json(request) async def on_connect(self, websocket: WebSocket) -> None: @@ -69,14 +65,10 @@ class Tunnel(WebSocketEndpoint): self.config = await websocket.receive_json() if self.config['subdomain'] is None \ - or self.config['subdomain'] in connections: + or await ProxyQueue.has_connection(self.config['subdomain']): self.config['subdomain'] = uuid4().hex - - connections[self.config['subdomain']] = Connection( - requests=Queue(), - responses=Queue(), - ) + self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) hostname = os.environ.get("TUNNEL_DOMAIN") protocol = "https" if os.environ.get("SECURE", False) else "http" @@ -87,12 +79,11 @@ class Tunnel(WebSocketEndpoint): self.request_task = asyncio.create_task(self.handle_requests(websocket)) - async def on_receive(self, websocket: WebSocket, data: Any) -> None: - await self.responses.put(data) + async def on_receive(self, websocket: WebSocket, data: Any): + await self.proxy_queue.send_response(data) - async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: - if self.config is not None and self.config['subdomain'] in connections: - del connections[self.config['subdomain']] + async def on_disconnect(self, websocket: WebSocket, close_code: int): + await self.proxy_queue.delete() if self.request_task is not None: self.request_task.cancel() -- cgit v1.2.3