feat: support fastapi proxy

This commit is contained in:
xtaodada 2023-11-07 14:14:47 +08:00
parent 8df84b165f
commit 520cc7ae3e
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
2 changed files with 61 additions and 24 deletions

77
main.py
View File

@ -1,31 +1,68 @@
from typing import Dict
from fastapi import FastAPI from fastapi import FastAPI
from httpx import AsyncClient from httpx import (
AsyncClient,
RemoteProtocolError,
UnsupportedProtocol,
Response as HttpxResponse,
ConnectError,
)
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
app = FastAPI() app = FastAPI()
client = AsyncClient( WHITE_LIST = ["mihoyo.com", "miyoushe.com"]
timeout=60.0,
def rewrite_headers(old_headers: Dict[str, str]) -> Dict[str, str]:
remove_keys = ["host"]
headers = {}
for k, v in old_headers.items():
if k.lower() not in remove_keys:
headers[k] = v
return headers
async def get_proxy(req: Request) -> Response:
path = req.path_params.get("path")
if not path:
return Response(status_code=400, content="path is required")
for domain in WHITE_LIST:
if domain in path:
break
else:
return Response(status_code=400, content="domain is not allowed")
query = str(req.query_params)
headers = rewrite_headers(dict(req.headers))
method = req.method
try:
body = await req.body()
except Exception as e:
return Response(status_code=400, content=f"get request body info error: {e}")
q = "?" + query if query else ""
target_url = path + q
async with AsyncClient(timeout=120, follow_redirects=True) as client:
try:
async with client.stream(
method, target_url, headers=headers, data=body
) as r:
headers = dict(r.headers)
r: HttpxResponse
_content = b"".join([part async for part in r.aiter_raw(1024 * 10)])
return Response(_content, headers=headers, status_code=r.status_code)
except (RemoteProtocolError, UnsupportedProtocol, ConnectError):
return Response(content="UnsupportedProtocol", status_code=400)
app.add_route(
"/{path:path}",
get_proxy,
methods=["OPTIONS", "HEAD", "GET", "POST", "PUT", "PATCH", "DELETE"],
) )
@app.get('/upload/{file_path:path}') if __name__ == "__main__":
async def proxy_cve_search_api(req: Request, file_path: str) -> Response:
headers = {}
for i in req.headers.items():
headers[i[0]] = i[1]
headers["host"] = "upload-bbs.miyoushe.com"
resp = await client.get(
f'https://upload-bbs.miyoushe.com/upload/{file_path}',
params=req.query_params,
headers=headers,
follow_redirects=True,
)
content = resp.content
return Response(content=content, status_code=resp.status_code, headers=dict(resp.headers))
if __name__ == '__main__':
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5677) uvicorn.run(app, host="0.0.0.0", port=5677)

View File

@ -1,4 +1,4 @@
httpx==0.24.1 httpx==0.25.1
fastapi~=0.95.2 fastapi~=0.104.1
starlette~=0.27.0 starlette~=0.32.0.post1
uvicorn~=0.22.0 uvicorn~=0.24.0.post1