summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Tom van der Lee <tom@vanderlee.io>2024-03-20 21:47:55 +0100
committerGravatar Tom van der Lee <tom@vanderlee.io>2024-03-20 21:47:55 +0100
commit9a55068e5de5da19e9c3d77455b8c25f8327f896 (patch)
tree10fd45b1a42ca1ac8d0473be4545216e6e41089c
parentf183536067dc694f37445148c15821f1621f5034 (diff)
downloadclient-9a55068e5de5da19e9c3d77455b8c25f8327f896.tar.gz
client-9a55068e5de5da19e9c3d77455b8c25f8327f896.tar.bz2
client-9a55068e5de5da19e9c3d77455b8c25f8327f896.zip
Added websocket support
-rw-r--r--ttun/__main__.py11
-rw-r--r--ttun/client.py216
-rw-r--r--ttun/inspect_server.py5
-rw-r--r--ttun/pubsub.py1
-rw-r--r--ttun/types.py60
5 files changed, 234 insertions, 59 deletions
diff --git a/ttun/__main__.py b/ttun/__main__.py
index 4e693fb..dd83e53 100644
--- a/ttun/__main__.py
+++ b/ttun/__main__.py
@@ -1,21 +1,23 @@
1import asyncio 1import asyncio
2import logging
3import os
2import re 4import re
3import time
4from argparse import ArgumentDefaultsHelpFormatter 5from argparse import ArgumentDefaultsHelpFormatter
5from argparse import ArgumentParser 6from argparse import ArgumentParser
6from asyncio import FIRST_EXCEPTION 7from asyncio import FIRST_EXCEPTION
7from asyncio.exceptions import CancelledError 8from asyncio.exceptions import CancelledError
8from asyncio.exceptions import TimeoutError 9from asyncio.exceptions import TimeoutError
9from typing import Dict
10from typing import Tuple 10from typing import Tuple
11 11
12from websockets.exceptions import ConnectionClosedError
13
14from ttun.client import Client 12from ttun.client import Client
15from ttun.inspect_server import Server 13from ttun.inspect_server import Server
16from ttun.settings import SERVER_HOSTNAME 14from ttun.settings import SERVER_HOSTNAME
17from ttun.settings import SERVER_USING_SSL 15from ttun.settings import SERVER_USING_SSL
18 16
17logging.basicConfig(encoding="utf-8")
18logging.getLogger("asyncio").setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET"))
19logging.getLogger("websockets").setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET"))
20
19inspect_queue = asyncio.Queue() 21inspect_queue = asyncio.Queue()
20 22
21 23
@@ -73,7 +75,6 @@ def main():
73 headers=args.header, 75 headers=args.header,
74 ) 76 )
75 77
76
77 try: 78 try:
78 loop = asyncio.get_running_loop() 79 loop = asyncio.get_running_loop()
79 except RuntimeError: 80 except RuntimeError:
diff --git a/ttun/client.py b/ttun/client.py
index a75c882..b19bb47 100644
--- a/ttun/client.py
+++ b/ttun/client.py
@@ -1,5 +1,8 @@
1import asyncio 1import asyncio
2import json 2import json
3import logging
4import os
5import sys
3from asyncio import get_running_loop 6from asyncio import get_running_loop
4from base64 import b64decode 7from base64 import b64decode
5from base64 import b64encode 8from base64 import b64encode
@@ -24,10 +27,18 @@ from websockets.exceptions import ConnectionClosed
24from ttun import __version__ 27from ttun import __version__
25from ttun.pubsub import PubSub 28from ttun.pubsub import PubSub
26from ttun.types import Config 29from ttun.types import Config
30from ttun.types import HttpMessage
31from ttun.types import HttpMessageType
32from ttun.types import HttpRequestData
33from ttun.types import HttpResponseData
27from ttun.types import Message 34from ttun.types import Message
28from ttun.types import MessageType 35from ttun.types import MessageType
29from ttun.types import RequestData 36from ttun.types import WebsocketMessage
30from ttun.types import ResponseData 37from ttun.types import WebsocketMessageData
38from ttun.types import WebsocketMessageType
39
40logger = logging.getLogger(__name__)
41logger.setLevel(os.environ.get("LOGGING_LEVEL", "NOTSET"))
31 42
32 43
33class Client: 44class Client:
@@ -48,9 +59,12 @@ class Client:
48 self.connection: WebSocketClientProtocol = None 59 self.connection: WebSocketClientProtocol = None
49 60
50 self.proxy_origin = f'{"https" if https else "http"}://{to}:{port}' 61 self.proxy_origin = f'{"https" if https else "http"}://{to}:{port}'
62 self.ws_proxy_origin = f'{"wss" if https else "ws"}://{to}:{port}'
51 63
52 self.headers = [] if headers is None else headers 64 self.headers = [] if headers is None else headers
53 65
66 self.websocket_connections = {}
67
54 async def send(self, data: dict): 68 async def send(self, data: dict):
55 await self.connection.send(json.dumps(data)) 69 await self.connection.send(json.dumps(data))
56 70
@@ -86,63 +100,181 @@ class Client:
86 def session(self): 100 def session(self):
87 return ClientSession(base_url=self.proxy_origin, cookie_jar=DummyCookieJar()) 101 return ClientSession(base_url=self.proxy_origin, cookie_jar=DummyCookieJar())
88 102
103 async def handle_request(self, message: HttpMessage, session: ClientSession = None):
104 if session is None:
105 session = self.session()
106
107 request: HttpRequestData = message["payload"]
108
109 request["headers"] = [
110 *request["headers"],
111 *self.headers,
112 ]
113
114 async def response_handler(
115 response: HttpResponseData, identifier=message["identifier"]
116 ):
117 await self.send(
118 HttpMessage(
119 type=HttpMessageType.response.value,
120 identifier=identifier,
121 payload=response,
122 )
123 )
124
125 await self.proxy_request(
126 session=session,
127 request=request,
128 on_response=response_handler,
129 )
130
131 async def receive_websocket_message(self, message: str, idenitfier: str):
132 message_data = WebsocketMessage(
133 identifier=idenitfier,
134 type=WebsocketMessageType.message.value,
135 payload=WebsocketMessageData(body=b64encode(message.encode()).decode()),
136 )
137 await self.send(message_data)
138
139 await PubSub.publish(
140 {
141 "type": "websocket_outbound",
142 "payload": {
143 "id": message_data["identifier"],
144 "timestamp": datetime.now().isoformat(),
145 **message_data["payload"],
146 },
147 }
148 )
149
150 async def connect_websocket(self, message: WebsocketMessage):
151 assert not message["identifier"] in self.websocket_connections
152
153 start = perf_counter()
154 await PubSub.publish(
155 {
156 "type": "websocket_connect",
157 "payload": {
158 "id": message["identifier"],
159 "timestamp": datetime.now().isoformat(),
160 **message["payload"],
161 },
162 }
163 )
164
165 async with websockets.connect(
166 f'{self.ws_proxy_origin}/{message["payload"]["path"]}'
167 ) as connection:
168 end = perf_counter()
169 self.websocket_connections[message["identifier"]] = connection
170
171 await self.send(
172 WebsocketMessage(
173 identifier=message["identifier"],
174 type=WebsocketMessageType.ack.value,
175 payload=None,
176 )
177 )
178
179 await PubSub.publish(
180 {
181 "type": "websocket_connected",
182 "payload": {
183 "id": message["identifier"],
184 "timing": end - start,
185 },
186 }
187 )
188
189 async for m in connection:
190 await self.receive_websocket_message(m, message["identifier"])
191
192 async def send_websocket_message(self, message: WebsocketMessage):
193 assert message["identifier"] in self.websocket_connections
194 await self.websocket_connections[message["identifier"]].send(
195 b64decode(message["payload"]["body"]).decode()
196 )
197
198 await PubSub.publish(
199 {
200 "type": "websocket_inbound",
201 "payload": {
202 "id": message["identifier"],
203 "timestamp": datetime.now().isoformat(),
204 **message["payload"],
205 },
206 }
207 )
208
209 async def disconnect_websocket(self, message: WebsocketMessage):
210 assert message["identifier"] in self.websocket_connections
211
212 await self.websocket_connections[message["identifier"]].close()
213
214 self.websocket_connections[message["identifier"]] = None
215 await PubSub.publish(
216 {
217 "type": "websocket_disconnect",
218 "payload": {
219 "id": message["identifier"],
220 "timestamp": datetime.now().isoformat(),
221 **message["payload"],
222 },
223 }
224 )
225
89 async def handle_messages(self): 226 async def handle_messages(self):
90 loop = get_running_loop() 227 loop = get_running_loop()
228 tasks = set()
229
91 async with self.session() as session: 230 async with self.session() as session:
92 while True: 231 while True:
93 try: 232 try:
94 message: Message = await self.receive() 233 message: Message = await self.receive()
234 logger.debug(message)
95 235
96 try: 236 match MessageType(message["type"