From c4f33b3576e3a4a7f70b3d681fadae45f73ae31e Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Wed, 10 Jun 2026 08:31:30 +0200 Subject: Added support for multiple domains within one session --- ttun_server/endpoints.py | 4 ++-- ttun_server/websockets.py | 52 +++++++++++++++++++++++++++++------------------ 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index eae0ebe..fa5e7e7 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py @@ -7,7 +7,7 @@ from starlette.requests import Request from starlette.responses import Response from ttun_server.proxy_queue import ProxyQueue -from ttun_server.types import HttpRequestData, Message, HttpMessageType, HttpMessage +from ttun_server.types import HttpRequestData, HttpMessageType, HttpMessage logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ class Proxy(HTTPEndpoint): response = Response(content='Not Found', status_code=404) identifier = str(uuid4()) - response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{identifier}') + response_queue = await ProxyQueue.create_for_identifier(identifier) try: diff --git a/ttun_server/websockets.py b/ttun_server/websockets.py index f80359f..e828b8d 100644 --- a/ttun_server/websockets.py +++ b/ttun_server/websockets.py @@ -40,7 +40,7 @@ class WebsocketProxy(WebSocketEndpoint): await request_queue.enqueue(message) if expect_ack: - response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{message["identifier"]}') + response_queue = await ProxyQueue.create_for_identifier(message["identifier"]) yield await response_queue.dequeue() await response_queue.delete() else: @@ -49,10 +49,7 @@ class WebsocketProxy(WebSocketEndpoint): yield None async def listen_for_messages(self, websocket: WebSocket): - [subdomain, *_] = websocket.url.hostname.split('.') - - print('listen', self.id) - response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{self.id}') + response_queue = await ProxyQueue.create_for_identifier(self.id) while True: message: WebsocketMessage = await response_queue.dequeue() @@ -122,12 +119,14 @@ class Tunnel(WebSocketEndpoint): def __init__(self, scope: Scope, receive: Receive, send: Send): super().__init__(scope, receive, send) - self.request_task = None + self.request_tasks: dict[str, asyncio.Task] = {} self.config: Optional[Config] = None + self.proxy_queues: dict[str, ProxyQueue] = {} + + async def handle_requests(self, websocket: WebSocket, subdomain: str): + while request := await self.proxy_queues[subdomain].dequeue(): + task = asyncio.create_task(websocket.send_json(request), name=request['identifier']) - async def handle_requests(self, websocket: WebSocket): - while request := await self.proxy_queue.dequeue(): - create_task(websocket.send_json(request)) async def on_connect(self, websocket: WebSocket) -> None: await websocket.accept() @@ -146,32 +145,45 @@ class Tunnel(WebSocketEndpoint): if client_major > server_major: await websocket.close(4001, 'Your client is too new') + if 'subdomains' not in self.config: + self.config['subdomains'] = [self.config['subdomain']] + elif self.config['subdomains'] is None: + self.config['subdomains'] = [None] - if self.config['subdomain'] is None \ - or await ProxyQueue.has_connection(self.config['subdomain']): - self.config['subdomain'] = uuid4().hex + for i, subdomain in enumerate(self.config['subdomains']): + if subdomain is None or await ProxyQueue.has_connection(subdomain): + self.config['subdomains'][i] = uuid4().hex - - self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) + for subdomain in self.config['subdomains']: + self.proxy_queues[subdomain] = await ProxyQueue.create_for_identifier(subdomain) hostname = os.environ.get("TUNNEL_DOMAIN") protocol = "https" if os.environ.get("SECURE", False) else "http" + urls = [ + f'{protocol}://{subdomain}.{hostname}' + for subdomain in self.config['subdomains'] + ] + await websocket.send_json({ - 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}' + 'url': urls[0], + 'urls': urls, }) - self.request_task = asyncio.create_task(self.handle_requests(websocket)) + for subdomain in self.config['subdomains']: + self.request_tasks[subdomain] = asyncio.create_task(self.handle_requests(websocket, subdomain), name=subdomain) async def on_receive(self, websocket: WebSocket, data: Message): try: - response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}") + data['type'] = MessageType(data['type']).value + response_queue = await ProxyQueue.get_for_identifier(data['identifier']) await response_queue.enqueue(data) except AssertionError: pass async def on_disconnect(self, websocket: WebSocket, close_code: int): - await self.proxy_queue.delete() + for proxy_queue in self.proxy_queues.values(): + await proxy_queue.delete() - if self.request_task is not None: - self.request_task.cancel() + for request_task in self.request_tasks.values(): + request_task.cancel() -- cgit v1.2.3