1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
|
import asyncio
import json
import logging
import os
import typing
from asyncio import create_task
from base64 import b64encode, b64decode
from contextlib import asynccontextmanager
from typing import Optional
from uuid import uuid4
from starlette.endpoints import WebSocketEndpoint
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
import ttun_server
from ttun_server.proxy_queue import ProxyQueue
from ttun_server.types import Config, Message, WebsocketMessageType, \
WebsocketConnectData, WebsocketMessage, WebsocketMessageData, WebsocketDisconnectData, MessageType
logger = logging.getLogger(__name__)
logger.setLevel('DEBUG')
class WebsocketProxy(WebSocketEndpoint):
encoding = 'json'
websocket_listen_task = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.id = str(uuid4())
@asynccontextmanager
async def proxy(self, websocket: WebSocket, message: WebsocketMessage):
[subdomain, *_] = websocket.url.hostname.split('.')
expect_ack = WebsocketMessageType(message['type']) == WebsocketMessageType.connect
try:
request_queue = await ProxyQueue.get_for_identifier(subdomain)
await request_queue.enqueue(message)
if expect_ack:
response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{message["identifier"]}')
yield await response_queue.dequeue()
await response_queue.delete()
else:
yield
except AssertionError:
pass
async def listen_for_messages(self, websocket: WebSocket):
[subdomain, *_] = websocket.url.hostname.split('.')
print('listen', self.id)
response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{self.id}')
while True:
message: WebsocketMessage = await response_queue.dequeue()
logger.debug(message)
await websocket.send_text(b64decode(message['payload']['body'].encode()).decode())
async def on_connect(self, websocket: WebSocket) -> None:
message = WebsocketMessage(
type=WebsocketMessageType.connect.value,
identifier=self.id,
payload=WebsocketConnectData(
path=websocket.path_params['path'],
headers=[
(k.decode(), v.decode())
for k, v
in websocket.scope['headers']
],
)
)
async with self.proxy(websocket, message) as m:
type = WebsocketMessageType(m['type'])
if type == WebsocketMessageType.ack:
await super().on_connect(websocket)
self.websocket_listen_task = asyncio.create_task(self.listen_for_messages(websocket))
def callback(*args, **kwargs):
self.websocket_listen_task = None
self.websocket_listen_task.add_done_callback(callback)
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
match data:
case dict():
data_bytes = json.dumps(data).encode()
case bytes():
data_bytes = data
case _:
data_bytes = data.encode()
message = WebsocketMessage(
type=WebsocketMessageType.message.value,
identifier=self.id,
payload=WebsocketMessageData(
body=b64encode(data_bytes).decode(),
)
)
async with self.proxy(websocket, message):
pass
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
message = WebsocketMessage(
type=WebsocketMessageType.disconnect.value,
identifier=self.id,
payload=WebsocketDisconnectData(
close_code=close_code,
)
)
async with self.proxy(websocket, message):
if self.websocket_listen_task is not None:
self.websocket_listen_task.cancel()
class Tunnel(WebSocketEndpoint):
encoding = 'json'
def __init__(self, scope: Scope, receive: Receive, send: Send):
super().__init__(scope, receive, send)
self.request_task = None
self.config: Optional[Config] = None
async def handle_requests(self, websocket: WebSocket):
while request := await self.proxy_queue.dequeue():
create_task(websocket.send_json(request))
async def on_connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.config = await websocket.receive_json()
client_version = self.config.get('version', '1.0.0')
logger.debug('client_version %s', client_version)
if 'git' not in client_version and ttun_server.__version__ != 'development':
[client_major, *_] = [int(i) for i in client_version.split('.')[:3]]
[server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')]
if client_major < server_major:
await websocket.close(4000, 'Your client is too old')
if client_major > server_major:
await websocket.close(4001, 'Your client is too new')
if self.config['subdomain'] is None \
or await ProxyQueue.has_connection(self.config['subdomain']):
self.config['subdomain'] = uuid4().hex
self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])
hostname = os.environ.get("TUNNEL_DOMAIN")
protocol = "https" if os.environ.get("SECURE", False) else "http"
await websocket.send_json({
'url': f'{protocol}://{self.config["subdomain"]}.{hostname}'
})
self.request_task = asyncio.create_task(self.handle_requests(websocket))
async def on_receive(self, websocket: WebSocket, data: Message):
try:
response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}")
await response_queue.enqueue(data)
except AssertionError:
pass
async def on_disconnect(self, websocket: WebSocket, close_code: int):
await self.proxy_queue.delete()
if self.request_task is not None:
self.request_task.cancel()
|