summaryrefslogtreecommitdiffstats
path: root/ttun_server/endpoints.py
diff options
context:
space:
mode:
Diffstat (limited to 'ttun_server/endpoints.py')
-rw-r--r--ttun_server/endpoints.py43
1 files changed, 17 insertions, 26 deletions
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py
index d59cb7c..5b9e57f 100644
--- a/ttun_server/endpoints.py
+++ b/ttun_server/endpoints.py
@@ -1,4 +1,5 @@
1import asyncio 1import asyncio
2import logging
2import os 3import os
3from asyncio import Queue 4from asyncio import Queue
4from base64 import b64decode, b64encode 5from base64 import b64decode, b64encode
@@ -11,9 +12,10 @@ from starlette.responses import Response
11from starlette.types import Scope, Receive, Send 12from starlette.types import Scope, Receive, Send
12from starlette.websockets import WebSocket 13from starlette.websockets import WebSocket
13 14
14from ttun_server.types import Connection, RequestData, Config, ResponseData 15from ttun_server.proxy_queue import ProxyQueue
16from ttun_server.types import RequestData, Config, ResponseData
15 17
16from ttun_server.connections import connections 18logger = logging.getLogger(__name__)
17 19
18 20
19class Proxy(HTTPEndpoint): 21class Proxy(HTTPEndpoint):
@@ -23,10 +25,10 @@ class Proxy(HTTPEndpoint):
23 [subdomain, *_] = request.headers['host'].split('.') 25 [subdomain, *_] = request.headers['host'].split('.')
24 response = Response(content='Not Found', status_code=404) 26 response = Response(content='Not Found', status_code=404)
25 27
26 if subdomain in connections: 28 try:
27 connection = connections[subdomain] 29 queue = await ProxyQueue.get_for_identifier(subdomain)
28 30
29 await connection['requests'].put(RequestData( 31 await queue.send_request(RequestData(
30 method=request.method, 32 method=request.method,
31 path=str(request.url).replace(str(request.base_url), '/'), 33 path=str(request.url).replace(str(request.base_url), '/'),
32 headers=dict(request.headers), 34 headers=dict(request.headers),
@@ -34,12 +36,14 @@ class Proxy(HTTPEndpoint):
34 body=b64encode(await request.body()).decode() 36 body=b64encode(await request.body()).decode()
35 )) 37 ))
36 38
37 _response = await connection['responses'].get() 39 _response = await queue.handle_response()
38 response = Response( 40 response = Response(
39 status_code=_response['status'], 41 status_code=_response['status'],
40 headers=_response['headers'], 42 headers=_response['headers'],
41 content=b64decode(_response['body'].encode()) 43 content=b64decode(_response['body'].encode())
42 ) 44 )
45 except AssertionError:
46 pass
43 47
44 await response(self.scope, self.receive, self.send) 48 await response(self.scope, self.receive, self.send)
45 49
@@ -52,16 +56,8 @@ class Tunnel(WebSocketEndpoint):
52 self.request_task = None 56 self.request_task = None
53 self.config: Optional[Config] = None 57 self.config: Optional[Config] = None
54 58
55 @property
56 def requests(self) -> Queue[RequestData]:
57 return connections[self.config['subdomain']]['requests']
58
59 @property
60 def responses(self) -> Queue[ResponseData]:
61 return connections[self.config['subdomain']]['responses']
62
63 async def handle_requests(self, websocket: WebSocket): 59 async def handle_requests(self, websocket: WebSocket):
64 while request := await self.requests.get(): 60 while request := await self.proxy_queue.handle_request():
65 await websocket.send_json(request) 61 await websocket.send_json(request)
66 62
67 async def on_connect(self, websocket: WebSocket) -> None: 63 async def on_connect(self, websocket: WebSocket) -> None:
@@ -69,14 +65,10 @@ class Tunnel(WebSocketEndpoint):
69 self.config = await websocket.receive_json() 65 self.config = await websocket.receive_json()
70 66
71 if self.config['subdomain'] is None \ 67 if self.config['subdomain'] is None \
72 or self.config['subdomain'] in connections: 68 or await ProxyQueue.has_connection(self.config['subdomain']):
73 self.config['subdomain'] = uuid4().hex 69 self.config['subdomain'] = uuid4().hex
74 70
75 71 self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])
76 connections[self.config['subdomain']] = Connection(
77 requests=Queue(),
78 responses=Queue(),
79 )
80 72
81 hostname = os.environ.get("TUNNEL_DOMAIN") 73 hostname = os.environ.get("TUNNEL_DOMAIN")
82 protocol = "https" if os.environ.get("SECURE", False) else "http" 74 protocol = "https" if os.environ.get("SECURE", False) else "http"
@@ -87,12 +79,11 @@ class Tunnel(WebSocketEndpoint):
87 79
88 self.request_task = asyncio.create_task(self.handle_requests(websocket)) 80 self.request_task = asyncio.create_task(self.handle_requests(websocket))
89 81
90 async def on_receive(self, websocket: WebSocket, data: Any) -> None: 82 async def on_receive(self, websocket: WebSocket, data: Any):
91 await self.responses.put(data) 83 await self.proxy_queue.send_response(data)
92 84
93 async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: 85 async def on_disconnect(self, websocket: WebSocket, close_code: int):
94 if self.config is not None and self.config['subdomain'] in connections: 86 await self.proxy_queue.delete()
95 del connections[self.config['subdomain']]
96 87
97 if self.request_task is not None: 88 if self.request_task is not None:
98 self.request_task.cancel() 89 self.request_task.cancel()