From 486087cdb349dbc07b479d2286a02bdca310ea38 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Wed, 20 Mar 2024 21:48:45 +0100 Subject: Added websocket support --- ttun_server/__init__.py | 4 +- ttun_server/endpoints.py | 75 ++----------------- ttun_server/proxy_queue.py | 1 + ttun_server/redis.py | 7 +- ttun_server/types.py | 51 +++++++++++-- ttun_server/websockets.py | 179 +++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 236 insertions(+), 81 deletions(-) create mode 100644 ttun_server/websockets.py (limited to 'ttun_server') diff --git a/ttun_server/__init__.py b/ttun_server/__init__.py index 81f8cd4..2f8fed0 100644 --- a/ttun_server/__init__.py +++ b/ttun_server/__init__.py @@ -4,7 +4,8 @@ import os from starlette.applications import Starlette from starlette.routing import Route, WebSocketRoute, Host, Router -from ttun_server.endpoints import Proxy, Tunnel, Health +from ttun_server.endpoints import Proxy, Health +from .websockets import WebsocketProxy, Tunnel logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO'))) @@ -18,6 +19,7 @@ server = Starlette( routes=[ Host(os.environ['TUNNEL_DOMAIN'], base_router, 'base'), Route('/{path:path}', Proxy), + WebSocketRoute('/{path:path}', WebsocketProxy) ] ) diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index 3e263da..eae0ebe 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py @@ -1,20 +1,13 @@ -import asyncio import logging -import os -from asyncio import create_task from base64 import b64decode, b64encode -from typing import Optional, Any from uuid import uuid4 -from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint +from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import Response -from starlette.types import Scope, Receive, Send -from starlette.websockets import WebSocket -import ttun_server from ttun_server.proxy_queue import ProxyQueue -from ttun_server.types import RequestData, Config, Message, MessageType +from ttun_server.types import HttpRequestData, Message, HttpMessageType, HttpMessage logger = logging.getLogger(__name__) @@ -44,11 +37,11 @@ class Proxy(HTTPEndpoint): logger.debug('PROXY %s%s ', subdomain, request.url) await request_queue.enqueue( - Message( - type=MessageType.request.value, + HttpMessage( + type=HttpMessageType.request.value, identifier=identifier, payload= - RequestData( + HttpRequestData( method=request.method, path=str(request.url).replace(str(request.base_url), '/'), headers=list(request.headers.items()), @@ -78,61 +71,3 @@ class Health(HTTPEndpoint): await response(self.scope, self.receive, self.send) -class Tunnel(WebSocketEndpoint): - encoding = 'json' - - def __init__(self, scope: Scope, receive: Receive, send: Send): - super().__init__(scope, receive, send) - self.request_task = None - self.config: Optional[Config] = None - - async def handle_requests(self, websocket: WebSocket): - while request := await self.proxy_queue.dequeue(): - create_task(websocket.send_json(request)) - - async def on_connect(self, websocket: WebSocket) -> None: - await websocket.accept() - self.config = await websocket.receive_json() - - client_version = self.config.get('version', '1.0.0') - logger.debug('client_version %s', client_version) - - if 'git' not in client_version and ttun_server.__version__ != 'development': - [client_major, *_] = [int(i) for i in client_version.split('.')[:3]] - [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')] - - if client_major < server_major: - await websocket.close(4000, 'Your client is too old') - - if client_major > server_major: - await websocket.close(4001, 'Your client is too new') - - - if self.config['subdomain'] is None \ - or await ProxyQueue.has_connection(self.config['subdomain']): - self.config['subdomain'] = uuid4().hex - - - self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) - - hostname = os.environ.get("TUNNEL_DOMAIN") - protocol = "https" if os.environ.get("SECURE", False) else "http" - - await websocket.send_json({ - 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}' - }) - - self.request_task = asyncio.create_task(self.handle_requests(websocket)) - - async def on_receive(self, websocket: WebSocket, data: Message): - try: - response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}") - await response_queue.enqueue(data) - except AssertionError: - pass - - async def on_disconnect(self, websocket: WebSocket, close_code: int): - await self.proxy_queue.delete() - - if self.request_task is not None: - self.request_task.cancel() diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py index e521886..c6c8067 100644 --- a/ttun_server/proxy_queue.py +++ b/ttun_server/proxy_queue.py @@ -2,6 +2,7 @@ import asyncio import json import logging import os +import traceback from typing import Type from ttun_server.redis import RedisConnectionPool diff --git a/ttun_server/redis.py b/ttun_server/redis.py index 3065dec..18fbca2 100644 --- a/ttun_server/redis.py +++ b/ttun_server/redis.py @@ -1,6 +1,8 @@ +import asyncio import os +from asyncio import get_running_loop -from aioredis import ConnectionPool, Redis +from redis.asyncio import ConnectionPool, Redis class RedisConnectionPool: @@ -9,9 +11,6 @@ class RedisConnectionPool: def __init__(self): self.pool = ConnectionPool.from_url(os.environ.get('REDIS_URL')) - def __del__(self): - self.pool.disconnect() - @classmethod def get_connection(cls) -> Redis: if cls.instance is None: diff --git a/ttun_server/types.py b/ttun_server/types.py index 8a4d929..8591e7d 100644 --- a/ttun_server/types.py +++ b/ttun_server/types.py @@ -3,7 +3,7 @@ from enum import Enum from typing import TypedDict, Optional -class MessageType(Enum): +class HttpMessageType(Enum): request = 'request' response = 'response' @@ -13,23 +13,62 @@ class Config(TypedDict): client_version: str -class RequestData(TypedDict): +class HttpRequestData(TypedDict): method: str path: str headers: list[tuple[str, str]] body: Optional[str] -class ResponseData(TypedDict): +class HttpResponseData(TypedDict): status: int headers: list[tuple[str, str]] body: Optional[str] -class Message(TypedDict): - type: MessageType +class HttpMessage(TypedDict): + type: HttpMessageType identifier: str - payload: Config | RequestData | ResponseData + payload: Config | HttpRequestData | HttpResponseData + + +class WebsocketMessageType(Enum): + connect = 'connect' + disconnect = 'disconnect' + message = 'message' + ack = 'ack' + + +class WebsocketConnectData(TypedDict): + path: str + headers: list[tuple[str, str]] + + +class WebsocketDisconnectData(TypedDict): + close_code: int + + +class WebsocketMessageData(TypedDict): + body: Optional[str] + + +class WebsocketMessage(TypedDict): + type: WebsocketMessageType + identifier: str + payload: WebsocketConnectData | WebsocketDisconnectData | WebsocketMessageData + + +class MessageType(Enum): + request = 'request' + response = 'response' + + ws_connect = 'connect' + ws_disconnect = 'disconnect' + ws_message = 'message' + ws_ack = 'ack' + + +Message = HttpMessage | WebsocketMessage class MemoryConnection(TypedDict): diff --git a/ttun_server/websockets.py b/ttun_server/websockets.py new file mode 100644 index 0000000..0800cbc --- /dev/null +++ b/ttun_server/websockets.py @@ -0,0 +1,179 @@ +import asyncio +import json +import logging +import os +import typing +from asyncio import create_task +from base64 import b64encode, b64decode +from contextlib import asynccontextmanager +from typing import Optional +from uuid import uuid4 + +from starlette.endpoints import WebSocketEndpoint +from starlette.types import Scope, Receive, Send +from starlette.websockets import WebSocket + +import ttun_server +from ttun_server.proxy_queue import ProxyQueue +from ttun_server.types import Config, Message, WebsocketMessageType, \ + WebsocketConnectData, WebsocketMessage, WebsocketMessageData, WebsocketDisconnectData, MessageType + +logger = logging.getLogger(__name__) +logger.setLevel('DEBUG') + +class WebsocketProxy(WebSocketEndpoint): + encoding = 'json' + websocket_listen_task = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.id = str(uuid4()) + + @asynccontextmanager + async def proxy(self, websocket: WebSocket, message: WebsocketMessage): + [subdomain, *_] = websocket.url.hostname.split('.') + + expect_ack = WebsocketMessageType(message['type']) == WebsocketMessageType.connect + + try: + request_queue = await ProxyQueue.get_for_identifier(subdomain) + await request_queue.enqueue(message) + + if expect_ack: + response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{message["identifier"]}') + yield await response_queue.dequeue() + await response_queue.delete() + else: + yield + except AssertionError: + pass + + async def listen_for_messages(self, websocket: WebSocket): + [subdomain, *_] = websocket.url.hostname.split('.') + + print('listen', self.id) + response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{self.id}') + + while True: + message: WebsocketMessage = await response_queue.dequeue() + logger.debug(message) + await websocket.send_text(b64decode(message['payload']['body'].encode()).decode()) + + async def on_connect(self, websocket: WebSocket) -> None: + message = WebsocketMessage( + type=WebsocketMessageType.connect.value, + identifier=self.id, + payload=WebsocketConnectData( + path=websocket.path_params['path'], + headers=[ + (k.decode(), v.decode()) + for k, v + in websocket.scope['headers'] + ], + ) + ) + + async with self.proxy(websocket, message) as m: + type = WebsocketMessageType(m['type']) + + if type == WebsocketMessageType.ack: + await super().on_connect(websocket) + + self.websocket_listen_task = asyncio.create_task(self.listen_for_messages(websocket)) + + def callback(*args, **kwargs): + self.websocket_listen_task = None + + self.websocket_listen_task.add_done_callback(callback) + + async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: + match data: + case dict(): + data_bytes = json.dumps(data).encode() + case bytes(): + data_bytes = data + case _: + data_bytes = data.encode() + + message = WebsocketMessage( + type=WebsocketMessageType.message.value, + identifier=self.id, + payload=WebsocketMessageData( + body=b64encode(data_bytes).decode(), + ) + ) + + async with self.proxy(websocket, message): + pass + + async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: + message = WebsocketMessage( + type=WebsocketMessageType.disconnect.value, + identifier=self.id, + payload=WebsocketDisconnectData( + close_code=close_code, + ) + ) + + async with self.proxy(websocket, message): + if self.websocket_listen_task is not None: + self.websocket_listen_task.cancel() + +class Tunnel(WebSocketEndpoint): + encoding = 'json' + + def __init__(self, scope: Scope, receive: Receive, send: Send): + super().__init__(scope, receive, send) + self.request_task = None + self.config: Optional[Config] = None + + async def handle_requests(self, websocket: WebSocket): + while request := await self.proxy_queue.dequeue(): + create_task(websocket.send_json(request)) + + async def on_connect(self, websocket: WebSocket) -> None: + await websocket.accept() + self.config = await websocket.receive_json() + + client_version = self.config.get('version', '1.0.0') + logger.debug('client_version %s', client_version) + + if 'git' not in client_version and ttun_server.__version__ != 'development': + [client_major, *_] = [int(i) for i in client_version.split('.')[:3]] + [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')] + + if client_major < server_major: + await websocket.close(4000, 'Your client is too old') + + if client_major > server_major: + await websocket.close(4001, 'Your client is too new') + + + if self.config['subdomain'] is None \ + or await ProxyQueue.has_connection(self.config['subdomain']): + self.config['subdomain'] = uuid4().hex + + + self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) + + hostname = os.environ.get("TUNNEL_DOMAIN") + protocol = "https" if os.environ.get("SECURE", False) else "http" + + await websocket.send_json({ + 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}' + }) + + self.request_task = asyncio.create_task(self.handle_requests(websocket)) + + async def on_receive(self, websocket: WebSocket, data: Message): + try: + response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}") + await response_queue.enqueue(data) + except AssertionError: + pass + + async def on_disconnect(self, websocket: WebSocket, close_code: int): + await self.proxy_queue.delete() + + if self.request_task is not None: + self.request_task.cancel() -- cgit v1.2.3