♻️ 重写 restricts 修饰器

* 🔧 `restricts` 修饰器增加 `no_overlapping` 参数,避免同一个 user 对一个 handler 的多次调用时间重叠
* ♻ 移除 `try_delete_message` 参数 使用 `sleep` 替代
* 使用 `asyncio.Lock` 保证 `context.user_data` 的计数器 `usage_count` 线程安全
* 增加 `no_overlapping` 参数,在上一条一样的命令返回之前,忽略重复调用

Co-authored-by: 洛水居室 <luoshuijs@outlook.com>
This commit is contained in:
Chuangbo Li 2022-09-10 19:57:23 +08:00 committed by GitHub
parent 649c647d9e
commit 87c7253e3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 28 deletions

View File

@ -48,8 +48,7 @@ class Gacha(Plugin, BasePlugin):
@handler(CommandHandler, command="gacha", block=False) @handler(CommandHandler, command="gacha", block=False)
@handler(MessageHandler, filters=filters.Regex("^深渊数据查询(.*)"), block=False) @handler(MessageHandler, filters=filters.Regex("^深渊数据查询(.*)"), block=False)
@handler(MessageHandler, filters=filters.Regex("^非首模拟器(.*)"), block=False) @handler(MessageHandler, filters=filters.Regex("^非首模拟器(.*)"), block=False)
@restricts(filters.ChatType.GROUPS, restricts_time=20, try_delete_message=True) @restricts(restricts_time=3, restricts_time_of_groups=20)
@restricts(filters.ChatType.PRIVATE)
@error_callable @error_callable
async def command_start(self, update: Update, context: CallbackContext) -> None: async def command_start(self, update: Update, context: CallbackContext) -> None:
message = update.message message = update.message

View File

@ -55,8 +55,7 @@ class PlayerCards(Plugin, BasePlugin):
@handler(CommandHandler, command="player_card", block=False) @handler(CommandHandler, command="player_card", block=False)
@handler(MessageHandler, filters=filters.Regex("^角色卡片查询(.*)"), block=False) @handler(MessageHandler, filters=filters.Regex("^角色卡片查询(.*)"), block=False)
@restricts(filters.ChatType.GROUPS, restricts_time=20, try_delete_message=True) @restricts(restricts_time_of_groups=20, without_overlapping=True)
@restricts(filters.ChatType.PRIVATE)
@error_callable @error_callable
async def player_cards(self, update: Update, context: CallbackContext) -> None: async def player_cards(self, update: Update, context: CallbackContext) -> None:
user = update.effective_user user = update.effective_user
@ -119,8 +118,7 @@ class PlayerCards(Plugin, BasePlugin):
await message.reply_photo(pnd_data, filename=f"player_card_{uid}_{character_name}.png") await message.reply_photo(pnd_data, filename=f"player_card_{uid}_{character_name}.png")
@handler(CallbackQueryHandler, pattern=r"^get_player_card\|", block=False) @handler(CallbackQueryHandler, pattern=r"^get_player_card\|", block=False)
@restricts(filters.ChatType.GROUPS, restricts_time=6) @restricts(restricts_time_of_groups=20, without_overlapping=True)
@restricts(filters.ChatType.PRIVATE, restricts_time=6)
async def get_player_cards(self, update: Update, _: CallbackContext) -> None: async def get_player_cards(self, update: Update, _: CallbackContext) -> None:
callback_query = update.callback_query callback_query = update.callback_query
user = callback_query.from_user user = callback_query.from_user

View File

@ -23,7 +23,7 @@ class QuizPlugin(Plugin, BasePlugin):
self.random = MT19937Random() self.random = MT19937Random()
@handler(CommandHandler, command="quiz", block=False) @handler(CommandHandler, command="quiz", block=False)
@restricts(restricts_time=20, try_delete_message=True) @restricts(restricts_time_of_groups=20)
async def command_start(self, update: Update, context: CallbackContext) -> None: async def command_start(self, update: Update, context: CallbackContext) -> None:
user = update.effective_user user = update.effective_user
message = update.effective_message message = update.effective_message

View File

