From 83638b03233cf92d572f0a61acade3a854083338 Mon Sep 17 00:00:00 2001 From: xtaodada Date: Wed, 12 Jun 2024 23:09:19 +0800 Subject: [PATCH] feat: support hoyolab --- .gitignore | 1 + env.py.example | 1 + main.py | 33 ++++++++++++++++++++------------- requirements.txt | 8 ++++---- 4 files changed, 26 insertions(+), 17 deletions(-) create mode 100644 env.py.example diff --git a/.gitignore b/.gitignore index 2dc53ca..aabac04 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +env.py diff --git a/env.py.example b/env.py.example new file mode 100644 index 0000000..ec18309 --- /dev/null +++ b/env.py.example @@ -0,0 +1 @@ +proxy = "socks5://127.0.0.1:7676" diff --git a/main.py b/main.py index 0a17807..5acae14 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional from fastapi import FastAPI from httpx import ( @@ -10,9 +10,26 @@ from httpx import ( ) from starlette.requests import Request from starlette.responses import Response +from env import proxy app = FastAPI() -WHITE_LIST = ["mihoyo.com", "miyoushe.com"] +WHITE_LIST = ["mihoyo.com", "miyoushe.com", "hoyolab.com", "hoyoverse.com"] + + +async def req_client(method: str, target_url: str, headers, body, _proxy: Optional[str]) -> Response: + async with AsyncClient(timeout=120, follow_redirects=True, proxy=proxy) as client: + try: + async with client.stream( + method, target_url, headers=headers, data=body + ) as r: + headers = dict(r.headers) + r: HttpxResponse + _content = b"".join([part async for part in r.aiter_raw(1024 * 10)]) + return Response(_content, headers=headers, status_code=r.status_code) + except (RemoteProtocolError, UnsupportedProtocol, ConnectError): + if _proxy is not None: + return await req_client(method, target_url, headers, body, None) + return Response(content="UnsupportedProtocol", status_code=400) def rewrite_headers(old_headers: Dict[str, str]) -> Dict[str, str]: @@ -42,17 +59,7 @@ async def get_proxy(req: Request) -> Response: return Response(status_code=400, content=f"get request body info error: {e}") q = "?" + query if query else "" target_url = path + q - async with AsyncClient(timeout=120, follow_redirects=True) as client: - try: - async with client.stream( - method, target_url, headers=headers, data=body - ) as r: - headers = dict(r.headers) - r: HttpxResponse - _content = b"".join([part async for part in r.aiter_raw(1024 * 10)]) - return Response(_content, headers=headers, status_code=r.status_code) - except (RemoteProtocolError, UnsupportedProtocol, ConnectError): - return Response(content="UnsupportedProtocol", status_code=400) + return await req_client(method, target_url, headers, body, proxy) app.add_route( diff --git a/requirements.txt b/requirements.txt index f5a77d3..a9f51d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -httpx==0.25.1 -fastapi~=0.104.1 -starlette~=0.32.0.post1 -uvicorn~=0.24.0.post1 +httpx +fastapi +starlette +uvicorn