summaryrefslogtreecommitdiffstats
path: root/ttun_server/endpoints.py
blob: 3e263da912cafc57b6524316521ee2842e03ba65 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import asyncio
import logging
import os
from asyncio import create_task
from base64 import b64decode, b64encode
from typing import Optional, Any
from uuid import uuid4

from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket

import ttun_server
from ttun_server.proxy_queue import ProxyQueue
from ttun_server.types import RequestData, Config, Message, MessageType

logger = logging.getLogger(__name__)


class HeaderMapping:
    def __init__(self, headers: list[tuple[str, str]]):
        self._headers = headers

    def items(self):
        for header in self._headers:
            yield header


class Proxy(HTTPEndpoint):
    async def dispatch(self) -> None:
        request = Request(self.scope, self.receive)

        [subdomain, *_] = request.headers['host'].split('.')
        response = Response(content='Not Found', status_code=404)

        identifier = str(uuid4())
        response_queue = await ProxyQueue.create_for_identifier(f'{subdomain}_{identifier}')

        try:

            request_queue = await ProxyQueue.get_for_identifier(subdomain)

            logger.debug('PROXY %s%s ', subdomain, request.url)
            await request_queue.enqueue(
                Message(
                    type=MessageType.request.value,
                    identifier=identifier,
                    payload=
                    RequestData(
                        method=request.method,
                        path=str(request.url).replace(str(request.base_url), '/'),
                        headers=list(request.headers.items()),
                        body=b64encode(await request.body()).decode()
                    )
                )
            )

            _response = await response_queue.dequeue()
            payload = _response['payload']
            response = Response(
                status_code=payload['status'],
                headers=HeaderMapping(payload['headers']),
                content=b64decode(payload['body'].encode())
            )
        except AssertionError:
            pass
        finally:
            await response(self.scope, self.receive, self.send)
            await response_queue.delete()


class Health(HTTPEndpoint):
    async def get(self, _) -> None:
        response = Response(content='OK', status_code=200)

        await response(self.scope, self.receive, self.send)


class Tunnel(WebSocketEndpoint):
    encoding = 'json'

    def __init__(self, scope: Scope, receive: Receive, send: Send):
        super().__init__(scope, receive, send)
        self.request_task = None
        self.config: Optional[Config] = None

    async def handle_requests(self, websocket: WebSocket):
        while request := await self.proxy_queue.dequeue():
            create_task(websocket.send_json(request))

    async def on_connect(self, websocket: WebSocket) -> None:
        await websocket.accept()
        self.config = await websocket.receive_json()

        client_version = self.config.get('version', '1.0.0')
        logger.debug('client_version %s', client_version)

        if 'git' not in client_version and ttun_server.__version__ != 'development':
            [client_major, *_] = [int(i) for i in client_version.split('.')[:3]]
            [server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')]

            if client_major < server_major:
                await websocket.close(4000, 'Your client is too old')

            if client_major > server_major:
                await websocket.close(4001, 'Your client is too new')


        if self.config['subdomain'] is None \
                or await ProxyQueue.has_connection(self.config['subdomain']):
            self.config['subdomain'] = uuid4().hex


        self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])

        hostname = os.environ.get("TUNNEL_DOMAIN")
        protocol = "https" if os.environ.get("SECURE", False) else "http"

        await websocket.send_json({
            'url': f'{protocol}://{self.config["subdomain"]}.{hostname}'
        })

        self.request_task = asyncio.create_task(self.handle_requests(websocket))

    async def on_receive(self, websocket: WebSocket, data: Message):
        try:
            response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}")
            await response_queue.enqueue(data)
        except AssertionError:
            pass

    async def on_disconnect(self, websocket: WebSocket, close_code: int):
        await self.proxy_queue.delete()

        if self.request_task is not None:
            self.request_task.cancel()