summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Tom van der Lee <tomvanderlee@users.noreply.github.com>2024-08-30 15:54:40 +0200
committerGravatar GitHub <noreply@github.com>2024-08-30 15:54:40 +0200
commit0f7c975fe61dab4efb11b49ddc87331c30c26942 (patch)
tree68f4b351b337b9a2269ddb2cb512016c93e7cbbc
parent53a8f300859a50d9f99f1821c35bca999fced6d8 (diff)
parenta72a0485ef8761b95c73cc420723247fafbb6f1c (diff)
downloadserver-0f7c975fe61dab4efb11b49ddc87331c30c26942.tar.gz
server-0f7c975fe61dab4efb11b49ddc87331c30c26942.tar.bz2
server-0f7c975fe61dab4efb11b49ddc87331c30c26942.zip
Merge pull request #6 from tomvanderlee/feature/websocketsv2.1.0main
Added websocket support
-rw-r--r--.github/workflows/docker-image.yml8
-rw-r--r--requirements.txt5
-rw-r--r--ttun_server/__init__.py4
-rw-r--r--ttun_server/endpoints.py75
-rw-r--r--ttun_server/proxy_queue.py1
-rw-r--r--ttun_server/redis.py7
-rw-r--r--ttun_server/types.py51
-rw-r--r--ttun_server/websockets.py179
8 files changed, 243 insertions, 87 deletions
diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml
index 3f83351..c8c8930 100644
--- a/.github/workflows/docker-image.yml
+++ b/.github/workflows/docker-image.yml
@@ -13,10 +13,10 @@ jobs:
13 runs-on: ubuntu-latest 13 runs-on: ubuntu-latest
14 steps: 14 steps:
15 - name: Checkout 15 - name: Checkout
16 uses: actions/checkout@v2 16 uses: actions/checkout@v4
17 - name: Docker meta 17 - name: Docker meta
18 id: meta 18 id: meta
19 uses: docker/metadata-action@v4 19 uses: docker/metadata-action@v5
20 with: 20 with:
21 images: ghcr.io/tomvanderlee/ttun-server 21 images: ghcr.io/tomvanderlee/ttun-server
22 tags: | 22 tags: |
@@ -25,13 +25,13 @@ jobs:
25 25
26 - name: Login to DockerHub 26 - name: Login to DockerHub
27 if: github.event_name != 'pull_request' 27 if: github.event_name != 'pull_request'
28 uses: docker/login-action@v1 28 uses: docker/login-action@v3
29 with: 29 with:
30 registry: ghcr.io 30 registry: ghcr.io
31 username: ${{ github.actor }} 31 username: ${{ github.actor }}
32 password: ${{ secrets.GITHUB_TOKEN }} 32 password: ${{ secrets.GITHUB_TOKEN }}
33 - name: Build and push 33 - name: Build and push
34 uses: docker/build-push-action@v4 34 uses: docker/build-push-action@v6
35 with: 35 with:
36 context: . 36 context: .
37 push: ${{ github.event_name != 'pull_request' }} 37 push: ${{ github.event_name != 'pull_request' }}
diff --git a/requirements.txt b/requirements.txt
index 34c860e..1ab8e8f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
1starlette ~= 0.17 1starlette ~= 0.37
2uvicorn[standard] ~= 0.16 2uvicorn[standard] ~= 0.16
3aioredis ~= 2.0 3redis
4setuptools
diff --git a/ttun_server/__init__.py b/ttun_server/__init__.py
index 81f8cd4..2f8fed0 100644
--- a/ttun_server/__init__.py
+++ b/ttun_server/__init__.py
@@ -4,7 +4,8 @@ import os
4from starlette.applications import Starlette 4from starlette.applications import Starlette
5from starlette.routing import Route, WebSocketRoute, Host, Router 5from starlette.routing import Route, WebSocketRoute, Host, Router
6 6
7from ttun_server.endpoints import Proxy, Tunnel, Health 7from ttun_server.endpoints import Proxy, Health
8from .websockets import WebsocketProxy, Tunnel
8 9
9logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO'))) 10logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO')))
10 11
@@ -18,6 +19,7 @@ server = Starlette(
18 routes=[ 19 routes=[
19 Host(os.environ['TUNNEL_DOMAIN'], base_router, 'base'), 20 Host(os.environ['TUNNEL_DOMAIN'], base_router, 'base'),
20 Route('/{path:path}', Proxy), 21 Route('/{path:path}', Proxy),
22 WebSocketRoute('/{path:path}', WebsocketProxy)
21 ] 23 ]
22) 24)
23 25
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py
index 3e263da..eae0ebe 100644
--- a/ttun_server/endpoints.py
+++ b/ttun_server/endpoints.py
@@ -1,20 +1,13 @@
1import asyncio
2import logging 1import logging
3import os
4from asyncio import create_task
5from base64 import b64decode, b64encode 2from base64 import b64decode, b64encode
6from typing import Optional, Any
7from uuid import uuid4 3from uuid import uuid4
8 4
9from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint 5from starlette.endpoints import HTTPEndpoint
10from starlette.requests import Request 6from starlette.requests import Request
11from starlette.responses import Response 7from starlette.responses import Response
12from starlette.types import Scope, Receive, Send
13from starlette.websockets import WebSocket
14 8
15import ttun_server
16from ttun_server.proxy_queue import ProxyQueue 9from ttun_server.proxy_queue import ProxyQueue
17from ttun_server.types import RequestData, Config, Message, MessageType 10from ttun_server.types import HttpRequestData, Message, HttpMessageType, HttpMessage
18 11
19logger = logging.getLogger(__name__) 12logger = logging.getLogger(__name__)
20 13
@@ -44,11 +37,11 @@ class Proxy(HTTPEndpoint):
44 37
45 logger.debug('PROXY %s%s ', subdomain, request.url) 38 logger.debug('PROXY %s%s ', subdomain, request.url)
46 await request_queue.enqueue( 39 await request_queue.enqueue(
47 Message( 40 HttpMessage(
48 type=MessageType.request.value, 41 type=HttpMessageType.request.value,
49 identifier=identifier, 42 identifier=identifier,
50 payload= 43 payload=
51 RequestData( 44 HttpRequestData(
52 method=request.method, 45 method=request.method,
53 path=str(request.url).replace(str(request.base_url), '/'), 46 path=str(request.url).replace(str(request.base_url), '/'),
54 headers=list(request.headers.items()), 47 headers=list(request.headers.items()),
@@ -78,61 +71,3 @@ class Health(HTTPEndpoint):
78 await response(self.scope, self.receive, self.send) 71 await response(self.scope, self.receive, self.send)
79 72
80 73
81class Tunnel(WebSocketEndpoint):
82 encoding = 'json'
83
84 def __init__(self, scope: Scope, receive: Receive, send: Send):
85 super().__init__(scope, receive, send)
86 self.request_task = None
87 self.config: Optional[Config] = None
88
89 async def handle_requests(self, websocket: WebSocket):
90 while request := await self.proxy_queue.dequeue():
91 create_task(websocket.send_json(request))
92
93 async def on_connect(self, websocket: WebSocket) -> None:
94 await websocket.accept()
95 self.config = await websocket.receive_json()
96
97 client_version = self.config.get('version', '1.0.0')
98 logger.debug('client_version %s', client_version)
99
100 if 'git' not in client_version and ttun_server.__version__ != 'development':
101 [client_major, *_] = [int(i) for i in client_version.split('.')[:3]]
102 [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')]
103
104 if client_major < server_major:
105 await websocket.close(4000, 'Your client is too old')
106
107 if client_major > server_major:
108 await websocket.close(4001, 'Your client is too new')
109
110
111 if self.config['subdomain'] is None \
112 or await ProxyQueue.has_connection(self.config['subdomain']):
113 self.config['subdomain'] = uuid4().hex
114
115
116 self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])
117
118 hostname = os.environ.get("TUNNEL_DOMAIN")
119 protocol = "https" if os.environ.get("SECURE", False) else "http"
120
121 await websocket.send_json({
122 'url': f'{protocol}://{self.config["subdomain"]}.{hostname}'
123 })
124
125 self.request_task = asyncio.create_task(self.handle_requests(websocket))
126
127 async def on_receive(self, websocket: WebSocket, data: Message):
128 try:
129 response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}")
130 await response_queue.enqueue(data)
131 except AssertionError:
132 pass
133
134 async def on_disconnect(self, websocket: WebSocket, close_code: int):
135 await self.proxy_queue.delete()
136
137 if self.request_task is not None:
138 self.request_task.cancel()
diff --git a/ttun_server/proxy_queue.py b/ttun_server/proxy_queue.py
index e521886..c6c8067 100644
--- a/ttun_server/proxy_queue.py
+++ b/ttun_server/proxy_queue.py
@@ -2,6 +2,7 @@ import asyncio
2import json 2import json
3import logging 3import logging
4import os 4import os
5import traceback
5from typing import Type 6from typing import Type
6 7
7from ttun_server.redis import RedisConnectionPool 8from ttun_server.redis import RedisConnectionPool
diff --git a/ttun_server/redis.py b/ttun_server/redis.py
index 3065dec..18fbca2 100644
--- a/ttun_server/redis.py
+++ b/ttun_server/redis.py
@@ -1,6 +1,8 @@
1import asyncio
1import os 2import os
3from asyncio import get_running_loop
2 4
3from aioredis import ConnectionPool, Redis 5from redis.asyncio import ConnectionPool, Redis
4 6
5 7
6class RedisConnectionPool: 8class RedisConnectionPool:
@@ -9,9 +11,6 @@ class RedisConnectionPool:
9 def __init__(self): 11 def __init__(self):
10 self.pool = ConnectionPool.from_url(os.environ.get('REDIS_URL')) 12 self.pool = ConnectionPool.from_url(os.environ.get('REDIS_URL'))
11 13
12 def __del__(self):
13 self.pool.disconnect()
14
15 @classmethod 14 @classmethod
16 def get_connection(cls) -> Redis: 15 def get_connection(cls) -> Redis:
17 if cls.instance is None: 16 if cls.instance is None:
diff --git a/ttun_server/types.py b/ttun_server/types.py
index 8a4d929..8591e7d 100644
--- a/ttun_server/types.py
+++ b/ttun_server/types.py
@@ -3,7 +3,7 @@ from enum import Enum
3from typing import TypedDict, Optional 3from