From 12f2e24e2154113a6329d74aa556ae23506c34e1 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Tue, 30 Jun 2026 22:46:55 +0200 Subject: WIP --- ttun_server/__init__.py | 15 +++++--- ttun_server/endpoints.py | 97 +++++++++++++++++++----------------------------- 2 files changed, 48 insertions(+), 64 deletions(-) (limited to 'ttun_server') diff --git a/ttun_server/__init__.py b/ttun_server/__init__.py index 2f8fed0..6c77858 100644 --- a/ttun_server/__init__.py +++ b/ttun_server/__init__.py @@ -1,28 +1,31 @@ import logging import os -from starlette.applications import Starlette -from starlette.routing import Route, WebSocketRoute, Host, Router +from fastapi import FastAPI +from starlette.routing import Host, Route, Router, WebSocketRoute -from ttun_server.endpoints import Proxy, Health +from ttun_server.endpoints import health, proxy from .websockets import WebsocketProxy, Tunnel logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO'))) base_router = Router(routes=[ - Route('/health/', Health), + Route('/health/', health), WebSocketRoute('/tunnel/', Tunnel) ]) -server = Starlette( +server = FastAPI( debug=True, routes=[ Host(os.environ['TUNNEL_DOMAIN'], base_router, 'base'), - Route('/{path:path}', Proxy), + Route('/{path:path}', proxy), WebSocketRoute('/{path:path}', WebsocketProxy) ] ) +server.post() + + try: from ._version import version __version__ = version diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index fa5e7e7..22dcb6d 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py @@ -2,7 +2,7 @@ import logging from base64 import b64decode, b64encode from uuid import uuid4 -from starlette.endpoints import HTTPEndpoint +from starlette.background import BackgroundTask from starlette.requests import Request from starlette.responses import Response @@ -12,62 +12,43 @@ from ttun_server.types import HttpRequestData, HttpMessageType, HttpMessage logger = logging.getLogger(__name__) -class HeaderMapping: - def __init__(self, headers: list[tuple[str, str]]): - self._headers = headers - - def items(self): - for header in self._headers: - yield header - - -class Proxy(HTTPEndpoint): - async def dispatch(self) -> None: - request = Request(self.scope, self.receive) - - [subdomain, *_] = request.headers['host'].split('.') - response = Response(content='Not Found', status_code=404) - - identifier = str(uuid4()) - response_queue = await ProxyQueue.create_for_identifier(identifier) - - try: - - request_queue = await ProxyQueue.get_for_identifier(subdomain) - - logger.debug('PROXY %s%s ', subdomain, request.url) - await request_queue.enqueue( - HttpMessage( - type=HttpMessageType.request.value, - identifier=identifier, - payload= - HttpRequestData( - method=request.method, - path=str(request.url).replace(str(request.base_url), '/'), - headers=list(request.headers.items()), - body=b64encode(await request.body()).decode() - ) +async def proxy(request: Request) -> Response: + [subdomain, *_] = request.headers['host'].split('.') + identifier = str(uuid4()) + response_queue = await ProxyQueue.create_for_identifier(identifier) + + try: + request_queue = await ProxyQueue.get_for_identifier(subdomain) + + logger.debug('PROXY %s%s ', subdomain, request.url) + await request_queue.enqueue( + HttpMessage( + type=HttpMessageType.request.value, + identifier=identifier, + payload=HttpRequestData( + method=request.method, + path=str(request.url).replace(str(request.base_url), '/'), + headers=list(request.headers.items()), + body=b64encode(await request.body()).decode() ) ) - - _response = await response_queue.dequeue() - payload = _response['payload'] - response = Response( - status_code=payload['status'], - headers=HeaderMapping(payload['headers']), - content=b64decode(payload['body'].encode()) - ) - except AssertionError: - pass - finally: - await response(self.scope, self.receive, self.send) - await response_queue.delete() - - -class Health(HTTPEndpoint): - async def get(self, _) -> None: - response = Response(content='OK', status_code=200) - - await response(self.scope, self.receive, self.send) - - + ) + + _response = await response_queue.dequeue() + payload = _response['payload'] + return Response( + status_code=payload['status'], + headers=dict(payload['headers']), + content=b64decode(payload['body'].encode()), + background=BackgroundTask(response_queue.delete) + ) + except AssertionError: + return Response( + content='Not Found', + status_code=404, + background=BackgroundTask(response_queue.delete) + ) + + +async def health(_: Request) -> Response: + return Response(content='OK', status_code=200) -- cgit v1.2.3