From 486087cdb349dbc07b479d2286a02bdca310ea38 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Wed, 20 Mar 2024 21:48:45 +0100 Subject: Added websocket support --- ttun_server/endpoints.py | 75 ++++-------------------------------------------- 1 file changed, 5 insertions(+), 70 deletions(-) (limited to 'ttun_server/endpoints.py') diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index 3e263da..eae0ebe 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py @@ -1,20 +1,13 @@ -import asyncio import logging -import os -from asyncio import create_task from base64 import b64decode, b64encode -from typing import Optional, Any from uuid import uuid4 -from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint +from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import Response -from starlette.types import Scope, Receive, Send -from starlette.websockets import WebSocket -import ttun_server from ttun_server.proxy_queue import ProxyQueue -from ttun_server.types import RequestData, Config, Message, MessageType +from ttun_server.types import HttpRequestData, Message, HttpMessageType, HttpMessage logger = logging.getLogger(__name__) @@ -44,11 +37,11 @@ class Proxy(HTTPEndpoint): logger.debug('PROXY %s%s ', subdomain, request.url) await request_queue.enqueue( - Message( - type=MessageType.request.value, + HttpMessage( + type=HttpMessageType.request.value, identifier=identifier, payload= - RequestData( + HttpRequestData( method=request.method, path=str(request.url).replace(str(request.base_url), '/'), headers=list(request.headers.items()), @@ -78,61 +71,3 @@ class Health(HTTPEndpoint): await response(self.scope, self.receive, self.send) -class Tunnel(WebSocketEndpoint): - encoding = 'json' - - def __init__(self, scope: Scope, receive: Receive, send: Send): - super().__init__(scope, receive, send) - self.request_task = None - self.config: Optional[Config] = None - - async def handle_requests(self, websocket: WebSocket): - while request := await self.proxy_queue.dequeue(): - create_task(websocket.send_json(request)) - - async def on_connect(self, websocket: WebSocket) -> None: - await websocket.accept() - self.config = await websocket.receive_json() - - client_version = self.config.get('version', '1.0.0') - logger.debug('client_version %s', client_version) - - if 'git' not in client_version and ttun_server.__version__ != 'development': - [client_major, *_] = [int(i) for i in client_version.split('.')[:3]] - [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')] - - if client_major < server_major: - await websocket.close(4000, 'Your client is too old') - - if client_major > server_major: - await websocket.close(4001, 'Your client is too new') - - - if self.config['subdomain'] is None \ - or await ProxyQueue.has_connection(self.config['subdomain']): - self.config['subdomain'] = uuid4().hex - - - self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) - - hostname = os.environ.get("TUNNEL_DOMAIN") - protocol = "https" if os.environ.get("SECURE", False) else "http" - - await websocket.send_json({ - 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}' - }) - - self.request_task = asyncio.create_task(self.handle_requests(websocket)) - - async def on_receive(self, websocket: WebSocket, data: Message): - try: - response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}") - await response_queue.enqueue(data) - except AssertionError: - pass - - async def on_disconnect(self, websocket: WebSocket, close_code: int): - await self.proxy_queue.delete() - - if self.request_task is not None: - self.request_task.cancel() -- cgit v1.2.3