diff options
| author | 2024-08-30 15:54:40 +0200 | |
|---|---|---|
| committer | 2024-08-30 15:54:40 +0200 | |
| commit | 0f7c975fe61dab4efb11b49ddc87331c30c26942 (patch) | |
| tree | 68f4b351b337b9a2269ddb2cb512016c93e7cbbc /ttun_server/websockets.py | |
| parent | 53a8f300859a50d9f99f1821c35bca999fced6d8 (diff) | |
| parent | a72a0485ef8761b95c73cc420723247fafbb6f1c (diff) | |
| download | server-main.tar.gz server-main.tar.bz2 server-main.zip | |
Added websocket support
Diffstat (limited to 'ttun_server/websockets.py')
| -rw-r--r-- | ttun_server/websockets.py | 179 |
1 files changed, 179 insertions, 0 deletions
diff --git a/ttun_server/websockets.py b/ttun_server/websockets.py new file mode 100644 index 0000000..0800cbc --- /dev/null +++ b/ttun_server/websockets.py | |||
| @@ -0,0 +1,179 @@ | |||
| 1 | import asyncio | ||
| 2 | import json | ||
| 3 | import logging | ||
| 4 | import os | ||
| 5 | import typing | ||
| 6 | from asyncio import create_task | ||
| 7 | from base64 import b64encode, b64decode | ||
| 8 | from contextlib import asynccontextmanager | ||
| 9 | from typing import Optional | ||
| 10 | from uuid import uuid4 | ||
| 11 | |||
| 12 | from starlette.endpoints import WebSocketEndpoint | ||
| 13 | from starlette.types import Scope, Receive, Send | ||
| 14 | from starlette.websockets import WebSocket | ||
| 15 | |||
| 16 | import ttun_server | ||
| 17 | from ttun_server.proxy_queue import ProxyQueue | ||
| 18 | from ttun_server.types import Config, Message, WebsocketMessageType, \ | ||
| 19 | WebsocketConnectData, WebsocketMessage, WebsocketMessageData, WebsocketDisconnectData, MessageType | ||
| 20 | |||
| 21 | logger = logging.getLogger(__name__) | ||
| 22 | logger.setLevel('DEBUG') | ||
| 23 | |||
| 24 | class WebsocketProxy(WebSocketEndpoint): | ||
| 25 | encoding = 'json' | ||
| 26 | websocket_listen_task = None | ||
| 27 | |||
| 28 | def __init__(self, *args, **kwargs): | ||
| 29 | super().__init__(*args, **kwargs) | ||
| 30 | self.id = str(uuid4()) | ||
| 31 | |||
| 32 | @asynccontextmanager | ||
| 33 | async def proxy(self, websocket: WebSocket, message: WebsocketMessage): | ||
| 34 | [subdomain, *_] = websocket.url.hostname.split('.') | ||
| 35 | |||
| 36 | expect_ack = WebsocketMessageType(message['type']) == WebsocketMessageType.connect | ||
| 37 | |||
| 38 | try: | ||
| 39 | request_queue = await ProxyQueue.get_for_identifier(subdomain) | ||
| 40 | await request_queue.enqueue(message) | ||
| 41 | |||
| 42 | if expect_ack: | ||
| 43 | response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{message["identifier"]}') | ||
| 44 | yield await response_queue.dequeue() | ||
| 45 | await response_queue.delete() | ||
| 46 | else: | ||
| 47 | yield | ||
| 48 | except AssertionError: | ||
| 49 | pass | ||
| 50 | |||
| 51 | async def listen_for_messages(self, websocket: WebSocket): | ||
| 52 | [subdomain, *_] = websocket.url.hostname.split('.') | ||
| 53 | |||
| 54 | print('listen', self.id) | ||
| 55 | response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{self.id}') | ||
| 56 | |||
| 57 | while True: | ||
| 58 | message: WebsocketMessage = await response_queue.dequeue() | ||
| 59 | logger.debug(message) | ||
| 60 | await websocket.send_text(b64decode(message['payload']['body'].encode()).decode()) | ||
| 61 | |||
| 62 | async def on_connect(self, websocket: WebSocket) -> None: | ||
| 63 | message = WebsocketMessage( | ||
| 64 | type=WebsocketMessageType.connect.value, | ||
| 65 | identifier=self.id, | ||
| 66 | payload=WebsocketConnectData( | ||
| 67 | path=websocket.path_params['path'], | ||
| 68 | headers=[ | ||
| 69 | (k.decode(), v.decode()) | ||
| 70 | for k, v | ||
| 71 | in websocket.scope['headers'] | ||
| 72 | ], | ||
| 73 | ) | ||
| 74 | ) | ||
| 75 | |||
| 76 | async with self.proxy(websocket, message) as m: | ||
| 77 | type = WebsocketMessageType(m['type']) | ||
| 78 | |||
| 79 | if type == WebsocketMessageType.ack: | ||
| 80 | await super().on_connect(websocket) | ||
| 81 | |||
| 82 | self.websocket_listen_task = asyncio.create_task(self.listen_for_messages(websocket)) | ||
| 83 | |||
| 84 | def callback(*args, **kwargs): | ||
| 85 | self.websocket_listen_task = None | ||
| 86 | |||
| 87 | self.websocket_listen_task.add_done_callback(callback) | ||
| 88 | |||
| 89 | async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: | ||
| 90 | match data: | ||
| 91 | case dict(): | ||
| 92 | data_bytes = json.dumps(data).encode() | ||
| 93 | case bytes(): | ||
| 94 | data_bytes = data | ||
| 95 | case _: | ||
| 96 | data_bytes = data.encode() | ||
| 97 | |||
| 98 | message = WebsocketMessage( | ||
| 99 | type=WebsocketMessageType.message.value, | ||
| 100 | identifier=self.id, | ||
| 101 | payload=WebsocketMessageData( | ||
| 102 | body=b64encode(data_bytes).decode(), | ||
| 103 | ) | ||
| 104 | ) | ||
| 105 | |||
| 106 | async with self.proxy(websocket, message): | ||
| 107 | pass | ||
| 108 | |||
| 109 | async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: | ||
| 110 | message = WebsocketMessage( | ||
| 111 | type=WebsocketMessageType.disconnect.value, | ||
| 112 | identifier=self.id, | ||
| 113 | payload=WebsocketDisconnectData( | ||
| 114 | close_code=close_code, | ||
| 115 | ) | ||
| 116 | ) | ||
| 117 | |||
| 118 | async with self.proxy(websocket, message): | ||
| 119 | if self.websocket_listen_task is not None: | ||
| 120 | self.websocket_listen_task.cancel() | ||
| 121 | |||
| 122 | class Tunnel(WebSocketEndpoint): | ||
| 123 | encoding = 'json' | ||
| 124 | |||
| 125 | def __init__(self, scope: Scope, receive: Receive, send: Send): | ||
| 126 | super().__init__(scope, receive, send) | ||
| 127 | self.request_task = None | ||
| 128 | self.config: Optional[Config] = None | ||
| 129 | |||
| 130 | async def handle_requests(self, websocket: WebSocket): | ||
| 131 | while request := await self.proxy_queue.dequeue(): | ||
| 132 | create_task(websocket.send_json(request)) | ||
| 133 | |||
| 134 | async def on_connect(self, websocket: WebSocket) -> None: | ||
| 135 | await websocket.accept() | ||
| 136 | self.config = await websocket.receive_json() | ||
| 137 | |||
| 138 | client_version = self.config.get('version', '1.0.0') | ||
| 139 | logger.debug('client_version %s', client_version) | ||
| 140 | |||
| 141 | if 'git' not in client_version and ttun_server.__version__ != 'development': | ||
| 142 | [client_major, *_] = [int(i) for i in client_version.split('.')[:3]] | ||
| 143 | [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')] | ||
| 144 | |||
| 145 | if client_major < server_major: | ||
| 146 | await websocket.close(4000, 'Your client is too old') | ||
| 147 | |||
| 148 | if client_major > server_major: | ||
| 149 | await websocket.close(4001, 'Your client is too new') | ||
| 150 | |||
| 151 | |||
| 152 | if self.config['subdomain'] is None \ | ||
| 153 | or await ProxyQueue.has_connection(self.config['subdomain']): | ||
| 154 | self.config['subdomain'] = uuid4().hex | ||
| 155 | |||
| 156 | |||
| 157 | self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) | ||
| 158 | |||
| 159 | hostname = os.environ.get("TUNNEL_DOMAIN") | ||
| 160 | protocol = "https" if os.environ.get("SECURE", False) else "http" | ||
| 161 | |||
| 162 | await websocket.send_json({ | ||
| 163 | 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}' | ||
| 164 | }) | ||
| 165 | |||
| 166 | self.request_task = asyncio.create_task(self.handle_requests(websocket)) | ||
| 167 | |||
| 168 | async def on_receive(self, websocket: WebSocket, data: Message): | ||
| 169 | try: | ||
| 170 | response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}") | ||
| 171 | await response_queue.enqueue(data) | ||
| 172 | except AssertionError: | ||
| 173 | pass | ||
| 174 | |||
| 175 | async def on_disconnect(self, websocket: WebSocket, close_code: int): | ||
| 176 | await self.proxy_queue.delete() | ||
| 177 | |||
| 178 | if self.request_task is not None: | ||
| 179 | self.request_task.cancel() | ||
