1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
import asyncio
import logging
import os
from base64 import b64decode, b64encode
from typing import Optional, Any
from uuid import uuid4
from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
from ttun_server.proxy_queue import ProxyQueue
from ttun_server.types import RequestData, Config
logger = logging.getLogger(__name__)
class HeaderMapping:
def __init__(self, headers: list[tuple[str, str]]):
self._headers = headers
def items(self):
for header in self._headers:
yield header
class Proxy(HTTPEndpoint):
async def dispatch(self) -> None:
request = Request(self.scope, self.receive)
[subdomain, *_] = request.headers['host'].split('.')
response = Response(content='Not Found', status_code=404)
try:
queue = await ProxyQueue.get_for_identifier(subdomain)
await queue.send_request(RequestData(
method=request.method,
path=str(request.url).replace(str(request.base_url), '/'),
headers=list(request.headers.items()),
body=b64encode(await request.body()).decode()
))
_response = await queue.handle_response()
response = Response(
status_code=_response['status'],
headers=HeaderMapping(_response['headers']),
content=b64decode(_response['body'].encode())
)
except AssertionError:
pass
await response(self.scope, self.receive, self.send)
class Tunnel(WebSocketEndpoint):
encoding = 'json'
def __init__(self, scope: Scope, receive: Receive, send: Send):
super().__init__(scope, receive, send)
self.request_task = None
self.config: Optional[Config] = None
async def handle_requests(self, websocket: WebSocket):
while request := await self.proxy_queue.handle_request():
await websocket.send_json(request)
async def on_connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.config = await websocket.receive_json()
if self.config['subdomain'] is None \
or await ProxyQueue.has_connection(self.config['subdomain']):
self.config['subdomain'] = uuid4().hex
self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])
hostname = os.environ.get("TUNNEL_DOMAIN")
protocol = "https" if os.environ.get("SECURE", False) else "http"
await websocket.send_json({
'url': f'{protocol}://{self.config["subdomain"]}.{hostname}'
})
self.request_task = asyncio.create_task(self.handle_requests(websocket))
async def on_receive(self, websocket: WebSocket, data: Any):
await self.proxy_queue.send_response(data)
async def on_disconnect(self, websocket: WebSocket, close_code: int):
await self.proxy_queue.delete()
if self.request_task is not None:
self.request_task.cancel()
|