From 5a38ebf365bfa0718dcbd7ab013af5f2da4610f6 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Mon, 17 Jan 2022 19:41:55 +0100 Subject: Added scaling support via redis --- requirements.txt | 1 + ttun_server/__init__.py | 5 ++ ttun_server/connections.py | 3 - ttun_server/endpoints.py | 43 +++++------- ttun_server/proxy_queue.py | 163 +++++++++++++++++++++++++++++++++++++++++++++ ttun_server/redis.py | 20 ++++++ ttun_server/types.py | 2 +- 7 files changed, 207 insertions(+), 30 deletions(-) delete mode 100644 ttun_server/connections.py create mode 100644 ttun_server/proxy_queue.py create mode 100644 ttun_server/redis.py diff --git a/requirements.txt b/requirements.txt index 95f7ad2..34c860e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ starlette ~= 0.17 uvicorn[standard] ~= 0.16 +aioredis ~= 2.0 diff --git a/ttun_server/__init__.py b/ttun_server/__init__.py index b8fd114..cf589cc 100644 --- a/ttun_server/__init__.py +++ b/ttun_server/__init__.py @@ -1,8 +1,13 @@ +import logging +import os + from starlette.applications import Starlette from starlette.routing import Route, WebSocketRoute from ttun_server.endpoints import Proxy, Tunnel +logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO'))) + server = Starlette( debug=True, routes=[ diff --git a/ttun_server/connections.py b/ttun_server/connections.py deleted file mode 100644 index a8dabcf..0000000 --- a/ttun_server/connections.py +++ /dev/null @@ -1,3 +0,0 @@ -from ttun_server.types import Connection - -connections: dict[str, Connection] = {} diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index d59cb7c..5b9e57f 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from asyncio import Queue from base64 import b64decode, b64encode @@ -11,9 +12,10 @@ from starlette.responses import Response from starlette.types import Scope, Receive, Send from starlette.websockets import WebSocket -from ttun_server.types import Connection, RequestData, Config, ResponseData +from ttun_server.proxy_queue import ProxyQueue +from ttun_server.types import RequestData, Config, ResponseData -from ttun_server.connections import connections +logger = logging.getLogger(__name__) class Proxy(HTTPEndpoint): @@ -23,10 +25,10 @@ class Proxy(HTTPEndpoint): [subdomain, *_] = request.headers['host'].split('.') response = Response(content='Not Found', status_code=404) - if subdomain in connections: - connection = connections[subdomain] + try: + queue = await ProxyQueue.get_for_identifier(subdomain) - await connection['requests'].put(RequestData( + await queue.send_request(RequestData( method=request.method, path=str(request.url).replace(str(request.base_url), '/'), headers=dict(request.headers), @@ -34,12 +36,14 @@ class Proxy(HTTPEndpoint): body=b64encode(await request.body()).decode() )) - _response = await connection['responses'].get() + _response = await queue.handle_response() response = Response( status_code=_response['status'], headers=_response['headers'], content=b64decode(_response['body'].encode()) ) + except AssertionError: + pass await response(self.scope, self.receive, self.send) @@ -52,16 +56,8 @@ class Tunnel(WebSocketEndpoint): self.request_task = None self.config: Optional[Config] = None - @property - def requests(self) -> Queue[RequestData]: - return connections[self.config['subdomain']]['requests'] - - @property - def responses(self) -> Queue[ResponseData]: - return connections[self.config['subdomain']]['responses'] - async def handle_requests(self, websocket: WebSocket): - while request := await self.requests.get(): + while request := await self.proxy_queue.handle_request(): await websocket.send_json(request) async def on_connect(self, websocket: WebSocket) -> None: @@ -69,14 +65,10 @@ class Tunnel(WebSocketEndpoint): self.config = await websocket.receive_json() if self.config['subdomain'] is None \ - or self.config['subdomain'] in connections: + or await ProxyQueue.has_connection(self.config['subdomain']): self.config['subdomain'] = uuid4().hex - - connections[self.config['subdomain']] = Connection( - requests=Queue(), - responses=Queue(), - ) + self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) hostname = os.environ.get("TUNNEL_DOMAIN") protocol = "https" if os.environ.get("SECURE", False) else "http" @@ -87,12 +79,11 @@ class Tunnel(WebSocketEndpoint): self.request_task = asyncio.create_task(self.handle_requests(websocket)) - async def on_receive(self, websocket: WebSocket, data: Any) -> None: - await self.responses.put(data) + async def on_receive(self, websocket: WebSocket, data: Any): + await self.proxy_queue.send_response(data) - async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: - if self.config is not None and self.config['subdomain'] in connections: - del connections[self.config['subdomain']] + async def on_disconnect(self, websocket: WebSocket, close_code: int): + await self.proxy_queue.delete() if self.request_task is not None: self.request_task.cancel() diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py new file mode 100644 index 0000000..07e16e0 --- /dev/null +++ b/ttun_server/proxy_queue.py @@ -0,0 +1,163 @@ +import asyncio +import json +import logging +import os +from typing import Awaitable, Callable + +from ttun_server.redis import RedisConnectionPool +from ttun_server.types import RequestData, ResponseData, MemoryConnection + +logger = logging.getLogger(__name__) + +class BaseProxyQueue: + def __init__(self, identifier: str): + self.identifier = identifier + + @classmethod + async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': + raise NotImplementedError(f'Please implement create_for_identifier') + + @classmethod + async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': + assert await cls.has_connection(identifier) + return cls(identifier) + + @classmethod + async def has_connection(cls, identifier) -> bool: + raise NotImplementedError(f'Please implement has_connection') + + async def send_request(self, request_data: RequestData): + raise NotImplementedError(f'Please implement send_request') + + async def handle_request(self) -> RequestData: + raise NotImplementedError(f'Please implement handle_requests') + + async def send_response(self, response_data: ResponseData): + raise NotImplementedError(f'Please implement send_request') + + async def handle_response(self) -> ResponseData: + raise NotImplementedError(f'Please implement handle_response') + + async def delete(self): + raise NotImplementedError(f'Please implement delete') + +class MemoryProxyQueue(BaseProxyQueue): + connections: dict[str, MemoryConnection] = {} + + @classmethod + async def has_connection(cls, identifier) -> bool: + return identifier in cls.connections + + @classmethod + async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': + instance = cls(identifier) + + cls.connections[identifier] = { + 'requests': asyncio.Queue(), + 'responses': asyncio.Queue(), + } + + return instance + + @property + def requests(self) -> asyncio.Queue[RequestData]: + return self.__class__.connections[self.identifier]['requests'] + + @property + def responses(self) -> asyncio.Queue[ResponseData]: + return self.__class__.connections[self.identifier]['responses'] + + async def send_request(self, request_data: RequestData): + await self.requests.put(request_data) + + async def handle_request(self) -> RequestData: + return await self.requests.get() + + async def send_response(self, response_data: ResponseData): + return await self.responses.put(response_data) + + async def handle_response(self) -> ResponseData: + return await self.responses.get() + + async def delete(self): + del self.__class__.connections[self.identifier] + + +class RedisProxyQueue(BaseProxyQueue): + def __init__(self, identifier): + super().__init__(identifier) + + self.pubsub = RedisConnectionPool()\ + .get_connection()\ + .pubsub() + + self.subscription_queue = asyncio.Queue() + + @classmethod + async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': + instance = cls(identifier) + + await instance.pubsub.subscribe(f'request_{identifier}') + return instance + + @classmethod + async def get_for_identifier(cls, identifier: str) -> 'RedisProxyQueue': + instance: 'RedisProxyQueue' = await super().get_for_identifier(identifier) + + await instance.pubsub.subscribe(f'response_{identifier}') + + return instance + + @classmethod + async def has_connection(cls, identifier) -> bool: + logger.debug(await RedisConnectionPool.get_connection().pubsub_channels()) + return f'request_{identifier}' in { + channel.decode() + for channel + in await RedisConnectionPool \ + .get_connection() \ + .pubsub_channels() + } + + async def wait_for_message(self): + async for message in self.pubsub.listen(): + match message['type']: + case 'subscribe': + continue + case _: + return message['data'] + + async def send_request(self, request_data: RequestData): + await RedisConnectionPool \ + .get_connection() \ + .publish(f'request_{self.identifier}', json.dumps(request_data)) + + async def handle_request(self) -> RequestData: + message = await self.wait_for_message() + return json.loads(message) + + async def send_response(self, response_data: ResponseData): + await RedisConnectionPool \ + .get_connection() \ + .publish(f'response_{self.identifier}', json.dumps(response_data)) + + async def handle_response(self) -> ResponseData: + message = await self.wait_for_message() + return json.loads(message) + + async def delete(self): + await self.pubsub.unsubscribe(f'request_{self.identifier}') + + await RedisConnectionPool.get_connection()\ + .srem('connections', self.identifier) + + +class ProxyQueueMeta(type): + def __new__(cls, name, superclasses, attributes): + return RedisProxyQueue \ + if 'REDIS_URL' in os.environ \ + else MemoryProxyQueue + + +class ProxyQueue(BaseProxyQueue, metaclass=ProxyQueueMeta): + pass diff --git a/ttun_server/redis.py b/ttun_server/redis.py new file mode 100644 index 0000000..344c107 --- /dev/null +++ b/ttun_server/redis.py @@ -0,0 +1,20 @@ +import os + +from aioredis import ConnectionPool, Redis + + +class RedisConnectionPool(): + instance: 'RedisConnectionPool' = None + + def __init__(self): + self.pool = ConnectionPool.from_url(os.environ.get('REDIS_URL')) + + def __del__(self): + self.pool.disconnect() + + @classmethod + def get_connection(cls) -> Redis: + if cls.instance is None: + cls.instance = RedisConnectionPool() + + return Redis(connection_pool=cls.instance.pool) diff --git a/ttun_server/types.py b/ttun_server/types.py index 0b2fb87..9052a0e 100644 --- a/ttun_server/types.py +++ b/ttun_server/types.py @@ -20,6 +20,6 @@ class ResponseData(TypedDict): body: Optional[str] -class Connection(TypedDict): +class MemoryConnection(TypedDict): requests: Queue[RequestData] responses: Queue[ResponseData] -- cgit v1.2.3