summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Tom van der Lee <tom@vanderlee.io>2022-01-17 19:41:55 +0100
committerGravatar Tom van der Lee <tom@vanderlee.io>2022-01-17 19:41:55 +0100
commit5a38ebf365bfa0718dcbd7ab013af5f2da4610f6 (patch)
tree72d762c3f0e081c24239522a6281425789e2e608
parent46af86f8ace136dd1d1d94590d3423e6b12e3f7b (diff)
downloadserver-1.1.0-rc1.tar.gz
server-1.1.0-rc1.tar.bz2
server-1.1.0-rc1.zip
Added scaling support via redisv1.1.0-rc1
-rw-r--r--requirements.txt1
-rw-r--r--ttun_server/__init__.py5
-rw-r--r--ttun_server/connections.py3
-rw-r--r--ttun_server/endpoints.py43
-rw-r--r--ttun_server/proxy_queue.py163
-rw-r--r--ttun_server/redis.py20
-rw-r--r--ttun_server/types.py2
7 files changed, 207 insertions, 30 deletions
diff --git a/requirements.txt b/requirements.txt
index 95f7ad2..34c860e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,2 +1,3 @@
1starlette ~= 0.17 1starlette ~= 0.17
2uvicorn[standard] ~= 0.16 2uvicorn[standard] ~= 0.16
3aioredis ~= 2.0
diff --git a/ttun_server/__init__.py b/ttun_server/__init__.py
index b8fd114..cf589cc 100644
--- a/ttun_server/__init__.py
+++ b/ttun_server/__init__.py
@@ -1,8 +1,13 @@
1import logging
2import os
3
1from starlette.applications import Starlette 4from starlette.applications import Starlette
2from starlette.routing import Route, WebSocketRoute 5from starlette.routing import Route, WebSocketRoute
3 6
4from ttun_server.endpoints import Proxy, Tunnel 7from ttun_server.endpoints import Proxy, Tunnel
5 8
9logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO')))
10
6server = Starlette( 11server = Starlette(
7 debug=True, 12 debug=True,
8 routes=[ 13 routes=[
diff --git a/ttun_server/connections.py b/ttun_server/connections.py
deleted file mode 100644
index a8dabcf..0000000
--- a/ttun_server/connections.py
+++ /dev/null
@@ -1,3 +0,0 @@
1from ttun_server.types import Connection
2
3connections: dict[str, Connection] = {}
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py
index d59cb7c..5b9e57f 100644
--- a/ttun_server/endpoints.py
+++ b/ttun_server/endpoints.py
@@ -1,4 +1,5 @@
1import asyncio 1import asyncio
2import logging
2import os 3import os
3from asyncio import Queue 4from asyncio import Queue
4from base64 import b64decode, b64encode 5from base64 import b64decode, b64encode
@@ -11,9 +12,10 @@ from starlette.responses import Response
11from starlette.types import Scope, Receive, Send 12from starlette.types import Scope, Receive, Send
12from starlette.websockets import WebSocket 13from starlette.websockets import WebSocket
13 14
14from ttun_server.types import Connection, RequestData, Config, ResponseData 15from ttun_server.proxy_queue import ProxyQueue
16from ttun_server.types import RequestData, Config, ResponseData
15 17
16from ttun_server.connections import connections 18logger = logging.getLogger(__name__)
17 19
18 20
19class Proxy(HTTPEndpoint): 21class Proxy(HTTPEndpoint):
@@ -23,10 +25,10 @@ class Proxy(HTTPEndpoint):
23 [subdomain, *_] = request.headers['host'].split('.') 25 [subdomain, *_] = request.headers['host'].split('.')
24 response = Response(content='Not Found', status_code=404) 26 response = Response(content='Not Found', status_code=404)
25 27
26 if subdomain in connections: 28 try:
27 connection = connections[subdomain] 29 queue = await ProxyQueue.get_for_identifier(subdomain)
28 30
29 await connection['requests'].put(RequestData( 31 await queue.send_request(RequestData(
30 method=request.method, 32 method=request.method,
31 path=str(request.url).replace(str(request.base_url), '/'), 33 path=str(request.url).replace(str(request.base_url), '/'),
32 headers=dict(request.headers), 34 headers=dict(request.headers),
@@ -34,12 +36,14 @@ class Proxy(HTTPEndpoint):
34 body=b64encode(await request.body()).decode() 36 body=b64encode(await request.body()).decode()
35 )) 37 ))
36 38
37 _response = await connection['responses'].get() 39 _response = await queue.handle_response()
38 response = Response( 40 response = Response(
39 status_code=_response['status'], 41 status_code=_response['status'],
40 headers=_response['headers'], 42 headers=_response['headers'],
41 content=b64decode(_response['body'].encode()) 43 content=b64decode(_response['body'].encode())
42 ) 44 )
45 except AssertionError:
46 pass
43 47
44 await response(self.scope, self.receive, self.send) 48 await response(self.scope, self.receive, self.send)
45 49
@@ -52,16 +56,8 @@ class Tunnel(WebSocketEndpoint):
52 self.request_task = None 56 self.request_task = None
53 self.config: Optional[Config] = None 57 self.config: Optional[Config] = None
54 58
55 @property
56 def requests(self) -> Queue[RequestData]:
57 return connections[self.config['subdomain']]['requests']
58
59 @property
60 def responses(self) -> Queue[ResponseData]:
61 return connections[self.config['subdomain']]['responses']
62
63 async def handle_requests(self, websocket: WebSocket): 59 async def handle_requests(self, websocket: WebSocket):
64 while request := await self.requests.get(): 60 while request := await self.proxy_queue.handle_request():
65 await websocket.send_json(request) 61 await websocket.send_json(request)
66 62
67 async def on_connect(self, websocket: WebSocket) -> None: 63 async def on_connect(self, websocket: WebSocket) -> None:
@@ -69,14 +65,10 @@ class Tunnel(WebSocketEndpoint):
69 self.config = await websocket.receive_json() 65 self.config = await websocket.receive_json()
70 66
71 if self.config['subdomain'] is None \ 67 if self.config['subdomain'] is None \
72 or self.config['subdomain'] in connections: 68 or await ProxyQueue.has_connection(self.config['subdomain']):
73 self.config['subdomain'] = uuid4().hex 69 self.config['subdomain'] = uuid4().hex
74 70
75 71 self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])
76 connections[self.config['subdomain']] = Connection(
77 requests=Queue(),
78 responses=Queue(),
79 )
80 72
81 hostname = os.environ.get("TUNNEL_DOMAIN") 73 hostname = os.environ.get("TUNNEL_DOMAIN")
82 protocol = "https" if os.environ.get("SECURE", False) else "http" 74 protocol = "https" if os.environ.get("SECURE", False) else "http"
@@ -87,12 +79,11 @@ class Tunnel(WebSocketEndpoint):
87 79
88 self.request_task = asyncio.create_task(self.handle_requests(websocket)) 80 self.request_task = asyncio.create_task(self.handle_requests(websocket))
89 81
90 async def on_receive(self, websocket: WebSocket, data: Any) -> None: 82 async def on_receive(self, websocket: WebSocket, data: Any):
91 await self.responses.put(data) 83 await self.proxy_queue.send_response(data)
92 84
93 async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: 85 async def on_disconnect(self, websocket: WebSocket, close_code: int):
94 if self.config is not None and self.config['subdomain'] in connections: 86 await self.proxy_queue.delete()
95 del connections[self.config['subdomain']]
96 87
97 if self.request_task is not None: 88 if self.request_task is not None:
98 self.request_task.cancel() 89 self.request_task.cancel()
diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py
new file mode 100644
index 0000000..07e16e0
--- /dev/null
+++ b/ttun_server/proxy_queue.py
@@ -0,0 +1,163 @@
1import asyncio
2import json
3import logging
4import os
5from typing import Awaitable, Callable
6
7from ttun_server.redis import RedisConnectionPool
8from ttun_server.types import RequestData, ResponseData, MemoryConnection
9
10logger = logging.getLogger(__name__)
11
12class BaseProxyQueue:
13 def __init__(self, identifier: str):
14 self.identifier = identifier
15
16 @classmethod
17 async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
18 raise NotImplementedError(f'Please implement create_for_identifier')
19
20 @classmethod
21 async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue':
22 assert await cls.has_connection(identifier)
23 return cls(identifier)
24
25 @classmethod
26 async def has_connection(cls, identifier) -> bool:
27 raise NotImplementedError(f'Please implement has_connection')
28
29 async def send_request(self, request_data: RequestData):
30 raise NotImplementedError(f'Please implement send_request')
31
32 async def handle_request(self) -> RequestData:
33 raise NotImplementedError(f'Please implement handle_requests')
34
35 async def send_response(self, response_data: ResponseData):
36 raise NotImplementedError(f'Please implement send_request')
37
38 async def handle_response(self) -> ResponseData:
39 raise NotImplementedError(f'Please implement handle_response')
40
41 async def delete(self):
42 raise NotImplementedError(f'Please implement delete')
43
44class MemoryProxyQueue(BaseProxyQueue):
45 connections: dict[str, MemoryConnection] = {}
46
47 @classmethod
48 async def has_connection(cls, identifier) -> bool:
49 return identifier in cls.connections
50
51 @classmethod
52 async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue':
53 instance = cls(identifier)
54
55 cls.connections[identifier] = {
56 'requests': asyncio.Queue(),
57 'responses': asyncio.Queue(),
58 }
59
60 return instance
61
62 @property
63 def requests(self) -> asyncio.Queue[RequestData]:
64 return self.__class__.connections[self.identifier]['requests']
65
66 @property
67 def responses(self) -> asyncio.Queue[ResponseData]:
68 return self.__class__.connections[self.identifier]['responses']
69
70 async def send_request(self, request_data: RequestData):
71 await self.requests.put(request_data)