summaryrefslogtreecommitdiffstats
path: root/ttun_server/endpoints.py
blob: d59cb7cc88e824d3d90693b7365b65b8b5dbc1e1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import asyncio
import os
from asyncio import Queue
from base64 import b64decode, b64encode
from typing import Optional, Any
from uuid import uuid4

from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket

from ttun_server.types import Connection, RequestData, Config, ResponseData

from ttun_server.connections import connections


class Proxy(HTTPEndpoint):
    async def dispatch(self) -> None:
        request = Request(self.scope, self.receive)

        [subdomain, *_] = request.headers['host'].split('.')
        response = Response(content='Not Found', status_code=404)

        if subdomain in connections:
            connection = connections[subdomain]

            await connection['requests'].put(RequestData(
                method=request.method,
                path=str(request.url).replace(str(request.base_url), '/'),
                headers=dict(request.headers),
                cookies=dict(request.cookies),
                body=b64encode(await request.body()).decode()
            ))

            _response = await connection['responses'].get()
            response = Response(
                status_code=_response['status'],
                headers=_response['headers'],
                content=b64decode(_response['body'].encode())
            )

        await response(self.scope, self.receive, self.send)


class Tunnel(WebSocketEndpoint):
    encoding = 'json'

    def __init__(self, scope: Scope, receive: Receive, send: Send):
        super().__init__(scope, receive, send)
        self.request_task = None
        self.config: Optional[Config] = None

    @property
    def requests(self) -> Queue[RequestData]:
        return connections[self.config['subdomain']]['requests']

    @property
    def responses(self) -> Queue[ResponseData]:
        return connections[self.config['subdomain']]['responses']

    async def handle_requests(self, websocket: WebSocket):
        while request := await self.requests.get():
            await websocket.send_json(request)

    async def on_connect(self, websocket: WebSocket) -> None:
        await websocket.accept()
        self.config = await websocket.receive_json()

        if self.config['subdomain'] is None \
                or self.config['subdomain'] in connections:
            self.config['subdomain'] = uuid4().hex


        connections[self.config['subdomain']] = Connection(
            requests=Queue(),
            responses=Queue(),
        )

        hostname = os.environ.get("TUNNEL_DOMAIN")
        protocol = "https" if os.environ.get("SECURE", False) else "http"

        await websocket.send_json({
            'url': f'{protocol}://{self.config["subdomain"]}.{hostname}'
        })

        self.request_task = asyncio.create_task(self.handle_requests(websocket))

    async def on_receive(self, websocket: WebSocket, data: Any) -> None:
        await self.responses.put(data)

    async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
        if self.config is not None and self.config['subdomain'] in connections:
            del connections[self.config['subdomain']]

        if self.request_task is not None:
            self.request_task.cancel()