From 520cc7ae3ec44b7db2b103045badc81a275eb245 Mon Sep 17 00:00:00 2001 From: xtaodada Date: Tue, 7 Nov 2023 14:14:47 +0800 Subject: [PATCH] feat: support fastapi proxy --- main.py | 77 +++++++++++++++++++++++++++++++++++------------- requirements.txt | 8 ++--- 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/main.py b/main.py index 8a36629..0a17807 100644 --- a/main.py +++ b/main.py @@ -1,31 +1,68 @@ +from typing import Dict + from fastapi import FastAPI -from httpx import AsyncClient +from httpx import ( + AsyncClient, + RemoteProtocolError, + UnsupportedProtocol, + Response as HttpxResponse, + ConnectError, +) from starlette.requests import Request from starlette.responses import Response app = FastAPI() -client = AsyncClient( - timeout=60.0, +WHITE_LIST = ["mihoyo.com", "miyoushe.com"] + + +def rewrite_headers(old_headers: Dict[str, str]) -> Dict[str, str]: + remove_keys = ["host"] + headers = {} + for k, v in old_headers.items(): + if k.lower() not in remove_keys: + headers[k] = v + return headers + + +async def get_proxy(req: Request) -> Response: + path = req.path_params.get("path") + if not path: + return Response(status_code=400, content="path is required") + for domain in WHITE_LIST: + if domain in path: + break + else: + return Response(status_code=400, content="domain is not allowed") + query = str(req.query_params) + headers = rewrite_headers(dict(req.headers)) + method = req.method + try: + body = await req.body() + except Exception as e: + return Response(status_code=400, content=f"get request body info error: {e}") + q = "?" + query if query else "" + target_url = path + q + async with AsyncClient(timeout=120, follow_redirects=True) as client: + try: + async with client.stream( + method, target_url, headers=headers, data=body + ) as r: + headers = dict(r.headers) + r: HttpxResponse + _content = b"".join([part async for part in r.aiter_raw(1024 * 10)]) + return Response(_content, headers=headers, status_code=r.status_code) + except (RemoteProtocolError, UnsupportedProtocol, ConnectError): + return Response(content="UnsupportedProtocol", status_code=400) + + +app.add_route( + "/{path:path}", + get_proxy, + methods=["OPTIONS", "HEAD", "GET", "POST", "PUT", "PATCH", "DELETE"], ) -@app.get('/upload/{file_path:path}') -async def proxy_cve_search_api(req: Request, file_path: str) -> Response: - headers = {} - for i in req.headers.items(): - headers[i[0]] = i[1] - headers["host"] = "upload-bbs.miyoushe.com" - resp = await client.get( - f'https://upload-bbs.miyoushe.com/upload/{file_path}', - params=req.query_params, - headers=headers, - follow_redirects=True, - ) - content = resp.content - return Response(content=content, status_code=resp.status_code, headers=dict(resp.headers)) - - -if __name__ == '__main__': +if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=5677) diff --git a/requirements.txt b/requirements.txt index 31beee3..f5a77d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -httpx==0.24.1 -fastapi~=0.95.2 -starlette~=0.27.0 -uvicorn~=0.22.0 +httpx==0.25.1 +fastapi~=0.104.1 +starlette~=0.32.0.post1 +uvicorn~=0.24.0.post1