diff --git a/pagermaid/__init__.py b/pagermaid/__init__.py index bac8f9c..5dfdcc9 100644 --- a/pagermaid/__init__.py +++ b/pagermaid/__init__.py @@ -1,5 +1,5 @@ import contextlib -from typing import Callable, Awaitable, Set +from typing import Callable, Awaitable, Set, Dict from coloredlogs import ColoredFormatter from datetime import datetime, timezone @@ -14,15 +14,15 @@ import pyromod.listen from pyrogram import Client import sys -pgm_version = "1.1.7" +pgm_version = "1.1.8" CMD_LIST = {} module_dir = __path__[0] working_dir = getcwd() # solve same process read_context = {} help_messages = {} -startup_functions: Set[Callable[[], Awaitable[None]]] = set() -shutdown_functions: Set[Callable[[], Awaitable[None]]] = set() +hook_functions: Dict[str, Set[Callable[[], Awaitable[None]]]] = { + "startup": set(), "shutdown": set(), "command_pre": set(), "command_post": set()} all_permissions = [] logs = getLogger(__name__) diff --git a/pagermaid/hook.py b/pagermaid/hook.py index eae9515..fe46206 100644 --- a/pagermaid/hook.py +++ b/pagermaid/hook.py @@ -1,6 +1,9 @@ import asyncio -from pagermaid import startup_functions, shutdown_functions, logs +from pyrogram import StopPropagation + +from pagermaid import hook_functions, logs +from pagermaid.single_utils import Message class Hook: @@ -10,7 +13,7 @@ class Hook: 注册一个启动钩子 """ def decorator(function): - startup_functions.add(function) + hook_functions["startup"].add(function) return function return decorator @@ -20,13 +23,37 @@ class Hook: 注册一个关闭钩子 """ def decorator(function): - shutdown_functions.add(function) + hook_functions["shutdown"].add(function) return function return decorator + @staticmethod + def command_preprocessor(): + """ + 注册一个命令预处理钩子 + """ + + def decorator(function): + hook_functions["command_pre"].add(function) + return function + + return decorator + + @staticmethod + def command_postprocessor(): + """ + 注册一个命令后处理钩子 + """ + + def decorator(function): + hook_functions["command_post"].add(function) + return function + + return decorator + @staticmethod async def startup(): - if cors := [startup() for startup in startup_functions]: + if cors := [startup() for startup in hook_functions["startup"]]: try: await asyncio.gather(*cors) except Exception as exception: @@ -34,8 +61,28 @@ class Hook: @staticmethod async def shutdown(): - if cors := [shutdown() for shutdown in shutdown_functions]: + if cors := [shutdown() for shutdown in hook_functions["shutdown"]]: try: await asyncio.gather(*cors) except Exception as exception: logs.info(f"[shutdown]: {type(exception)}: {exception}") + + @staticmethod + async def command_pre(message: Message): + if cors := [pre(message) for pre in hook_functions["command_pre"]]: # noqa + try: + await asyncio.gather(*cors) + except StopPropagation as e: + raise StopPropagation from e + except Exception as exception: + logs.info(f"[command_pre]: {type(exception)}: {exception}") + + @staticmethod + async def command_post(message: Message): + if cors := [post(message) for post in hook_functions["command_post"]]: # noqa + try: + await asyncio.gather(*cors) + except StopPropagation as e: + raise StopPropagation from e + except Exception as exception: + logs.info(f"[command_post]: {type(exception)}: {exception}") diff --git a/pagermaid/listener.py b/pagermaid/listener.py index 2ed92f9..6736735 100644 --- a/pagermaid/listener.py +++ b/pagermaid/listener.py @@ -137,10 +137,14 @@ def listener(**args): raise ContinuePropagation read_context[(message.chat.id, message.id)] = True + if command: + await Hook.command_pre(message) if function.__code__.co_argcount == 1: await function(message) elif function.__code__.co_argcount == 2: await function(client, message) + if command: + await Hook.command_post(message) except StopPropagation as e: raise StopPropagation from e except KeyboardInterrupt as e: diff --git a/pagermaid/modules/reload.py b/pagermaid/modules/reload.py index fe221f2..6ed47ff 100644 --- a/pagermaid/modules/reload.py +++ b/pagermaid/modules/reload.py @@ -3,7 +3,7 @@ import importlib import pagermaid.config import pagermaid.modules -from pagermaid import bot, logs, help_messages, all_permissions, startup_functions, shutdown_functions +from pagermaid import bot, logs, help_messages, all_permissions, hook_functions from pagermaid.listener import listener from pagermaid.utils import lang, Message @@ -20,8 +20,8 @@ def reload_all(): importlib.reload(pagermaid.config) help_messages.clear() all_permissions.clear() - startup_functions.clear() - shutdown_functions.clear() + for functions in hook_functions.values(): + functions.clear() # noqa: clear all hooks for module_name in pagermaid.modules.module_list: try: