diff --git a/plugins/genshin/gacha/gacha.py b/plugins/genshin/gacha/gacha.py index ae7a7c6..4ae7987 100644 --- a/plugins/genshin/gacha/gacha.py +++ b/plugins/genshin/gacha/gacha.py @@ -48,8 +48,7 @@ class Gacha(Plugin, BasePlugin): @handler(CommandHandler, command="gacha", block=False) @handler(MessageHandler, filters=filters.Regex("^深渊数据查询(.*)"), block=False) @handler(MessageHandler, filters=filters.Regex("^非首模拟器(.*)"), block=False) - @restricts(filters.ChatType.GROUPS, restricts_time=20, try_delete_message=True) - @restricts(filters.ChatType.PRIVATE) + @restricts(restricts_time=3, restricts_time_of_groups=20) @error_callable async def command_start(self, update: Update, context: CallbackContext) -> None: message = update.message diff --git a/plugins/genshin/player_cards.py b/plugins/genshin/player_cards.py index ba137f0..73dbcde 100644 --- a/plugins/genshin/player_cards.py +++ b/plugins/genshin/player_cards.py @@ -55,8 +55,7 @@ class PlayerCards(Plugin, BasePlugin): @handler(CommandHandler, command="player_card", block=False) @handler(MessageHandler, filters=filters.Regex("^角色卡片查询(.*)"), block=False) - @restricts(filters.ChatType.GROUPS, restricts_time=20, try_delete_message=True) - @restricts(filters.ChatType.PRIVATE) + @restricts(restricts_time_of_groups=20, without_overlapping=True) @error_callable async def player_cards(self, update: Update, context: CallbackContext) -> None: user = update.effective_user @@ -119,8 +118,7 @@ class PlayerCards(Plugin, BasePlugin): await message.reply_photo(pnd_data, filename=f"player_card_{uid}_{character_name}.png") @handler(CallbackQueryHandler, pattern=r"^get_player_card\|", block=False) - @restricts(filters.ChatType.GROUPS, restricts_time=6) - @restricts(filters.ChatType.PRIVATE, restricts_time=6) + @restricts(restricts_time_of_groups=20, without_overlapping=True) async def get_player_cards(self, update: Update, _: CallbackContext) -> None: callback_query = update.callback_query user = callback_query.from_user diff --git a/plugins/genshin/quiz.py b/plugins/genshin/quiz.py index 35d58db..05ba1f0 100644 --- a/plugins/genshin/quiz.py +++ b/plugins/genshin/quiz.py @@ -23,7 +23,7 @@ class QuizPlugin(Plugin, BasePlugin): self.random = MT19937Random() @handler(CommandHandler, command="quiz", block=False) - @restricts(restricts_time=20, try_delete_message=True) + @restricts(restricts_time_of_groups=20) async def command_start(self, update: Update, context: CallbackContext) -> None: user = update.effective_user message = update.effective_message diff --git a/utils/decorators/restricts.py b/utils/decorators/restricts.py index c09af60..9fccfce 100644 --- a/utils/decorators/restricts.py +++ b/utils/decorators/restricts.py @@ -1,16 +1,18 @@ +import asyncio import time from functools import wraps -from typing import Callable, cast +from typing import Callable, cast, Optional, Any from telegram import Update -from telegram.error import TelegramError from telegram.ext import filters, CallbackContext from utils.log import logger +_lock = asyncio.Lock() -def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_message: bool = False, - restricts_time: int = 5): + +def restricts(restricts_time: int = 9, restricts_time_of_groups: Optional[int] = None, return_data: Any = None, + without_overlapping: bool = False): """用于装饰在指定函数预防洪水攻击的装饰器 被修饰的函数生声明必须为 @@ -28,11 +30,10 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_ 我真™是服了某些闲着没事干的群友了 - :param filters_chat: 要限制的群 - :param return_data: - :param try_delete_message: - :param restricts_time: - :return: return_data + :param restricts_time: 基础限制时间 + :param restricts_time_of_groups: 对群限制的时间 + :param return_data: 返回的数据对于 ConversationHandler 需要传入 ConversationHandler.END + :param without_overlapping: 两次命令时间不覆盖,在上一条一样的命令返回之前,忽略重复调用 """ def decorator(func: Callable): @@ -50,10 +51,27 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_ context = cast(CallbackContext, context) message = update.effective_message user = update.effective_user - if filters_chat.filter(message): + + _restricts_time = restricts_time + if restricts_time_of_groups is not None: + if filters.ChatType.GROUPS.filter(message): + _restricts_time = restricts_time_of_groups + + async with _lock: + user_lock = context.user_data.get("lock") + if user_lock is None: + user_lock = context.user_data["lock"] = asyncio.Lock() + + # 如果上一个命令还未完成,忽略后续重复调用 + if without_overlapping and user_lock.locked(): + logger.warning(f"用户 {user.full_name}[{user.id}] 触发 overlapping 该次命令已忽略") + return return_data + + async with user_lock: command_time = context.user_data.get("command_time", 0) count = context.user_data.get("usage_count", 0) restrict_since = context.user_data.get("restrict_since", 0) + # 洪水防御 if restrict_since: if (time.time() - restrict_since) >= 60 * 5: @@ -62,29 +80,26 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_ else: return return_data else: - if count == 5: + if count >= 6: context.user_data["restrict_since"] = time.time() await message.reply_text("你已经触发洪水防御,请等待5分钟") logger.warning(f"用户 {user.full_name}[{user.id}] 触发洪水限制 已被限制5分钟") return return_data # 单次使用限制 if command_time: - if (time.time() - command_time) <= restricts_time: + if (time.time() - command_time) <= _restricts_time: context.user_data["usage_count"] = count + 1 - if filters.ChatType.GROUPS.filter(message): - if try_delete_message: - try: - await message.delete() - except TelegramError as exc: - logger.warning("删除消息失败") - logger.exception(exc) - return return_data else: if count >= 1: context.user_data["usage_count"] = count - 1 - context.user_data["command_time"] = time.time() + # 只需要给 without_overlapping 的代码加锁运行 + if without_overlapping: + return await func(*args, **kwargs) + + if count > 1: + await asyncio.sleep(count) return await func(*args, **kwargs) return restricts_func