From 9a55068e5de5da19e9c3d77455b8c25f8327f896 Mon Sep 17 00:00:00 2001 From: Tom van der Lee Date: Wed, 20 Mar 2024 21:47:55 +0100 Subject: Added websocket support --- ttun/__main__.py | 11 +-- ttun/client.py | 216 +++++++++++++++++++++++++++++++++++++++---------- ttun/inspect_server.py | 5 +- ttun/pubsub.py | 1 - ttun/types.py | 60 ++++++++++++-- 5 files changed, 234 insertions(+), 59 deletions(-) diff --git a/ttun/__main__.py b/ttun/__main__.py index 4e693fb..dd83e53 100644 --- a/ttun/__main__.py +++ b/ttun/__main__.py @@ -1,21 +1,23 @@ import asyncio +import logging +import os import re -import time from argparse import ArgumentDefaultsHelpFormatter from argparse import ArgumentParser from asyncio import FIRST_EXCEPTION from asyncio.exceptions import CancelledError from asyncio.exceptions import TimeoutError -from typing import Dict from typing import Tuple -from websockets.exceptions import ConnectionClosedError - from ttun.client import Client from ttun.inspect_server import Server from ttun.settings import SERVER_HOSTNAME from ttun.settings import SERVER_USING_SSL +logging.basicConfig(encoding="utf-8") +logging.getLogger("asyncio").setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET")) +logging.getLogger("websockets").setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET")) + inspect_queue = asyncio.Queue() @@ -73,7 +75,6 @@ def main(): headers=args.header, ) - try: loop = asyncio.get_running_loop() except RuntimeError: diff --git a/ttun/client.py b/ttun/client.py index a75c882..b19bb47 100644 --- a/ttun/client.py +++ b/ttun/client.py @@ -1,5 +1,8 @@ import asyncio import json +import logging +import os +import sys from asyncio import get_running_loop from base64 import b64decode from base64 import b64encode @@ -24,10 +27,18 @@ from websockets.exceptions import ConnectionClosed from ttun import __version__ from ttun.pubsub import PubSub from ttun.types import Config +from ttun.types import HttpMessage +from ttun.types import HttpMessageType +from ttun.types import HttpRequestData +from ttun.types import HttpResponseData from ttun.types import Message from ttun.types import MessageType -from ttun.types import RequestData -from ttun.types import ResponseData +from ttun.types import WebsocketMessage +from ttun.types import WebsocketMessageData +from ttun.types import WebsocketMessageType + +logger = logging.getLogger(__name__) +logger.setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET")) class Client: @@ -48,9 +59,12 @@ class Client: self.connection: WebSocketClientProtocol = None self.proxy_origin = f'{"https" if https else "http"}://{to}:{port}' + self.ws_proxy_origin = f'{"wss" if https else "ws"}://{to}:{port}' self.headers = [] if headers is None else headers + self.websocket_connections = {} + async def send(self, data: dict): await self.connection.send(json.dumps(data)) @@ -86,63 +100,181 @@ class Client: def session(self): return ClientSession(base_url=self.proxy_origin, cookie_jar=DummyCookieJar()) + async def handle_request(self, message: HttpMessage, session: ClientSession = None): + if session is None: + session = self.session() + + request: HttpRequestData = message["payload"] + + request["headers"] = [ + *request["headers"], + *self.headers, + ] + + async def response_handler( + response: HttpResponseData, identifier=message["identifier"] + ): + await self.send( + HttpMessage( + type=HttpMessageType.response.value, + identifier=identifier, + payload=response, + ) + ) + + await self.proxy_request( + session=session, + request=request, + on_response=response_handler, + ) + + async def receive_websocket_message(self, message: str, idenitfier: str): + message_data = WebsocketMessage( + identifier=idenitfier, + type=WebsocketMessageType.message.value, + payload=WebsocketMessageData(body=b64encode(message.encode()).decode()), + ) + await self.send(message_data) + + await PubSub.publish( + { + "type": "websocket_outbound", + "payload": { + "id": message_data["identifier"], + "timestamp": datetime.now().isoformat(), + **message_data["payload"], + }, + } + ) + + async def connect_websocket(self, message: WebsocketMessage): + assert not message["identifier"] in self.websocket_connections + + start = perf_counter() + await PubSub.publish( + { + "type": "websocket_connect", + "payload": { + "id": message["identifier"], + "timestamp": datetime.now().isoformat(), + **message["payload"], + }, + } + ) + + async with websockets.connect( + f'{self.ws_proxy_origin}/{message["payload"]["path"]}' + ) as connection: + end = perf_counter() + self.websocket_connections[message["identifier"]] = connection + + await self.send( + WebsocketMessage( + identifier=message["identifier"], + type=WebsocketMessageType.ack.value, + payload=None, + ) + ) + + await PubSub.publish( + { + "type": "websocket_connected", + "payload": { + "id": message["identifier"], + "timing": end - start, + }, + } + ) + + async for m in connection: + await self.receive_websocket_message(m, message["identifier"]) + + async def send_websocket_message(self, message: WebsocketMessage): + assert message["identifier"] in self.websocket_connections + await self.websocket_connections[message["identifier"]].send( + b64decode(message["payload"]["body"]).decode() + ) + + await PubSub.publish( + { + "type": "websocket_inbound", + "payload": { + "id": message["identifier"], + "timestamp": datetime.now().isoformat(), + **message["payload"], + }, + } + ) + + async def disconnect_websocket(self, message: WebsocketMessage): + assert message["identifier"] in self.websocket_connections + + await self.websocket_connections[message["identifier"]].close() + + self.websocket_connections[message["identifier"]] = None + await PubSub.publish( + { + "type": "websocket_disconnect", + "payload": { + "id": message["identifier"], + "timestamp": datetime.now().isoformat(), + **message["payload"], + }, + } + ) + async def handle_messages(self): loop = get_running_loop() + tasks = set() + async with self.session() as session: while True: try: message: Message = await self.receive() + logger.debug(message) - try: - if MessageType(message["type"]) != MessageType.request: - continue - except ValueError: - continue - - request: RequestData = message["payload"] - - request["headers"] = [ - *request["headers"], - *self.headers, - ] - - async def response_handler( - response: ResponseData, identifier=message["identifier"] - ): - await self.send( - Message( - type=MessageType.response.value, - identifier=identifier, - payload=response, + match MessageType(message["type"]): + case MessageType.request: + task = loop.create_task( + self.handle_request(message, session) ) - ) - - await loop.create_task( - self.proxy_request( - session=session, - request=request, - on_response=response_handler, - ) - ) - except ConnectionClosed: - break + case MessageType.ws_connect: + task = loop.create_task(self.connect_websocket(message)) + case MessageType.ws_message: + task = loop.create_task( + self.send_websocket_message(message) + ) + case MessageType.ws_disconnect: + task = loop.create_task(self.disconnect_websocket(message)) + case _: + logger.debug(message) + + tasks.add(task) + task.add_done_callback(tasks.discard) + except ValueError: + continue + except ConnectionClosed as e: + raise e + + for task in tasks: + task.cancel() - async def resend(self, data: RequestData): + async def resend(self, data: HttpRequestData): async with self.session() as session: await self.proxy_request(session, data) async def proxy_request( self, session: ClientSession, - request: RequestData, - on_response: Callable[[ResponseData], Awaitable] = None, + request: HttpRequestData, + on_response: Callable[[HttpResponseData], Awaitable] = None, ): - request_id = uuid4() + request_id = str(uuid4()) await PubSub.publish( { "type": "request", "payload": { - "id": request_id.hex, + "id": request_id, "timestamp": datetime.now().isoformat(), **request, }, @@ -160,7 +292,7 @@ class Client: ) end = perf_counter() - response_data = ResponseData( + response_data = HttpResponseData( status=response.status, headers=[ (key, value) @@ -173,7 +305,7 @@ class Client: except ClientError as e: end = perf_counter() - response_data = ResponseData( + response_data = HttpResponseData( status=(504 if isinstance(e, ClientConnectionError) else 502), headers=[("content-type", "text/plain")], body=b64encode(str(e).encode()).decode(), @@ -186,7 +318,7 @@ class Client: { "type": "response", "payload": { - "id": request_id.hex, + "id": request_id, "timing": end - start, **response_data, }, diff --git a/ttun/inspect_server.py b/ttun/inspect_server.py index 1cd8809..df274a2 100644 --- a/ttun/inspect_server.py +++ b/ttun/inspect_server.py @@ -2,13 +2,12 @@ from importlib import resources from pathlib import Path from typing import Awaitable from typing import Callable -from typing import Optional from aiohttp import web from ttun.pubsub import PubSub from ttun.types import Config -from ttun.types import RequestData +from ttun.types import HttpRequestData BASE_DIR = Path(__file__).resolve().parent @@ -17,7 +16,7 @@ class Server: def __init__( self, config: Config, - on_resend: Callable[[RequestData], Awaitable], + on_resend: Callable[[HttpRequestData], Awaitable], on_started: Callable[["Server"], None], ): self.port = 4040 diff --git a/ttun/pubsub.py b/ttun/pubsub.py index 0ba964e..85a1988 100644 --- a/ttun/pubsub.py +++ b/ttun/pubsub.py @@ -1,7 +1,6 @@ import asyncio from contextlib import contextmanager from typing import Any -from typing import Generator from typing import Iterator diff --git a/ttun/types.py b/ttun/types.py index 59bfa79..cdda635 100644 --- a/ttun/types.py +++ b/ttun/types.py @@ -1,28 +1,72 @@ from enum import Enum +from itertools import chain from typing import Optional from typing import TypedDict -class MessageType(Enum): - request = 'request' - response = 'response' -class RequestData(TypedDict): +class HttpMessageType(Enum): + request = "request" + response = "response" + + +class HttpRequestData(TypedDict): method: str path: str headers: list[tuple[str, str]] body: Optional[str] -class ResponseData(TypedDict): +class HttpResponseData(TypedDict): status: int headers: list[tuple[str, str]] body: Optional[str] -class Message(TypedDict): - type: MessageType +class HttpMessage(TypedDict): + type: HttpMessageType identifier: str - payload: RequestData | ResponseData + payload: HttpRequestData | HttpResponseData + + +class WebsocketMessageType(Enum): + connect = "connect" + disconnect = "disconnect" + message = "message" + ack = "ack" + + +class WebsocketConnectData(TypedDict): + path: str + headers: list[tuple[str, str]] + + +class WebsocketDisconnectData(TypedDict): + pass + + +class WebsocketMessageData(TypedDict): + body: Optional[str] + + +class WebsocketMessage(TypedDict): + type: WebsocketMessageType + identifier: str + payload: Optional[ + WebsocketConnectData | WebsocketDisconnectData | WebsocketMessageData + ] + + +class MessageType(Enum): + request = "request" + response = "response" + + ws_connect = "connect" + ws_disconnect = "disconnect" + ws_message = "message" + ws_ack = "ack" + + +Message = HttpMessage | WebsocketMessage class Config(TypedDict): -- cgit v1.2.3