summaryrefslogtreecommitdiffstats
path: root/ttun_server/endpoints.py
diff options
context:
space:
mode:
Diffstat (limited to 'ttun_server/endpoints.py')
-rw-r--r--ttun_server/endpoints.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/ttun_server/endpoints.py b/ttun_server/endpoints.py
index 5b9e57f..b33fe65 100644
--- a/ttun_server/endpoints.py
+++ b/ttun_server/endpoints.py
@@ -1,7 +1,6 @@
1import asyncio 1import asyncio
2import logging 2import logging
3import os 3import os
4from asyncio import Queue
5from base64 import b64decode, b64encode 4from base64 import b64decode, b64encode
6from typing import Optional, Any 5from typing import Optional, Any
7from uuid import uuid4 6from uuid import uuid4
@@ -13,11 +12,20 @@ from starlette.types import Scope, Receive, Send
13from starlette.websockets import WebSocket 12from starlette.websockets import WebSocket
14 13
15from ttun_server.proxy_queue import ProxyQueue 14from ttun_server.proxy_queue import ProxyQueue
16from ttun_server.types import RequestData, Config, ResponseData 15from ttun_server.types import RequestData, Config
17 16
18logger = logging.getLogger(__name__) 17logger = logging.getLogger(__name__)
19 18
20 19
20class HeaderMapping:
21 def __init__(self, headers: list[tuple[str, str]]):
22 self._headers = headers
23
24 def items(self):
25 for header in self._headers:
26 yield header
27
28
21class Proxy(HTTPEndpoint): 29class Proxy(HTTPEndpoint):
22 async def dispatch(self) -> None: 30 async def dispatch(self) -> None:
23 request = Request(self.scope, self.receive) 31 request = Request(self.scope, self.receive)
@@ -31,15 +39,14 @@ class Proxy(HTTPEndpoint):
31 await queue.send_request(RequestData( 39 await queue.send_request(RequestData(
32 method=request.method, 40 method=request.method,
33 path=str(request.url).replace(str(request.base_url), '/'), 41 path=str(request.url).replace(str(request.base_url), '/'),
34 headers=dict(request.headers), 42 headers=list(request.headers.items()),
35 cookies=dict(request.cookies),
36 body=b64encode(await request.body()).decode() 43 body=b64encode(await request.body()).decode()
37 )) 44 ))
38 45
39 _response = await queue.handle_response() 46 _response = await queue.handle_response()
40 response = Response( 47 response = Response(
41 status_code=_response['status'], 48 status_code=_response['status'],
42 headers=_response['headers'], 49 headers=HeaderMapping(_response['headers']),
43 content=b64decode(_response['body'].encode()) 50 content=b64decode(_response['body'].encode())
44 ) 51 )
45 except AssertionError: 52 except AssertionError: