🔖 Update to v1.1.8

支持 命令预处理/后处理 钩子
This commit is contained in:
xtaodada 2022-07-05 14:56:00 +08:00
parent 08a3fc5a5f
commit b024004a94
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
4 changed files with 63 additions and 12 deletions

View File

@ -1,5 +1,5 @@
import contextlib import contextlib
from typing import Callable, Awaitable, Set from typing import Callable, Awaitable, Set, Dict
from coloredlogs import ColoredFormatter from coloredlogs import ColoredFormatter
from datetime import datetime, timezone from datetime import datetime, timezone
@ -14,15 +14,15 @@ import pyromod.listen
from pyrogram import Client from pyrogram import Client
import sys import sys
pgm_version = "1.1.7" pgm_version = "1.1.8"
CMD_LIST = {} CMD_LIST = {}
module_dir = __path__[0] module_dir = __path__[0]
working_dir = getcwd() working_dir = getcwd()
# solve same process # solve same process
read_context = {} read_context = {}
help_messages = {} help_messages = {}
startup_functions: Set[Callable[[], Awaitable[None]]] = set() hook_functions: Dict[str, Set[Callable[[], Awaitable[None]]]] = {
shutdown_functions: Set[Callable[[], Awaitable[None]]] = set() "startup": set(), "shutdown": set(), "command_pre": set(), "command_post": set()}
all_permissions = [] all_permissions = []
logs = getLogger(__name__) logs = getLogger(__name__)

View File

@ -1,6 +1,9 @@
import asyncio import asyncio
from pagermaid import startup_functions, shutdown_functions, logs from pyrogram import StopPropagation
from pagermaid import hook_functions, logs
from pagermaid.single_utils import Message
class Hook: class Hook:
@ -10,7 +13,7 @@ class Hook:
注册一个启动钩子 注册一个启动钩子
""" """
def decorator(function): def decorator(function):
startup_functions.add(function) hook_functions["startup"].add(function)
return function return function
return decorator return decorator
@ -20,13 +23,37 @@ class Hook:
注册一个关闭钩子 注册一个关闭钩子
""" """
def decorator(function): def decorator(function):
shutdown_functions.add(function) hook_functions["shutdown"].add(function)
return function return function
return decorator return decorator
@staticmethod
def command_preprocessor():
"""
注册一个命令预处理钩子
"""
def decorator(function):
hook_functions["command_pre"].add(function)
return function
return decorator
@staticmethod
def command_postprocessor():
"""
注册一个命令后处理钩子
"""
def decorator(function):
hook_functions["command_post"].add(function)
return function
return decorator
@staticmethod @staticmethod
async def startup(): async def startup():
if cors := [startup() for startup in startup_functions]: if cors := [startup() for startup in hook_functions["startup"]]:
try: try:
await asyncio.gather(*cors) await asyncio.gather(*cors)
except Exception as exception: except Exception as exception:
@ -34,8 +61,28 @@ class Hook:
@staticmethod @staticmethod
async def shutdown(): async def shutdown():
if cors := [shutdown() for shutdown in shutdown_functions]: if cors := [shutdown() for shutdown in hook_functions["shutdown"]]:
try: try:
await asyncio.gather(*cors) await asyncio.gather(*cors)
except Exception as exception: except Exception as exception:
logs.info(f"[shutdown]: {type(exception)}: {exception}") logs.info(f"[shutdown]: {type(exception)}: {exception}")
@staticmethod
async def command_pre(message: Message):
if cors := [pre(message) for pre in hook_functions["command_pre"]]: # noqa
try:
await asyncio.gather(*cors)
except StopPropagation as e:
raise StopPropagation from e
except Exception as exception:
logs.info(f"[command_pre]: {type(exception)}: {exception}")
@staticmethod
async def command_post(message: Message):
if cors := [post(message) for post in hook_functions["command_post"]]: # noqa
try:
await asyncio.gather(*cors)
except StopPropagation as e:
raise StopPropagation from e
except Exception as exception:
logs.info(f"[command_post]: {type(exception)}: {exception}")

View File

@ -137,10 +137,14 @@ def listener(**args):
raise ContinuePropagation raise ContinuePropagation
read_context[(message.chat.id, message.id)] = True read_context[(message.chat.id, message.id)] = True
if command:
await Hook.command_pre(message)
if function.__code__.co_argcount == 1: if function.__code__.co_argcount == 1:
await function(message) await function(message)
elif function.__code__.co_argcount == 2: elif function.__code__.co_argcount == 2:
await function(client, message) await function(client, message)
if command:
await Hook.command_post(message)
except StopPropagation as e: except StopPropagation as e:
raise StopPropagation from e raise StopPropagation from e
except KeyboardInterrupt as e: except KeyboardInterrupt as e:

View File

@ -3,7 +3,7 @@ import importlib
import pagermaid.config import pagermaid.config
import pagermaid.modules import pagermaid.modules
from pagermaid import bot, logs, help_messages, all_permissions, startup_functions, shutdown_functions from pagermaid import bot, logs, help_messages, all_permissions, hook_functions
from pagermaid.listener import listener from pagermaid.listener import listener
from pagermaid.utils import lang, Message from pagermaid.utils import lang, Message
@ -20,8 +20,8 @@ def reload_all():
importlib.reload(pagermaid.config) importlib.reload(pagermaid.config)
help_messages.clear() help_messages.clear()
all_permissions.clear() all_permissions.clear()
startup_functions.clear() for functions in hook_functions.values():
shutdown_functions.clear() functions.clear() # noqa: clear all hooks
for module_name in pagermaid.modules.module_list: for module_name in pagermaid.modules.module_list:
try: try: