GramCore/handler/limiterhandler.py
2024-01-17 13:53:09 +08:00

69 lines
2.9 KiB
Python

import asyncio
from typing import Optional
from telegram import Update
from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler
from utils.log import logger
class LimiterHandler(TypeHandler):
_lock = asyncio.Lock()
def __init__(
self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None
):
"""Limiter Handler 通过
`Leaky bucket algorithm <https://en.wikipedia.org/wiki/Leaky_bucket>`_
实现对用户的输入的精确控制
输入超过一定速率后,代码会抛出
:class:`telegram.ext.ApplicationHandlerStop`
异常并在一段时间内防止用户执行任何其他操作
:param max_rate: 在抛出异常之前最多允许 频率/秒 的速度
:param time_period: 在限制速率的时间段的持续时间
:param amount: 提供的容量
:param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount
"""
self.max_rate = max_rate
self.amount = amount
self._rate_per_sec = max_rate / time_period
self.limit_time = limit_time
super().__init__(Update, self.limiter_callback)
async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
if update.inline_query is not None:
return
loop = asyncio.get_running_loop()
async with self._lock:
time = loop.time()
user_data = context.user_data
if user_data is None:
return
user_limit_time = user_data.get("limit_time")
if user_limit_time is not None:
if time >= user_limit_time:
del user_data["limit_time"]
else:
raise ApplicationHandlerStop
last_task_time = user_data.get("last_task_time", 0)
if last_task_time:
task_level = user_data.get("task_level", 0)
elapsed = time - last_task_time
decrement = elapsed * self._rate_per_sec
task_level = max(task_level - decrement, 0)
user_data["task_level"] = task_level
if not task_level + self.amount <= self.max_rate:
if self.limit_time:
limit_time = self.limit_time
else:
limit_time = 1 / self._rate_per_sec * self.amount
user_data["limit_time"] = time + limit_time
user = update.effective_user
logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s", user.full_name, user.id, limit_time)
raise ApplicationHandlerStop
user_data["last_task_time"] = time
task_level = user_data.get("task_level", 0)
user_data["task_level"] = task_level + self.amount