mirror of
https://github.com/PaiGramTeam/PamGram.git
synced 2024-11-25 07:36:27 +00:00
♻️ 重写 restricts
修饰器
* 🔧 `restricts` 修饰器增加 `no_overlapping` 参数,避免同一个 user 对一个 handler 的多次调用时间重叠
* ♻ 移除 `try_delete_message` 参数 使用 `sleep` 替代
* 使用 `asyncio.Lock` 保证 `context.user_data` 的计数器 `usage_count` 线程安全
* 增加 `no_overlapping` 参数,在上一条一样的命令返回之前,忽略重复调用
Co-authored-by: 洛水居室 <luoshuijs@outlook.com>
This commit is contained in:
parent
649c647d9e
commit
87c7253e3a
@ -48,8 +48,7 @@ class Gacha(Plugin, BasePlugin):
|
||||
@handler(CommandHandler, command="gacha", block=False)
|
||||
@handler(MessageHandler, filters=filters.Regex("^深渊数据查询(.*)"), block=False)
|
||||
@handler(MessageHandler, filters=filters.Regex("^非首模拟器(.*)"), block=False)
|
||||
@restricts(filters.ChatType.GROUPS, restricts_time=20, try_delete_message=True)
|
||||
@restricts(filters.ChatType.PRIVATE)
|
||||
@restricts(restricts_time=3, restricts_time_of_groups=20)
|
||||
@error_callable
|
||||
async def command_start(self, update: Update, context: CallbackContext) -> None:
|
||||
message = update.message
|
||||
|
@ -55,8 +55,7 @@ class PlayerCards(Plugin, BasePlugin):
|
||||
|
||||
@handler(CommandHandler, command="player_card", block=False)
|
||||
@handler(MessageHandler, filters=filters.Regex("^角色卡片查询(.*)"), block=False)
|
||||
@restricts(filters.ChatType.GROUPS, restricts_time=20, try_delete_message=True)
|
||||
@restricts(filters.ChatType.PRIVATE)
|
||||
@restricts(restricts_time_of_groups=20, without_overlapping=True)
|
||||
@error_callable
|
||||
async def player_cards(self, update: Update, context: CallbackContext) -> None:
|
||||
user = update.effective_user
|
||||
@ -119,8 +118,7 @@ class PlayerCards(Plugin, BasePlugin):
|
||||
await message.reply_photo(pnd_data, filename=f"player_card_{uid}_{character_name}.png")
|
||||
|
||||
@handler(CallbackQueryHandler, pattern=r"^get_player_card\|", block=False)
|
||||
@restricts(filters.ChatType.GROUPS, restricts_time=6)
|
||||
@restricts(filters.ChatType.PRIVATE, restricts_time=6)
|
||||
@restricts(restricts_time_of_groups=20, without_overlapping=True)
|
||||
async def get_player_cards(self, update: Update, _: CallbackContext) -> None:
|
||||
callback_query = update.callback_query
|
||||
user = callback_query.from_user
|
||||
|
@ -23,7 +23,7 @@ class QuizPlugin(Plugin, BasePlugin):
|
||||
self.random = MT19937Random()
|
||||
|
||||
@handler(CommandHandler, command="quiz", block=False)
|
||||
@restricts(restricts_time=20, try_delete_message=True)
|
||||
@restricts(restricts_time_of_groups=20)
|
||||
async def command_start(self, update: Update, context: CallbackContext) -> None:
|
||||
user = update.effective_user
|
||||
message = update.effective_message
|
||||
|
@ -1,16 +1,18 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Callable, cast
|
||||
from typing import Callable, cast, Optional, Any
|
||||
|
||||
from telegram import Update
|
||||
from telegram.error import TelegramError
|
||||
from telegram.ext import filters, CallbackContext
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_message: bool = False,
|
||||
restricts_time: int = 5):
|
||||
|
||||
def restricts(restricts_time: int = 9, restricts_time_of_groups: Optional[int] = None, return_data: Any = None,
|
||||
without_overlapping: bool = False):
|
||||
"""用于装饰在指定函数预防洪水攻击的装饰器
|
||||
|
||||
被修饰的函数生声明必须为
|
||||
@ -28,11 +30,10 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_
|
||||
|
||||
我真™是服了某些闲着没事干的群友了
|
||||
|
||||
:param filters_chat: 要限制的群
|
||||
:param return_data:
|
||||
:param try_delete_message:
|
||||
:param restricts_time:
|
||||
:return: return_data
|
||||
:param restricts_time: 基础限制时间
|
||||
:param restricts_time_of_groups: 对群限制的时间
|
||||
:param return_data: 返回的数据对于 ConversationHandler 需要传入 ConversationHandler.END
|
||||
:param without_overlapping: 两次命令时间不覆盖,在上一条一样的命令返回之前,忽略重复调用
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@ -50,10 +51,27 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_
|
||||
context = cast(CallbackContext, context)
|
||||
message = update.effective_message
|
||||
user = update.effective_user
|
||||
if filters_chat.filter(message):
|
||||
|
||||
_restricts_time = restricts_time
|
||||
if restricts_time_of_groups is not None:
|
||||
if filters.ChatType.GROUPS.filter(message):
|
||||
_restricts_time = restricts_time_of_groups
|
||||
|
||||
async with _lock:
|
||||
user_lock = context.user_data.get("lock")
|
||||
if user_lock is None:
|
||||
user_lock = context.user_data["lock"] = asyncio.Lock()
|
||||
|
||||
# 如果上一个命令还未完成,忽略后续重复调用
|
||||
if without_overlapping and user_lock.locked():
|
||||
logger.warning(f"用户 {user.full_name}[{user.id}] 触发 overlapping 该次命令已忽略")
|
||||
return return_data
|
||||
|
||||
async with user_lock:
|
||||
command_time = context.user_data.get("command_time", 0)
|
||||
count = context.user_data.get("usage_count", 0)
|
||||
restrict_since = context.user_data.get("restrict_since", 0)
|
||||
|
||||
# 洪水防御
|
||||
if restrict_since:
|
||||
if (time.time() - restrict_since) >= 60 * 5:
|
||||
@ -62,29 +80,26 @@ def restricts(filters_chat: filters = filters.ALL, return_data=None, try_delete_
|
||||
else:
|
||||
return return_data
|
||||
else:
|
||||
if count == 5:
|
||||
if count >= 6:
|
||||
context.user_data["restrict_since"] = time.time()
|
||||
await message.reply_text("你已经触发洪水防御,请等待5分钟")
|
||||
logger.warning(f"用户 {user.full_name}[{user.id}] 触发洪水限制 已被限制5分钟")
|
||||
return return_data
|
||||
# 单次使用限制
|
||||
if command_time:
|
||||
if (time.time() - command_time) <= restricts_time:
|
||||
if (time.time() - command_time) <= _restricts_time:
|
||||
context.user_data["usage_count"] = count + 1
|
||||
if filters.ChatType.GROUPS.filter(message):
|
||||
if try_delete_message:
|
||||
try:
|
||||
await message.delete()
|
||||
except TelegramError as exc:
|
||||
logger.warning("删除消息失败")
|
||||
logger.exception(exc)
|
||||
return return_data
|
||||
else:
|
||||
if count >= 1:
|
||||
context.user_data["usage_count"] = count - 1
|
||||
|
||||
context.user_data["command_time"] = time.time()
|
||||
|
||||
# 只需要给 without_overlapping 的代码加锁运行
|
||||
if without_overlapping:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if count > 1:
|
||||
await asyncio.sleep(count)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return restricts_func
|
||||
|
Loading…
Reference in New Issue
Block a user