diff options
| author | 2022-01-17 19:41:55 +0100 | |
|---|---|---|
| committer | 2022-01-17 19:41:55 +0100 | |
| commit | 5a38ebf365bfa0718dcbd7ab013af5f2da4610f6 (patch) | |
| tree | 72d762c3f0e081c24239522a6281425789e2e608 /ttun_server | |
| parent | 46af86f8ace136dd1d1d94590d3423e6b12e3f7b (diff) | |
| download | server-1.1.0-rc1.tar.gz server-1.1.0-rc1.tar.bz2 server-1.1.0-rc1.zip | |
Added scaling support via redisv1.1.0-rc1
Diffstat (limited to 'ttun_server')
| -rw-r--r-- | ttun_server/__init__.py | 5 | ||||
| -rw-r--r-- | ttun_server/connections.py | 3 | ||||
| -rw-r--r-- | ttun_server/endpoints.py | 43 | ||||
| -rw-r--r-- | ttun_server/proxy_queue.py | 163 | ||||
| -rw-r--r-- | ttun_server/redis.py | 20 | ||||
| -rw-r--r-- | ttun_server/types.py | 2 |
6 files changed, 206 insertions, 30 deletions
diff --git a/ttun_server/__init__.py b/ttun_server/__init__.py index b8fd114..cf589cc 100644 --- a/ttun_server/__init__.py +++ b/ttun_server/__init__.py | |||
| @@ -1,8 +1,13 @@ | |||
| 1 | import logging | ||
| 2 | import os | ||
| 3 | |||
| 1 | from starlette.applications import Starlette | 4 | from starlette.applications import Starlette |
| 2 | from starlette.routing import Route, WebSocketRoute | 5 | from starlette.routing import Route, WebSocketRoute |
| 3 | 6 | ||
| 4 | from ttun_server.endpoints import Proxy, Tunnel | 7 | from ttun_server.endpoints import Proxy, Tunnel |
| 5 | 8 | ||
| 9 | logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO'))) | ||
| 10 | |||
| 6 | server = Starlette( | 11 | server = Starlette( |
| 7 | debug=True, | 12 | debug=True, |
| 8 | routes=[ | 13 | routes=[ |
diff --git a/ttun_server/connections.py b/ttun_server/connections.py deleted file mode 100644 index a8dabcf..0000000 --- a/ttun_server/connections.py +++ /dev/null | |||
| @@ -1,3 +0,0 @@ | |||
| 1 | from ttun_server.types import Connection | ||
| 2 | |||
| 3 | connections: dict[str, Connection] = {} | ||
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index d59cb7c..5b9e57f 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py | |||
| @@ -1,4 +1,5 @@ | |||
| 1 | import asyncio | 1 | import asyncio |
| 2 | import logging | ||
| 2 | import os | 3 | import os |
| 3 | from asyncio import Queue | 4 | from asyncio import Queue |
| 4 | from base64 import b64decode, b64encode | 5 | from base64 import b64decode, b64encode |
| @@ -11,9 +12,10 @@ from starlette.responses import Response | |||
| 11 | from starlette.types import Scope, Receive, Send | 12 | from starlette.types import Scope, Receive, Send |
| 12 | from starlette.websockets import WebSocket | 13 | from starlette.websockets import WebSocket |
| 13 | 14 | ||
| 14 | from ttun_server.types import Connection, RequestData, Config, ResponseData | 15 | from ttun_server.proxy_queue import ProxyQueue |
| 16 | from ttun_server.types import RequestData, Config, ResponseData | ||
| 15 | 17 | ||
| 16 | from ttun_server.connections import connections | 18 | logger = logging.getLogger(__name__) |
| 17 | 19 | ||
| 18 | 20 | ||
| 19 | class Proxy(HTTPEndpoint): | 21 | class Proxy(HTTPEndpoint): |
| @@ -23,10 +25,10 @@ class Proxy(HTTPEndpoint): | |||
| 23 | [subdomain, *_] = request.headers['host'].split('.') | 25 | [subdomain, *_] = request.headers['host'].split('.') |
| 24 | response = Response(content='Not Found', status_code=404) | 26 | response = Response(content='Not Found', status_code=404) |
| 25 | 27 | ||
| 26 | if subdomain in connections: | 28 | try: |
| 27 | connection = connections[subdomain] | 29 | queue = await ProxyQueue.get_for_identifier(subdomain) |
| 28 | 30 | ||
| 29 | await connection['requests'].put(RequestData( | 31 | await queue.send_request(RequestData( |
| 30 | method=request.method, | 32 | method=request.method, |
| 31 | path=str(request.url).replace(str(request.base_url), '/'), | 33 | path=str(request.url).replace(str(request.base_url), '/'), |
| 32 | headers=dict(request.headers), | 34 | headers=dict(request.headers), |
| @@ -34,12 +36,14 @@ class Proxy(HTTPEndpoint): | |||
| 34 | body=b64encode(await request.body()).decode() | 36 | body=b64encode(await request.body()).decode() |
| 35 | )) | 37 | )) |
| 36 | 38 | ||
| 37 | _response = await connection['responses'].get() | 39 | _response = await queue.handle_response() |
| 38 | response = Response( | 40 | response = Response( |
| 39 | status_code=_response['status'], | 41 | status_code=_response['status'], |
| 40 | headers=_response['headers'], | 42 | headers=_response['headers'], |
| 41 | content=b64decode(_response['body'].encode()) | 43 | content=b64decode(_response['body'].encode()) |
| 42 | ) | 44 | ) |
| 45 | except AssertionError: | ||
| 46 | pass | ||
| 43 | 47 | ||
| 44 | await response(self.scope, self.receive, self.send) | 48 | await response(self.scope, self.receive, self.send) |
| 45 | 49 | ||
| @@ -52,16 +56,8 @@ class Tunnel(WebSocketEndpoint): | |||
| 52 | self.request_task = None | 56 | self.request_task = None |
| 53 | self.config: Optional[Config] = None | 57 | self.config: Optional[Config] = None |
| 54 | 58 | ||
| 55 | @property | ||
| 56 | def requests(self) -> Queue[RequestData]: | ||
| 57 | return connections[self.config['subdomain']]['requests'] | ||
| 58 | |||
| 59 | @property | ||
| 60 | def responses(self) -> Queue[ResponseData]: | ||
| 61 | return connections[self.config['subdomain']]['responses'] | ||
| 62 | |||
| 63 | async def handle_requests(self, websocket: WebSocket): | 59 | async def handle_requests(self, websocket: WebSocket): |
| 64 | while request := await self.requests.get(): | 60 | while request := await self.proxy_queue.handle_request(): |
| 65 | await websocket.send_json(request) | 61 | await websocket.send_json(request) |
| 66 | 62 | ||
| 67 | async def on_connect(self, websocket: WebSocket) -> None: | 63 | async def on_connect(self, websocket: WebSocket) -> None: |
| @@ -69,14 +65,10 @@ class Tunnel(WebSocketEndpoint): | |||
| 69 | self.config = await websocket.receive_json() | 65 | self.config = await websocket.receive_json() |
| 70 | 66 | ||
| 71 | if self.config['subdomain'] is None \ | 67 | if self.config['subdomain'] is None \ |
| 72 | or self.config['subdomain'] in connections: | 68 | or await ProxyQueue.has_connection(self.config['subdomain']): |
| 73 | self.config['subdomain'] = uuid4().hex | 69 | self.config['subdomain'] = uuid4().hex |
| 74 | 70 | ||
| 75 | 71 | self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) | |
| 76 | connections[self.config['subdomain']] = Connection( | ||
| 77 | requests=Queue(), | ||
| 78 | responses=Queue(), | ||
| 79 | ) | ||
| 80 | 72 | ||
| 81 | hostname = os.environ.get("TUNNEL_DOMAIN") | 73 | hostname = os.environ.get("TUNNEL_DOMAIN") |
| 82 | protocol = "https" if os.environ.get("SECURE", False) else "http" | 74 | protocol = "https" if os.environ.get("SECURE", False) else "http" |
| @@ -87,12 +79,11 @@ class Tunnel(WebSocketEndpoint): | |||
| 87 | 79 | ||
| 88 | self.request_task = asyncio.create_task(self.handle_requests(websocket)) | 80 | self.request_task = asyncio.create_task(self.handle_requests(websocket)) |
| 89 | 81 | ||
| 90 | async def on_receive(self, websocket: WebSocket, data: Any) -> None: | 82 | async def on_receive(self, websocket: WebSocket, data: Any): |
| 91 | await self.responses.put(data) | 83 | await self.proxy_queue.send_response(data) |
| 92 | 84 | ||
| 93 | async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: | 85 | async def on_disconnect(self, websocket: WebSocket, close_code: int): |
| 94 | if self.config is not None and self.config['subdomain'] in connections: | 86 | await self.proxy_queue.delete() |
| 95 | del connections[self.config['subdomain']] | ||
| 96 | 87 | ||
| 97 | if self.request_task is not None: | 88 | if self.request_task is not None: |
| 98 | self.request_task.cancel() | 89 | self.request_task.cancel() |
diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py new file mode 100644 index 0000000..07e16e0 --- /dev/null +++ b/ttun_server/proxy_queue.py | |||
| @@ -0,0 +1,163 @@ | |||
| 1 | import asyncio | ||
| 2 | import json | ||
| 3 | import logging | ||
| 4 | import os | ||
| 5 | from typing import Awaitable, Callable | ||
| 6 | |||
| 7 | from ttun_server.redis import RedisConnectionPool | ||
| 8 | from ttun_server.types import RequestData, ResponseData, MemoryConnection | ||
| 9 | |||
| 10 | logger = logging.getLogger(__name__) | ||
| 11 | |||
| 12 | class BaseProxyQueue: | ||
| 13 | def __init__(self, identifier: str): | ||
| 14 | self.identifier = identifier | ||
| 15 | |||
| 16 | @classmethod | ||
| 17 | async def create_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': | ||
| 18 | raise NotImplementedError(f'Please implement create_for_identifier') | ||
| 19 | |||
| 20 | @classmethod | ||
| 21 | async def get_for_identifier(cls, identifier: str) -> 'BaseProxyQueue': | ||
| 22 | assert await cls.has_connection(identifier) | ||
| 23 | return cls(identifier) | ||
| 24 | |||
| 25 | @classmethod | ||
| 26 | async def has_connection(cls, identifier) -> bool: | ||
| 27 | raise NotImplementedError(f'Please implement has_connection') | ||
| 28 | |||
| 29 | async def send_request(self, request_data: RequestData): | ||
| 30 | raise NotImplementedError(f'Please implement send_request') | ||
| 31 | |||
| 32 | async def handle_request(self) -> RequestData: | ||
| 33 | raise NotImplementedError(f'Please implement handle_requests') | ||
| 34 | |||
| 35 | async def send_response(self, response_data: ResponseData): | ||
| 36 | raise NotImplementedError(f'Please implement send_request') | ||
| 37 | |||
| 38 | async def handle_response(self) -> ResponseData: | ||
| 39 | raise NotImplementedError(f'Please implement handle_response') | ||
| 40 | |||
| 41 | async def delete(self): | ||
| 42 | raise NotImplementedError(f'Please implement delete') | ||
| 43 | |||
| 44 | class MemoryProxyQueue(BaseProxyQueue): | ||
| 45 | connections: dict[str, MemoryConnection] = {} | ||
| 46 | |||
| 47 | @classmethod | ||
| 48 | async def has_connection(cls, identifier) -> bool: | ||
| 49 | return identifier in cls.connections | ||
| 50 | |||
| 51 | @classmethod | ||
| 52 | async def create_for_identifier(cls, identifier: str) -> 'MemoryProxyQueue': | ||
| 53 | instance = cls(identifier) | ||
| 54 | |||
| 55 | cls.connections[identifier] = { | ||
| 56 | 'requests': asyncio.Queue(), | ||
| 57 | 'responses': asyncio.Queue(), | ||
| 58 | } | ||
| 59 | |||
| 60 | return instance | ||
| 61 | |||
| 62 | @property | ||
| 63 | def requests(self) -> asyncio.Queue[RequestData]: | ||
| 64 | return self.__class__.connections[self.identifier]['requests'] | ||
| 65 | |||
| 66 | @property | ||
| 67 | def responses(self) -> asyncio.Queue[ResponseData]: | ||
| 68 | return self.__class__.connections[self.identifier]['responses'] | ||
| 69 | |||
| 70 | async def send_request(self, request_data: RequestData): | ||
| 71 | await self.requests.put(request_data) | ||
| 72 | |||
| 73 | async def handle_request(self) -> RequestData: | ||
| 74 | return await self.requests.get() | ||
| 75 | |||
| 76 | async def send_response(self, response_data: ResponseData): | ||
| 77 | return await self.responses.put(response_data) | ||
