summaryrefslogtreecommitdiffstats
path: root/ttun_server/websockets.py
diff options
context:
space:
mode:
authorGravatar Tom van der Lee <tomvanderlee@users.noreply.github.com>2024-08-30 15:54:40 +0200
committerGravatar GitHub <noreply@github.com>2024-08-30 15:54:40 +0200
commit0f7c975fe61dab4efb11b49ddc87331c30c26942 (patch)
tree68f4b351b337b9a2269ddb2cb512016c93e7cbbc /ttun_server/websockets.py
parent53a8f300859a50d9f99f1821c35bca999fced6d8 (diff)
parenta72a0485ef8761b95c73cc420723247fafbb6f1c (diff)
downloadserver-0f7c975fe61dab4efb11b49ddc87331c30c26942.tar.gz
server-0f7c975fe61dab4efb11b49ddc87331c30c26942.tar.bz2
server-0f7c975fe61dab4efb11b49ddc87331c30c26942.zip
Merge pull request #6 from tomvanderlee/feature/websocketsv2.1.0main
Added websocket support
Diffstat (limited to 'ttun_server/websockets.py')
-rw-r--r--ttun_server/websockets.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/ttun_server/websockets.py b/ttun_server/websockets.py
new file mode 100644
index 0000000..0800cbc
--- /dev/null
+++ b/ttun_server/websockets.py
@@ -0,0 +1,179 @@
1import asyncio
2import json
3import logging
4import os
5import typing
6from asyncio import create_task
7from base64 import b64encode, b64decode
8from contextlib import asynccontextmanager
9from typing import Optional
10from uuid import uuid4
11
12from starlette.endpoints import WebSocketEndpoint
13from starlette.types import Scope, Receive, Send
14from starlette.websockets import WebSocket
15
16import ttun_server
17from ttun_server.proxy_queue import ProxyQueue
18from ttun_server.types import Config, Message, WebsocketMessageType, \
19 WebsocketConnectData, WebsocketMessage, WebsocketMessageData, WebsocketDisconnectData, MessageType
20
21logger = logging.getLogger(__name__)
22logger.setLevel('DEBUG')
23
24class WebsocketProxy(WebSocketEndpoint):
25 encoding = 'json'
26 websocket_listen_task = None
27
28 def __init__(self, *args, **kwargs):
29 super().__init__(*args, **kwargs)
30 self.id = str(uuid4())
31
32 @asynccontextmanager
33 async def proxy(self, websocket: WebSocket, message: WebsocketMessage):
34 [subdomain, *_] = websocket.url.hostname.split('.')
35
36 expect_ack = WebsocketMessageType(message['type']) == WebsocketMessageType.connect
37
38 try:
39 request_queue = await ProxyQueue.get_for_identifier(subdomain)
40 await request_queue.enqueue(message)
41
42 if expect_ack:
43 response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{message["identifier"]}')
44 yield await response_queue.dequeue()
45 await response_queue.delete()
46 else:
47 yield
48 except AssertionError:
49 pass
50
51 async def listen_for_messages(self, websocket: WebSocket):
52 [subdomain, *_] = websocket.url.hostname.split('.')
53
54 print('listen', self.id)
55 response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{self.id}')
56
57 while True:
58 message: WebsocketMessage = await response_queue.dequeue()
59 logger.debug(message)
60 await websocket.send_text(b64decode(message['payload']['body'].encode()).decode())
61
62 async def on_connect(self, websocket: WebSocket) -> None:
63 message = WebsocketMessage(
64 type=WebsocketMessageType.connect.value,
65 identifier=self.id,
66 payload=WebsocketConnectData(
67 path=websocket.path_params['path'],
68 headers=[
69 (k.decode(), v.decode())
70 for k, v
71 in websocket.scope['headers']
72 ],
73 )
74 )
75
76 async with self.proxy(websocket, message) as m:
77 type = WebsocketMessageType(m['type'])
78
79 if type == WebsocketMessageType.ack:
80 await super().on_connect(websocket)
81
82 self.websocket_listen_task = asyncio.create_task(self.listen_for_messages(websocket))
83
84 def callback(*args, **kwargs):
85 self.websocket_listen_task = None
86
87 self.websocket_listen_task.add_done_callback(callback)
88
89 async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
90 match data:
91 case dict():
92 data_bytes = json.dumps(data).encode()
93 case bytes():
94 data_bytes = data
95 case _:
96 data_bytes = data.encode()
97
98 message = WebsocketMessage(
99 type=WebsocketMessageType.message.value,
100 identifier=self.id,
101 payload=WebsocketMessageData(
102 body=b64encode(data_bytes).decode(),
103 )
104 )
105
106 async with self.proxy(websocket, message):
107 pass
108
109 async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
110 message = WebsocketMessage(
111 type=WebsocketMessageType.disconnect.value,
112 identifier=self.id,
113 payload=WebsocketDisconnectData(
114 close_code=close_code,
115 )
116 )
117
118 async with self.proxy(websocket, message):
119 if self.websocket_listen_task is not None:
120 self.websocket_listen_task.cancel()
121
122class Tunnel(WebSocketEndpoint):
123 encoding = 'json'
124
125 def __init__(self, scope: Scope, receive: Receive, send: Send):
126 super().__init__(scope, receive, send)
127 self.request_task = None
128 self.config: Optional[Config] = None
129
130 async def handle_requests(self, websocket: WebSocket):
131 while request := await self.proxy_queue.dequeue():
132 create_task(websocket.send_json(request))
133
134 async def on_connect(self, websocket: WebSocket) -> None:
135 await websocket.accept()
136 self.config = await websocket.receive_json()
137
138 client_version = self.config.get('version', '1.0.0')
139 logger.debug('client_version %s', client_version)
140
141 if 'git' not in client_version and ttun_server.__version__ != 'development':
142 [client_major, *_] = [int(i) for i in client_version.split('.')[:3]]
143 [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')]
144
145 if client_major < server_major:
146 await websocket.close(4000, 'Your client is too old')
147
148 if client_major > server_major:
149 await websocket.close(4001, 'Your client is too new')
150
151
152 if self.config['subdomain'] is None \
153 or await ProxyQueue.has_connection(self.config['subdomain']):
154 self.config['subdomain'] = uuid4().hex
155
156
157 self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])
158
159 hostname = os.environ.get("TUNNEL_DOMAIN")
160 protocol = "https" if os.environ.get("SECURE", False) else "http"
161
162 await websocket.send_json({
163 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}'
164 })
165
166 self.request_task = asyncio.create_task(self.handle_requests(websocket))
167
168 async def on_receive(self, websocket: WebSocket, data: Message):
169 try:
170 response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}")
171 await response_queue.enqueue(data)
172 except AssertionError:
173 pass
174
175 async def on_disconnect(self, websocket: WebSocket, close_code: int):
176 await self.proxy_queue.delete()
177
178 if self.request_task is not None:
179 self.request_task.cancel()