@ -1,16 +1,18 @@
import asyncio
import time import time
from functools import wraps from functools import wraps
from typing import Callable, cast from typing import Callable, cast, Optional, Any
from telegram import Update from telegram import Update
from telegram.error import TelegramError
from telegram.ext import filters, CallbackContext from telegram.ext import filters, CallbackContext
from utils.log import logger from utils.log import logger
_lock = asyncio.Lock()
def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_message: bool = False,
restricts_time: int = 5): def restricts(restricts_time: int = 9, restricts_time_of_groups: Optional[int] = None, return_data: Any = None,
without_overlapping: bool = False):
"""用于装饰在指定函数预防洪水攻击的装饰器 """用于装饰在指定函数预防洪水攻击的装饰器
被修饰的函数生声明必须为 被修饰的函数生声明必须为
@ -28,11 +30,10 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_
我真是服了某些闲着没事干的群友了 我真是服了某些闲着没事干的群友了
:param filters_chat: 要限制的群 :param restricts_time: 基础限制时间
:param return_data: :param restricts_time_of_groups: 对群限制的时间
:param try_delete_message: :param return_data: 返回的数据对于 ConversationHandler 需要传入 ConversationHandler.END
:param restricts_time: :param without_overlapping: 两次命令时间不覆盖在上一条一样的命令返回之前忽略重复调用
:return: return_data
""" """
def decorator(func: Callable): def decorator(func: Callable):
@ -50,10 +51,27 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_
context = cast(CallbackContext, context) context = cast(CallbackContext, context)
message = update.effective_message message = update.effective_message
user = update.effective_user user = update.effective_user
if filters_chat.filter(message):
_restricts_time = restricts_time
if restricts_time_of_groups is not None:
if filters.ChatType.GROUPS.filter(message):
_restricts_time = restricts_time_of_groups
async with _lock:
user_lock = context.user_data.get("lock")
if user_lock is None:
user_lock = context.user_data["lock"] = asyncio.Lock()
# 如果上一个命令还未完成,忽略后续重复调用
if without_overlapping and user_lock.locked():
logger.warning(f"用户 {user.full_name}[{user.id}] 触发 overlapping 该次命令已忽略")
return return_data
async with user_lock:
command_time = context.user_data.get("command_time", 0) command_time = context.user_data.get("command_time", 0)
count = context.user_data.get("usage_count", 0) count = context.user_data.get("usage_count", 0)
restrict_since = context.user_data.get("restrict_since", 0) restrict_since = context.user_data.get("restrict_since", 0)
# 洪水防御 # 洪水防御
if restrict_since: if restrict_since:
if (time.time() - restrict_since) >= 60 * 5: if (time.time() - restrict_since) >= 60 * 5:
@ -62,29 +80,26 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_
else: else:
return return_data return return_data
else: else:
if count == 5: if count >= 6:
context.user_data["restrict_since"] = time.time() context.user_data["restrict_since"] = time.time()
await message.reply_text("你已经触发洪水防御请等待5分钟") await message.reply_text("你已经触发洪水防御请等待5分钟")
logger.warning(f"用户 {user.full_name}[{user.id}] 触发洪水限制 已被限制5分钟") logger.warning(f"用户 {user.full_name}[{user.id}] 触发洪水限制 已被限制5分钟")
return return_data return return_data
# 单次使用限制 # 单次使用限制
if command_time: if command_time:
if (time.time() - command_time) <= restricts_time: if (time.time() - command_time) <= _restricts_time:
context.user_data["usage_count"] = count + 1 context.user_data["usage_count"] = count + 1
if filters.ChatType.GROUPS.filter(message):
if try_delete_message:
try:
await message.delete()
except TelegramError as exc:
logger.warning("删除消息失败")
logger.exception(exc)
return return_data
else: else:
if count >= 1: if count >= 1:
context.user_data["usage_count"] = count - 1 context.user_data["usage_count"] = count - 1
context.user_data["command_time"] = time.time() context.user_data["command_time"] = time.time()
# 只需要给 without_overlapping 的代码加锁运行
if without_overlapping:
return await func(*args, **kwargs)
if count > 1:
await asyncio.sleep(count)
return await func(*args, **kwargs) return await func(*args, **kwargs)
return restricts_func return restricts_func