diff options
| author | 2024-08-30 15:54:40 +0200 | |
|---|---|---|
| committer | 2024-08-30 15:54:40 +0200 | |
| commit | 0f7c975fe61dab4efb11b49ddc87331c30c26942 (patch) | |
| tree | 68f4b351b337b9a2269ddb2cb512016c93e7cbbc | |
| parent | 53a8f300859a50d9f99f1821c35bca999fced6d8 (diff) | |
| parent | a72a0485ef8761b95c73cc420723247fafbb6f1c (diff) | |
| download | server-2.1.0.tar.gz server-2.1.0.tar.bz2 server-2.1.0.zip | |
Added websocket support
| -rw-r--r-- | .github/workflows/docker-image.yml | 8 | ||||
| -rw-r--r-- | requirements.txt | 5 | ||||
| -rw-r--r-- | ttun_server/__init__.py | 4 | ||||
| -rw-r--r-- | ttun_server/endpoints.py | 75 | ||||
| -rw-r--r-- | ttun_server/proxy_queue.py | 1 | ||||
| -rw-r--r-- | ttun_server/redis.py | 7 | ||||
| -rw-r--r-- | ttun_server/types.py | 51 | ||||
| -rw-r--r-- | ttun_server/websockets.py | 179 |
8 files changed, 243 insertions, 87 deletions
diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 3f83351..c8c8930 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml | |||
| @@ -13,10 +13,10 @@ jobs: | |||
| 13 | runs-on: ubuntu-latest | 13 | runs-on: ubuntu-latest |
| 14 | steps: | 14 | steps: |
| 15 | - name: Checkout | 15 | - name: Checkout |
| 16 | uses: actions/checkout@v2 | 16 | uses: actions/checkout@v4 |
| 17 | - name: Docker meta | 17 | - name: Docker meta |
| 18 | id: meta | 18 | id: meta |
| 19 | uses: docker/metadata-action@v4 | 19 | uses: docker/metadata-action@v5 |
| 20 | with: | 20 | with: |
| 21 | images: ghcr.io/tomvanderlee/ttun-server | 21 | images: ghcr.io/tomvanderlee/ttun-server |
| 22 | tags: | | 22 | tags: | |
| @@ -25,13 +25,13 @@ jobs: | |||
| 25 | 25 | ||
| 26 | - name: Login to DockerHub | 26 | - name: Login to DockerHub |
| 27 | if: github.event_name != 'pull_request' | 27 | if: github.event_name != 'pull_request' |
| 28 | uses: docker/login-action@v1 | 28 | uses: docker/login-action@v3 |
| 29 | with: | 29 | with: |
| 30 | registry: ghcr.io | 30 | registry: ghcr.io |
| 31 | username: ${{ github.actor }} | 31 | username: ${{ github.actor }} |
| 32 | password: ${{ secrets.GITHUB_TOKEN }} | 32 | password: ${{ secrets.GITHUB_TOKEN }} |
| 33 | - name: Build and push | 33 | - name: Build and push |
| 34 | uses: docker/build-push-action@v4 | 34 | uses: docker/build-push-action@v6 |
| 35 | with: | 35 | with: |
| 36 | context: . | 36 | context: . |
| 37 | push: ${{ github.event_name != 'pull_request' }} | 37 | push: ${{ github.event_name != 'pull_request' }} |
diff --git a/requirements.txt b/requirements.txt index 34c860e..1ab8e8f 100644 --- a/requirements.txt +++ b/requirements.txt | |||
| @@ -1,3 +1,4 @@ | |||
| 1 | starlette ~= 0.17 | 1 | starlette ~= 0.37 |
| 2 | uvicorn[standard] ~= 0.16 | 2 | uvicorn[standard] ~= 0.16 |
| 3 | aioredis ~= 2.0 | 3 | redis |
| 4 | setuptools | ||
diff --git a/ttun_server/__init__.py b/ttun_server/__init__.py index 81f8cd4..2f8fed0 100644 --- a/ttun_server/__init__.py +++ b/ttun_server/__init__.py | |||
| @@ -4,7 +4,8 @@ import os | |||
| 4 | from starlette.applications import Starlette | 4 | from starlette.applications import Starlette |
| 5 | from starlette.routing import Route, WebSocketRoute, Host, Router | 5 | from starlette.routing import Route, WebSocketRoute, Host, Router |
| 6 | 6 | ||
| 7 | from ttun_server.endpoints import Proxy, Tunnel, Health | 7 | from ttun_server.endpoints import Proxy, Health |
| 8 | from .websockets import WebsocketProxy, Tunnel | ||
| 8 | 9 | ||
| 9 | logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO'))) | 10 | logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO'))) |
| 10 | 11 | ||
| @@ -18,6 +19,7 @@ server = Starlette( | |||
| 18 | routes=[ | 19 | routes=[ |
| 19 | Host(os.environ['TUNNEL_DOMAIN'], base_router, 'base'), | 20 | Host(os.environ['TUNNEL_DOMAIN'], base_router, 'base'), |
| 20 | Route('/{path:path}', Proxy), | 21 | Route('/{path:path}', Proxy), |
| 22 | WebSocketRoute('/{path:path}', WebsocketProxy) | ||
| 21 | ] | 23 | ] |
| 22 | ) | 24 | ) |
| 23 | 25 | ||
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py index 3e263da..eae0ebe 100644 --- a/ttun_server/endpoints.py +++ b/ttun_server/endpoints.py | |||
| @@ -1,20 +1,13 @@ | |||
| 1 | import asyncio | ||
| 2 | import logging | 1 | import logging |
| 3 | import os | ||
| 4 | from asyncio import create_task | ||
| 5 | from base64 import b64decode, b64encode | 2 | from base64 import b64decode, b64encode |
| 6 | from typing import Optional, Any | ||
| 7 | from uuid import uuid4 | 3 | from uuid import uuid4 |
| 8 | 4 | ||
| 9 | from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint | 5 | from starlette.endpoints import HTTPEndpoint |
| 10 | from starlette.requests import Request | 6 | from starlette.requests import Request |
| 11 | from starlette.responses import Response | 7 | from starlette.responses import Response |
| 12 | from starlette.types import Scope, Receive, Send | ||
| 13 | from starlette.websockets import WebSocket | ||
| 14 | 8 | ||
| 15 | import ttun_server | ||
| 16 | from ttun_server.proxy_queue import ProxyQueue | 9 | from ttun_server.proxy_queue import ProxyQueue |
| 17 | from ttun_server.types import RequestData, Config, Message, MessageType | 10 | from ttun_server.types import HttpRequestData, Message, HttpMessageType, HttpMessage |
| 18 | 11 | ||
| 19 | logger = logging.getLogger(__name__) | 12 | logger = logging.getLogger(__name__) |
| 20 | 13 | ||
| @@ -44,11 +37,11 @@ class Proxy(HTTPEndpoint): | |||
| 44 | 37 | ||
| 45 | logger.debug('PROXY %s%s ', subdomain, request.url) | 38 | logger.debug('PROXY %s%s ', subdomain, request.url) |
| 46 | await request_queue.enqueue( | 39 | await request_queue.enqueue( |
| 47 | Message( | 40 | HttpMessage( |
| 48 | type=MessageType.request.value, | 41 | type=HttpMessageType.request.value, |
| 49 | identifier=identifier, | 42 | identifier=identifier, |
| 50 | payload= | 43 | payload= |
| 51 | RequestData( | 44 | HttpRequestData( |
| 52 | method=request.method, | 45 | method=request.method, |
| 53 | path=str(request.url).replace(str(request.base_url), '/'), | 46 | path=str(request.url).replace(str(request.base_url), '/'), |
| 54 | headers=list(request.headers.items()), | 47 | headers=list(request.headers.items()), |
| @@ -78,61 +71,3 @@ class Health(HTTPEndpoint): | |||
| 78 | await response(self.scope, self.receive, self.send) | 71 | await response(self.scope, self.receive, self.send) |
| 79 | 72 | ||
| 80 | 73 | ||
| 81 | class Tunnel(WebSocketEndpoint): | ||
| 82 | encoding = 'json' | ||
| 83 | |||
| 84 | def __init__(self, scope: Scope, receive: Receive, send: Send): | ||
| 85 | super().__init__(scope, receive, send) | ||
| 86 | self.request_task = None | ||
| 87 | self.config: Optional[Config] = None | ||
| 88 | |||
| 89 | async def handle_requests(self, websocket: WebSocket): | ||
| 90 | while request := await self.proxy_queue.dequeue(): | ||
| 91 | create_task(websocket.send_json(request)) | ||
| 92 | |||
| 93 | async def on_connect(self, websocket: WebSocket) -> None: | ||
| 94 | await websocket.accept() | ||
| 95 | self.config = await websocket.receive_json() | ||
| 96 | |||
| 97 | client_version = self.config.get('version', '1.0.0') | ||
| 98 | logger.debug('client_version %s', client_version) | ||
| 99 | |||
| 100 | if 'git' not in client_version and ttun_server.__version__ != 'development': | ||
| 101 | [client_major, *_] = [int(i) for i in client_version.split('.')[:3]] | ||
| 102 | [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')] | ||
| 103 | |||
| 104 | if client_major < server_major: | ||
| 105 | await websocket.close(4000, 'Your client is too old') | ||
| 106 | |||
| 107 | if client_major > server_major: | ||
| 108 | await websocket.close(4001, 'Your client is too new') | ||
| 109 | |||
| 110 | |||
| 111 | if self.config['subdomain'] is None \ | ||
| 112 | or await ProxyQueue.has_connection(self.config['subdomain']): | ||
| 113 | self.config['subdomain'] = uuid4().hex | ||
| 114 | |||
| 115 | |||
| 116 | self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain']) | ||
| 117 | |||
| 118 | hostname = os.environ.get("TUNNEL_DOMAIN") | ||
| 119 | protocol = "https" if os.environ.get("SECURE", False) else "http" | ||
| 120 | |||
| 121 | await websocket.send_json({ | ||
| 122 | 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}' | ||
| 123 | }) | ||
| 124 | |||
| 125 | self.request_task = asyncio.create_task(self.handle_requests(websocket)) | ||
| 126 | |||
| 127 | async def on_receive(self, websocket: WebSocket, data: Message): | ||
| 128 | try: | ||
| 129 | response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}") | ||
| 130 | await response_queue.enqueue(data) | ||
| 131 | except AssertionError: | ||
| 132 | pass | ||
| 133 | |||
| 134 | async def on_disconnect(self, websocket: WebSocket, close_code: int): | ||
| 135 | await self.proxy_queue.delete() | ||
| 136 | |||
| 137 | if self.request_task is not None: | ||
| 138 | self.request_task.cancel() | ||
diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py index e521886..c6c8067 100644 --- a/ttun_server/proxy_queue.py +++ b/ttun_server/proxy_queue.py | |||
| @@ -2,6 +2,7 @@ import asyncio | |||
| 2 | import json | 2 | import json |
| 3 | import logging | 3 | import logging |
| 4 | import os | 4 | import os |
| 5 | import traceback | ||
| 5 | from typing import Type | 6 | from typing import Type |
| 6 | 7 | ||
| 7 | from ttun_server.redis import RedisConnectionPool | 8 | from ttun_server.redis import RedisConnectionPool |
diff --git a/ttun_server/redis.py b/ttun_server/redis.py index 3065dec..18fbca2 100644 --- a/ttun_server/redis.py +++ b/ttun_server/redis.py | |||
| @@ -1,6 +1,8 @@ | |||
| 1 | import asyncio | ||
| 1 | import os | 2 | import os |
| 3 | from asyncio import get_running_loop | ||
| 2 | 4 | ||
| 3 | from aioredis import ConnectionPool, Redis | 5 | from redis.asyncio import ConnectionPool, Redis |
| 4 | 6 | ||
| 5 | 7 | ||
| 6 | class RedisConnectionPool: | 8 | class RedisConnectionPool: |
| @@ -9,9 +11,6 @@ class RedisConnectionPool: | |||
| 9 | def __init__(self): | 11 | def __init__(self): |
| 10 | self.pool = ConnectionPool.from_url(os.environ.get('REDIS_URL')) | 12 | self.pool = ConnectionPool.from_url(os.environ.get('REDIS_URL')) |
| 11 | 13 | ||
| 12 | def __del__(self): | ||
| 13 | self.pool.disconnect() | ||
| 14 | |||
| 15 | @classmethod | 14 | @classmethod |
| 16 | def get_connection(cls) -> Redis: | 15 | def get_connection(cls) -> Redis: |
| 17 | if cls.instance is None: | 16 | if cls.instance is None: |
diff --git a/ttun_server/types.py b/ttun_server/types.py index 8a4d929..8591e7d 100644 --- a/ttun_server/types.py +++ b/ttun_server/types.py | |||
| @@ -3,7 +3,7 @@ from enum import Enum | |||
| 3 | from typing import TypedDict, Optional | 3 | from typing import TypedDict, Optional |
| 4 | 4 | ||
| 5 | 5 | ||
| 6 | class MessageType(Enum): | 6 | class HttpMessageType(Enum): |
| 7 | request = 'request' | ||
