mirror of
https://github.com/PaiGramTeam/GramCore.git
synced 2024-12-04 02:43:35 +00:00
♻️ Use AIORateLimiter Refactor RateLimiter
This commit is contained in:
parent
0ae6a40433
commit
fdaeaa19e9
125
ratelimiter.py
125
ratelimiter.py
@ -3,7 +3,8 @@ import contextlib
|
|||||||
from typing import Callable, Coroutine, Any, Union, List, Dict, Optional, TypeVar, Type, TYPE_CHECKING, Awaitable
|
from typing import Callable, Coroutine, Any, Union, List, Dict, Optional, TypeVar, Type, TYPE_CHECKING, Awaitable
|
||||||
|
|
||||||
from telegram.error import RetryAfter
|
from telegram.error import RetryAfter
|
||||||
from telegram.ext import BaseRateLimiter, ApplicationHandlerStop
|
from telegram.ext import AIORateLimiter
|
||||||
|
from telegram.ext._aioratelimiter import null_context
|
||||||
|
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
@ -15,19 +16,62 @@ RL_ARGS = TypeVar("RL_ARGS")
|
|||||||
T_CalledAPIFunc = Callable[[str, Dict[str, Any], Union[bool, JSONDict, List[JSONDict]]], Awaitable[Any]]
|
T_CalledAPIFunc = Callable[[str, Dict[str, Any], Union[bool, JSONDict, List[JSONDict]]], Awaitable[Any]]
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter(BaseRateLimiter[int]):
|
class RateLimiter(AIORateLimiter):
|
||||||
_lock = asyncio.Lock()
|
_lock = asyncio.Lock()
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"_limiter_info",
|
"_retry_after_event_map",
|
||||||
"_retry_after_event",
|
|
||||||
"_application",
|
"_application",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(
|
||||||
self._limiter_info: Dict[Union[str, int], float] = {}
|
self,
|
||||||
self._retry_after_event = asyncio.Event()
|
max_retries: int = 5,
|
||||||
self._retry_after_event.set()
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
self._application: Optional["Application"] = None
|
self._application: Optional["Application"] = None
|
||||||
|
self._retry_after_event_map: Dict[int, asyncio.Event] = {0: asyncio.Event()}
|
||||||
|
self._retry_after_event_map[0].set()
|
||||||
|
|
||||||
|
def clear_group_retry_after_event(self, group: Union[str, int]) -> None:
|
||||||
|
for key, retry_after_event in self._retry_after_event_map.copy().items():
|
||||||
|
if key == group:
|
||||||
|
continue
|
||||||
|
if retry_after_event.is_set():
|
||||||
|
del self._retry_after_event_map[key]
|
||||||
|
|
||||||
|
async def _get_group_retry_after_event(self, group: Union[str, int]) -> asyncio.Event:
|
||||||
|
async with self._lock:
|
||||||
|
event = self._retry_after_event_map.get(group)
|
||||||
|
if event:
|
||||||
|
return event
|
||||||
|
if isinstance(group, (str, int)):
|
||||||
|
if len(self._retry_after_event_map) > 512:
|
||||||
|
self.clear_group_retry_after_event(group)
|
||||||
|
|
||||||
|
if group not in self._retry_after_event_map:
|
||||||
|
event = asyncio.Event()
|
||||||
|
event.set()
|
||||||
|
self._retry_after_event_map[group] = event
|
||||||
|
event = self._retry_after_event_map[group]
|
||||||
|
if not event:
|
||||||
|
event = self._retry_after_event_map[0]
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def _run_request(
|
||||||
|
self,
|
||||||
|
chat: bool,
|
||||||
|
group: Union[str, int, bool],
|
||||||
|
callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, List[JSONDict]]]],
|
||||||
|
args: Any,
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> Union[bool, JSONDict, List[JSONDict]]:
|
||||||
|
base_context = self._base_limiter if (chat and self._base_limiter) else null_context()
|
||||||
|
group_context = self._get_group_limiter(group) if group and self._group_max_rate else null_context()
|
||||||
|
|
||||||
|
async with group_context, base_context:
|
||||||
|
return await callback(*args, **kwargs)
|
||||||
|
|
||||||
async def process_request(
|
async def process_request(
|
||||||
self,
|
self,
|
||||||
@ -38,44 +82,49 @@ class RateLimiter(BaseRateLimiter[int]):
|
|||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
rate_limit_args: Optional[RL_ARGS],
|
rate_limit_args: Optional[RL_ARGS],
|
||||||
) -> Union[bool, JSONDict, List[JSONDict]]:
|
) -> Union[bool, JSONDict, List[JSONDict]]:
|
||||||
chat_id = data.get("chat_id")
|
max_retries = rate_limit_args or self._max_retries
|
||||||
|
|
||||||
|
group: Union[int, str, bool] = False
|
||||||
|
chat: bool = False
|
||||||
|
chat_id = data.get("chat_id")
|
||||||
|
if chat_id is not None:
|
||||||
|
chat = True
|
||||||
|
|
||||||
|
# In case user passes integer chat id as string
|
||||||
with contextlib.suppress(ValueError, TypeError):
|
with contextlib.suppress(ValueError, TypeError):
|
||||||
chat_id = int(chat_id)
|
chat_id = int(chat_id)
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
if (isinstance(chat_id, int) and chat_id < 0) or isinstance(chat_id, str):
|
||||||
time = loop.time()
|
# string chat_id only works for channels and supergroups
|
||||||
|
# We can't really tell channels from groups though ...
|
||||||
|
group = chat_id
|
||||||
|
|
||||||
await self._retry_after_event.wait()
|
_retry_after_event = await self._get_group_retry_after_event(group)
|
||||||
|
await _retry_after_event.wait()
|
||||||
|
|
||||||
async with self._lock:
|
for i in range(max_retries + 1):
|
||||||
chat_limit_time = self._limiter_info.get(chat_id)
|
try:
|
||||||
if chat_limit_time:
|
result = await self._run_request(chat=chat, group=group, callback=callback, args=args, kwargs=kwargs)
|
||||||
if time >= chat_limit_time:
|
await self._on_called_api(endpoint, data, result)
|
||||||
raise ApplicationHandlerStop
|
return result
|
||||||
del self._limiter_info[chat_id]
|
except RetryAfter as exc:
|
||||||
|
if endpoint == "setWebhook" and exc.retry_after == 1:
|
||||||
|
# webhook 已被正确设置
|
||||||
|
return True
|
||||||
|
|
||||||
try:
|
if i == max_retries:
|
||||||
result = await callback(*args, **kwargs)
|
logger.warning("chat_id[%s] 达到最大重试限制 max_retries[%s]", chat_id, exc)
|
||||||
await self._on_called_api(endpoint, data, result)
|
raise exc
|
||||||
return result
|
|
||||||
except RetryAfter as exc:
|
|
||||||
if endpoint == "setWebhook" and exc.retry_after == 1:
|
|
||||||
# webhook 已被正确设置
|
|
||||||
return True
|
|
||||||
logger.warning("chat_id[%s] 触发洪水限制 当前被服务器限制 retry_after[%s]秒", chat_id, exc.retry_after)
|
|
||||||
self._limiter_info[chat_id] = time + (exc.retry_after * 2)
|
|
||||||
sleep = exc.retry_after + 0.1
|
|
||||||
self._retry_after_event.clear()
|
|
||||||
await asyncio.sleep(sleep)
|
|
||||||
finally:
|
|
||||||
self._retry_after_event.set()
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
sleep = exc.retry_after + 0.1
|
||||||
pass
|
logger.warning("chat_id[%s] 触发洪水限制 当前被服务器限制 retry_after[%s]秒", chat_id, exc.retry_after)
|
||||||
|
# Make sure we don't allow other requests to be processed
|
||||||
async def shutdown(self) -> None:
|
_retry_after_event.clear()
|
||||||
pass
|
await asyncio.sleep(sleep)
|
||||||
|
finally:
|
||||||
|
# Allow other requests to be processed
|
||||||
|
_retry_after_event.set()
|
||||||
|
return None # type: ignore[return-value]
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
def set_application(self, application: "Application") -> None:
|
||||||
self._application = application
|
self._application = application
|
||||||
|
Loading…
Reference in New Issue
Block a user