2022-09-10 11:57:23 +00:00
|
|
|
|
import asyncio
|
2022-07-26 10:07:31 +00:00
|
|
|
|
import time
|
|
|
|
|
from functools import wraps
|
2022-09-10 11:57:23 +00:00
|
|
|
|
from typing import Callable, cast, Optional, Any
|
2022-07-26 10:07:31 +00:00
|
|
|
|
|
|
|
|
|
from telegram import Update
|
|
|
|
|
from telegram.ext import filters, CallbackContext
|
|
|
|
|
|
2022-09-08 01:08:37 +00:00
|
|
|
|
from utils.log import logger
|
2022-07-26 10:07:31 +00:00
|
|
|
|
|
2022-09-10 11:57:23 +00:00
|
|
|
|
_lock = asyncio.Lock()
|
2022-07-26 10:07:31 +00:00
|
|
|
|
|
2022-09-10 11:57:23 +00:00
|
|
|
|
|
2022-10-10 11:07:28 +00:00
|
|
|
|
def restricts(
|
|
|
|
|
restricts_time: int = 9,
|
|
|
|
|
restricts_time_of_groups: Optional[int] = None,
|
|
|
|
|
return_data: Any = None,
|
|
|
|
|
without_overlapping: bool = False,
|
|
|
|
|
):
|
2022-07-26 10:07:31 +00:00
|
|
|
|
"""用于装饰在指定函数预防洪水攻击的装饰器
|
|
|
|
|
|
|
|
|
|
被修饰的函数生声明必须为
|
|
|
|
|
|
|
|
|
|
async def command_func(update, context)
|
|
|
|
|
或
|
|
|
|
|
async def command_func(self, update, context
|
|
|
|
|
|
|
|
|
|
如果修饰的函数属于
|
|
|
|
|
ConversationHandler
|
|
|
|
|
参数
|
|
|
|
|
return_data
|
|
|
|
|
必须传入
|
|
|
|
|
ConversationHandler.END
|
|
|
|
|
|
|
|
|
|
我真™是服了某些闲着没事干的群友了
|
|
|
|
|
|
2022-09-10 11:57:23 +00:00
|
|
|
|
:param restricts_time: 基础限制时间
|
|
|
|
|
:param restricts_time_of_groups: 对群限制的时间
|
|
|
|
|
:param return_data: 返回的数据对于 ConversationHandler 需要传入 ConversationHandler.END
|
|
|
|
|
:param without_overlapping: 两次命令时间不覆盖,在上一条一样的命令返回之前,忽略重复调用
|
2022-07-26 10:07:31 +00:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def decorator(func: Callable):
|
|
|
|
|
@wraps(func)
|
|
|
|
|
async def restricts_func(*args, **kwargs):
|
|
|
|
|
if len(args) == 3:
|
2022-09-08 01:08:37 +00:00
|
|
|
|
# self update context
|
|
|
|
|
_, update, context = args
|
2022-07-26 10:07:31 +00:00
|
|
|
|
elif len(args) == 2:
|
2022-09-08 01:08:37 +00:00
|
|
|
|
# update context
|
|
|
|
|
update, context = args
|
2022-07-26 10:07:31 +00:00
|
|
|
|
else:
|
|
|
|
|
return await func(*args, **kwargs)
|
2022-09-08 01:08:37 +00:00
|
|
|
|
update = cast(Update, update)
|
|
|
|
|
context = cast(CallbackContext, context)
|
|
|
|
|
message = update.effective_message
|
2022-07-26 10:07:31 +00:00
|
|
|
|
user = update.effective_user
|
2022-09-10 11:57:23 +00:00
|
|
|
|
|
|
|
|
|
_restricts_time = restricts_time
|
|
|
|
|
if restricts_time_of_groups is not None:
|
|
|
|
|
if filters.ChatType.GROUPS.filter(message):
|
|
|
|
|
_restricts_time = restricts_time_of_groups
|
|
|
|
|
|
|
|
|
|
async with _lock:
|
|
|
|
|
user_lock = context.user_data.get("lock")
|
|
|
|
|
if user_lock is None:
|
|
|
|
|
user_lock = context.user_data["lock"] = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
# 如果上一个命令还未完成,忽略后续重复调用
|
|
|
|
|
if without_overlapping and user_lock.locked():
|
|
|
|
|
logger.warning(f"用户 {user.full_name}[{user.id}] 触发 overlapping 该次命令已忽略")
|
|
|
|
|
return return_data
|
|
|
|
|
|
|
|
|
|
async with user_lock:
|
2022-07-26 10:07:31 +00:00
|
|
|
|
command_time = context.user_data.get("command_time", 0)
|
|
|
|
|
count = context.user_data.get("usage_count", 0)
|
|
|
|
|
restrict_since = context.user_data.get("restrict_since", 0)
|
2022-09-10 11:57:23 +00:00
|
|
|
|
|
2022-07-26 10:07:31 +00:00
|
|
|
|
# 洪水防御
|
|
|
|
|
if restrict_since:
|
|
|
|
|
if (time.time() - restrict_since) >= 60 * 5:
|
|
|
|
|
del context.user_data["restrict_since"]
|
|
|
|
|
del context.user_data["usage_count"]
|
|
|
|
|
else:
|
|
|
|
|
return return_data
|
|
|
|
|
else:
|
2022-09-10 11:57:23 +00:00
|
|
|
|
if count >= 6:
|
2022-07-26 10:07:31 +00:00
|
|
|
|
context.user_data["restrict_since"] = time.time()
|
2022-09-08 01:08:37 +00:00
|
|
|
|
await message.reply_text("你已经触发洪水防御,请等待5分钟")
|
|
|
|
|
logger.warning(f"用户 {user.full_name}[{user.id}] 触发洪水限制 已被限制5分钟")
|
2022-07-26 10:07:31 +00:00
|
|
|
|
return return_data
|
|
|
|
|
# 单次使用限制
|
|
|
|
|
if command_time:
|
2022-09-10 11:57:23 +00:00
|
|
|
|
if (time.time() - command_time) <= _restricts_time:
|
2022-07-26 10:07:31 +00:00
|
|
|
|
context.user_data["usage_count"] = count + 1
|
|
|
|
|
else:
|
|
|
|
|
if count >= 1:
|
|
|
|
|
context.user_data["usage_count"] = count - 1
|
|
|
|
|
context.user_data["command_time"] = time.time()
|
|
|
|
|
|
2022-09-10 11:57:23 +00:00
|
|
|
|
# 只需要给 without_overlapping 的代码加锁运行
|
|
|
|
|
if without_overlapping:
|
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
if count > 1:
|
|
|
|
|
await asyncio.sleep(count)
|
2022-07-26 10:07:31 +00:00
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return restricts_func
|
|
|
|
|
|
|
|
|
|
return decorator
|