summaryrefslogtreecommitdiffstats
path: root/ttun_server/endpoints.py
diff options
context:
space:
mode:
authorGravatar Tom van der Lee <tom@vanderlee.io>2024-03-20 21:48:45 +0100
committerGravatar Tom van der Lee <tom@vanderlee.io>2024-03-20 21:48:45 +0100
commit486087cdb349dbc07b479d2286a02bdca310ea38 (patch)
tree43f7baed542a1bc819884ddeebc6e15dfcbd42b0 /ttun_server/endpoints.py
parent53a8f300859a50d9f99f1821c35bca999fced6d8 (diff)
downloadserver-486087cdb349dbc07b479d2286a02bdca310ea38.tar.gz
server-486087cdb349dbc07b479d2286a02bdca310ea38.tar.bz2
server-486087cdb349dbc07b479d2286a02bdca310ea38.zip
Added websocket support
Diffstat (limited to 'ttun_server/endpoints.py')
-rw-r--r--ttun_server/endpoints.py75
1 files changed, 5 insertions, 70 deletions
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py
index 3e263da..eae0ebe 100644
--- a/ttun_server/endpoints.py
+++ b/ttun_server/endpoints.py
@@ -1,20 +1,13 @@
1import asyncio
2import logging 1import logging
3import os
4from asyncio import create_task
5from base64 import b64decode, b64encode 2from base64 import b64decode, b64encode
6from typing import Optional, Any
7from uuid import uuid4 3from uuid import uuid4
8 4
9from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint 5from starlette.endpoints import HTTPEndpoint
10from starlette.requests import Request 6from starlette.requests import Request
11from starlette.responses import Response 7from starlette.responses import Response
12from starlette.types import Scope, Receive, Send
13from starlette.websockets import WebSocket
14 8
15import ttun_server
16from ttun_server.proxy_queue import ProxyQueue 9from ttun_server.proxy_queue import ProxyQueue
17from ttun_server.types import RequestData, Config, Message, MessageType 10from ttun_server.types import HttpRequestData, Message, HttpMessageType, HttpMessage
18 11
19logger = logging.getLogger(__name__) 12logger = logging.getLogger(__name__)
20 13
@@ -44,11 +37,11 @@ class Proxy(HTTPEndpoint):
44 37
45 logger.debug('PROXY %s%s ', subdomain, request.url) 38 logger.debug('PROXY %s%s ', subdomain, request.url)
46 await request_queue.enqueue( 39 await request_queue.enqueue(
47 Message( 40 HttpMessage(
48 type=MessageType.request.value, 41 type=HttpMessageType.request.value,
49 identifier=identifier, 42 identifier=identifier,
50 payload= 43 payload=
51 RequestData( 44 HttpRequestData(
52 method=request.method, 45 method=request.method,
53 path=str(request.url).replace(str(request.base_url), '/'), 46 path=str(request.url).replace(str(request.base_url), '/'),
54 headers=list(request.headers.items()), 47 headers=list(request.headers.items()),
@@ -78,61 +71,3 @@ class Health(HTTPEndpoint):
78 await response(self.scope, self.receive, self.send) 71 await response(self.scope, self.receive, self.send)
79 72
80 73
81class Tunnel(WebSocketEndpoint):
82 encoding = 'json'
83
84 def __init__(self, scope: Scope, receive: Receive, send: Send):
85 super().__init__(scope, receive, send)
86 self.request_task = None
87 self.config: Optional[Config] = None
88
89 async def handle_requests(self, websocket: WebSocket):
90 while request := await self.proxy_queue.dequeue():
91 create_task(websocket.send_json(request))
92
93 async def on_connect(self, websocket: WebSocket) -> None:
94 await websocket.accept()
95 self.config = await websocket.receive_json()
96
97 client_version = self.config.get('version', '1.0.0')
98 logger.debug('client_version %s', client_version)
99
100 if 'git' not in client_version and ttun_server.__version__ != 'development':
101 [client_major, *_] = [int(i) for i in client_version.split('.')[:3]]
102 [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')]
103
104 if client_major < server_major:
105 await websocket.close(4000, 'Your client is too old')
106
107 if client_major > server_major:
108 await websocket.close(4001, 'Your client is too new')
109
110
111 if self.config['subdomain'] is None \
112 or await ProxyQueue.has_connection(self.config['subdomain']):
113 self.config['subdomain'] = uuid4().hex
114
115
116 self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])
117
118 hostname = os.environ.get("TUNNEL_DOMAIN")
119 protocol = "https" if os.environ.get("SECURE", False) else "http"
120
121 await websocket.send_json({
122 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}'
123 })
124
125 self.request_task = asyncio.create_task(self.handle_requests(websocket))
126
127 async def on_receive(self, websocket: WebSocket, data: Message):
128 try:
129 response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}")
130 await response_queue.enqueue(data)
131 except AssertionError:
132 pass
133
134 async def on_disconnect(self, websocket: WebSocket, close_code: int):
135 await self.proxy_queue.delete()
136
137 if self.request_task is not None:
138 self.request_task.cancel()