diff options
| author | 2024-03-20 21:47:55 +0100 | |
|---|---|---|
| committer | 2024-03-20 21:47:55 +0100 | |
| commit | 9a55068e5de5da19e9c3d77455b8c25f8327f896 (patch) | |
| tree | 10fd45b1a42ca1ac8d0473be4545216e6e41089c | |
| parent | f183536067dc694f37445148c15821f1621f5034 (diff) | |
| download | client-9a55068e5de5da19e9c3d77455b8c25f8327f896.tar.gz client-9a55068e5de5da19e9c3d77455b8c25f8327f896.tar.bz2 client-9a55068e5de5da19e9c3d77455b8c25f8327f896.zip | |
Added websocket support
| -rw-r--r-- | ttun/__main__.py | 11 | ||||
| -rw-r--r-- | ttun/client.py | 216 | ||||
| -rw-r--r-- | ttun/inspect_server.py | 5 | ||||
| -rw-r--r-- | ttun/pubsub.py | 1 | ||||
| -rw-r--r-- | ttun/types.py | 60 |
5 files changed, 234 insertions, 59 deletions
diff --git a/ttun/__main__.py b/ttun/__main__.py index 4e693fb..dd83e53 100644 --- a/ttun/__main__.py +++ b/ttun/__main__.py | |||
| @@ -1,21 +1,23 @@ | |||
| 1 | import asyncio | 1 | import asyncio |
| 2 | import logging | ||
| 3 | import os | ||
| 2 | import re | 4 | import re |
| 3 | import time | ||
| 4 | from argparse import ArgumentDefaultsHelpFormatter | 5 | from argparse import ArgumentDefaultsHelpFormatter |
| 5 | from argparse import ArgumentParser | 6 | from argparse import ArgumentParser |
| 6 | from asyncio import FIRST_EXCEPTION | 7 | from asyncio import FIRST_EXCEPTION |
| 7 | from asyncio.exceptions import CancelledError | 8 | from asyncio.exceptions import CancelledError |
| 8 | from asyncio.exceptions import TimeoutError | 9 | from asyncio.exceptions import TimeoutError |
| 9 | from typing import Dict | ||
| 10 | from typing import Tuple | 10 | from typing import Tuple |
| 11 | 11 | ||
| 12 | from websockets.exceptions import ConnectionClosedError | ||
| 13 | |||
| 14 | from ttun.client import Client | 12 | from ttun.client import Client |
| 15 | from ttun.inspect_server import Server | 13 | from ttun.inspect_server import Server |
| 16 | from ttun.settings import SERVER_HOSTNAME | 14 | from ttun.settings import SERVER_HOSTNAME |
| 17 | from ttun.settings import SERVER_USING_SSL | 15 | from ttun.settings import SERVER_USING_SSL |
| 18 | 16 | ||
| 17 | logging.basicConfig(encoding="utf-8") | ||
| 18 | logging.getLogger("asyncio").setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET")) | ||
| 19 | logging.getLogger("websockets").setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET")) | ||
| 20 | |||
| 19 | inspect_queue = asyncio.Queue() | 21 | inspect_queue = asyncio.Queue() |
| 20 | 22 | ||
| 21 | 23 | ||
| @@ -73,7 +75,6 @@ def main(): | |||
| 73 | headers=args.header, | 75 | headers=args.header, |
| 74 | ) | 76 | ) |
| 75 | 77 | ||
| 76 | |||
| 77 | try: | 78 | try: |
| 78 | loop = asyncio.get_running_loop() | 79 | loop = asyncio.get_running_loop() |
| 79 | except RuntimeError: | 80 | except RuntimeError: |
diff --git a/ttun/client.py b/ttun/client.py index a75c882..b19bb47 100644 --- a/ttun/client.py +++ b/ttun/client.py | |||
| @@ -1,5 +1,8 @@ | |||
| 1 | import asyncio | 1 | import asyncio |
| 2 | import json | 2 | import json |
| 3 | import logging | ||
| 4 | import os | ||
| 5 | import sys | ||
| 3 | from asyncio import get_running_loop | 6 | from asyncio import get_running_loop |
| 4 | from base64 import b64decode | 7 | from base64 import b64decode |
| 5 | from base64 import b64encode | 8 | from base64 import b64encode |
| @@ -24,10 +27,18 @@ from websockets.exceptions import ConnectionClosed | |||
| 24 | from ttun import __version__ | 27 | from ttun import __version__ |
| 25 | from ttun.pubsub import PubSub | 28 | from ttun.pubsub import PubSub |
| 26 | from ttun.types import Config | 29 | from ttun.types import Config |
| 30 | from ttun.types import HttpMessage | ||
| 31 | from ttun.types import HttpMessageType | ||
| 32 | from ttun.types import HttpRequestData | ||
| 33 | from ttun.types import HttpResponseData | ||
| 27 | from ttun.types import Message | 34 | from ttun.types import Message |
| 28 | from ttun.types import MessageType | 35 | from ttun.types import MessageType |
| 29 | from ttun.types import RequestData | 36 | from ttun.types import WebsocketMessage |
| 30 | from ttun.types import ResponseData | 37 | from ttun.types import WebsocketMessageData |
| 38 | from ttun.types import WebsocketMessageType | ||
| 39 | |||
| 40 | logger = logging.getLogger(__name__) | ||
| 41 | logger.setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET")) | ||
| 31 | 42 | ||
| 32 | 43 | ||
| 33 | class Client: | 44 | class Client: |
| @@ -48,9 +59,12 @@ class Client: | |||
| 48 | self.connection: WebSocketClientProtocol = None | 59 | self.connection: WebSocketClientProtocol = None |
| 49 | 60 | ||
| 50 | self.proxy_origin = f'{"https" if https else "http"}://{to}:{port}' | 61 | self.proxy_origin = f'{"https" if https else "http"}://{to}:{port}' |
| 62 | self.ws_proxy_origin = f'{"wss" if https else "ws"}://{to}:{port}' | ||
| 51 | 63 | ||
| 52 | self.headers = [] if headers is None else headers | 64 | self.headers = [] if headers is None else headers |
| 53 | 65 | ||
| 66 | self.websocket_connections = {} | ||
| 67 | |||
| 54 | async def send(self, data: dict): | 68 | async def send(self, data: dict): |
| 55 | await self.connection.send(json.dumps(data)) | 69 | await self.connection.send(json.dumps(data)) |
| 56 | 70 | ||
| @@ -86,63 +100,181 @@ class Client: | |||
| 86 | def session(self): | 100 | def session(self): |
| 87 | return ClientSession(base_url=self.proxy_origin, cookie_jar=DummyCookieJar()) | 101 | return ClientSession(base_url=self.proxy_origin, cookie_jar=DummyCookieJar()) |
| 88 | 102 | ||
| 103 | async def handle_request(self, message: HttpMessage, session: ClientSession = None): | ||
| 104 | if session is None: | ||
| 105 | session = self.session() | ||
| 106 | |||
| 107 | request: HttpRequestData = message["payload"] | ||
| 108 | |||
| 109 | request["headers"] = [ | ||
| 110 | *request["headers"], | ||
| 111 | *self.headers, | ||
| 112 | ] | ||
| 113 | |||
| 114 | async def response_handler( | ||
| 115 | response: HttpResponseData, identifier=message["identifier"] | ||
| 116 | ): | ||
| 117 | await self.send( | ||
| 118 | HttpMessage( | ||
| 119 | type=HttpMessageType.response.value, | ||
| 120 | identifier=identifier, | ||
| 121 | payload=response, | ||
| 122 | ) | ||
| 123 | ) | ||
| 124 | |||
| 125 | await self.proxy_request( | ||
| 126 | session=session, | ||
| 127 | request=request, | ||
| 128 | on_response=response_handler, | ||
| 129 | ) | ||
| 130 | |||
| 131 | async def receive_websocket_message(self, message: str, idenitfier: str): | ||
| 132 | message_data = WebsocketMessage( | ||
| 133 | identifier=idenitfier, | ||
| 134 | type=WebsocketMessageType.message.value, | ||
| 135 | payload=WebsocketMessageData(body=b64encode(message.encode()).decode()), | ||
| 136 | ) | ||
| 137 | await self.send(message_data) | ||
| 138 | |||
| 139 | await PubSub.publish( | ||
| 140 | { | ||
| 141 | "type": "websocket_outbound", | ||
| 142 | "payload": { | ||
| 143 | "id": message_data["identifier"], | ||
| 144 | "timestamp": datetime.now().isoformat(), | ||
| 145 | **message_data["payload"], | ||
| 146 | }, | ||
| 147 | } | ||
| 148 | ) | ||
| 149 | |||
| 150 | async def connect_websocket(self, message: WebsocketMessage): | ||
| 151 | assert not message["identifier"] in self.websocket_connections | ||
| 152 | |||
| 153 | start = perf_counter() | ||
| 154 | await PubSub.publish( | ||
| 155 | { | ||
| 156 | "type": "websocket_connect", | ||
| 157 | "payload": { | ||
| 158 | "id": message["identifier"], | ||
| 159 | "timestamp": datetime.now().isoformat(), | ||
| 160 | **message["payload"], | ||
| 161 | }, | ||
| 162 | } | ||
| 163 | ) | ||
| 164 | |||
| 165 | async with websockets.connect( | ||
| 166 | f'{self.ws_proxy_origin}/{message["payload"]["path"]}' | ||
| 167 | ) as connection: | ||
| 168 | end = perf_counter() | ||
| 169 | self.websocket_connections[message["identifier"]] = connection | ||
| 170 | |||
| 171 | await self.send( | ||
| 172 | WebsocketMessage( | ||
| 173 | identifier=message["identifier"], | ||
| 174 | type=WebsocketMessageType.ack.value, | ||
| 175 | payload=None, | ||
| 176 | ) | ||
| 177 | ) | ||
| 178 | |||
| 179 | await PubSub.publish( | ||
| 180 | { | ||
| 181 | "type": "websocket_connected", | ||
| 182 | "payload": { | ||
| 183 | "id": message["identifier"], | ||
| 184 | "timing": end - start, | ||
| 185 | }, | ||
| 186 | } | ||
| 187 | ) | ||
| 188 | |||
| 189 | async for m in connection: | ||
| 190 | await self.receive_websocket_message(m, message["identifier"]) | ||
| 191 | |||
| 192 | async def send_websocket_message(self, message: WebsocketMessage): | ||
| 193 | assert message["identifier"] in self.websocket_connections | ||
| 194 | await self.websocket_connections[message["identifier"]].send( | ||
| 195 | b64decode(message["payload"]["body"]).decode() | ||
| 196 | ) | ||
| 197 | |||
| 198 | await PubSub.publish( | ||
| 199 | { | ||
| 200 | "type": "websocket_inbound", | ||
| 201 | "payload": { | ||
| 202 | "id": message["identifier"], | ||
| 203 | "timestamp": datetime.now().isoformat(), | ||
| 204 | **message["payload"], | ||
| 205 | }, | ||
| 206 | } | ||
| 207 | ) | ||
| 208 | |||
| 209 | async def disconnect_websocket(self, message: WebsocketMessage): | ||
| 210 | assert message["identifier"] in self.websocket_connections | ||
| 211 | |||
| 212 | await self.websocket_connections[message["identifier"]].close() | ||
| 213 | |||
| 214 | self.websocket_connections[message["identifier"]] = None | ||
| 215 | await PubSub.publish( | ||
| 216 | { | ||
| 217 | "type": "websocket_disconnect", | ||
| 218 | "payload": { | ||
| 219 | "id": message["identifier"], | ||
| 220 | "timestamp": datetime.now().isoformat(), | ||
| 221 | **message["payload"], | ||
| 222 | }, | ||
| 223 | } | ||
| 224 | ) | ||
| 225 | |||
| 89 | async def handle_messages(self): | 226 | async def handle_messages(self): |
| 90 | loop = get_running_loop() | 227 | loop = get_running_loop() |
| 228 | tasks = set() | ||
| 229 | |||
| 91 | async with self.session() as session: | 230 | async with self.session() as session: |
| 92 | while True: | 231 | while True: |
| 93 | try: | 232 | try: |
| 94 | message: Message = await self.receive() | 233 | message: Message = await self.receive() |
| 234 | logger.debug(message) | ||
| 95 | 235 | ||
| 96 | try: | 236 | ma |
