🔖 Update to v1.1.8

支持 命令预处理/后处理 钩子
This commit is contained in:
xtaodada 2022-07-05 14:56:00 +08:00
parent 08a3fc5a5f
commit b024004a94
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
4 changed files with 63 additions and 12 deletions

View File

@ -1,5 +1,5 @@
import contextlib
from typing import Callable, Awaitable, Set
from typing import Callable, Awaitable, Set, Dict
from coloredlogs import ColoredFormatter
from datetime import datetime, timezone
@ -14,15 +14,15 @@ import pyromod.listen
from pyrogram import Client
import sys
pgm_version = "1.1.7"
pgm_version = "1.1.8"
CMD_LIST = {}
module_dir = __path__[0]
working_dir = getcwd()
# solve same process
read_context = {}
help_messages = {}
startup_functions: Set[Callable[[], Awaitable[None]]] = set()
shutdown_functions: Set[Callable[[], Awaitable[None]]] = set()
hook_functions: Dict[str, Set[Callable[[], Awaitable[None]]]] = {
"startup": set(), "shutdown": set(), "command_pre": set(), "command_post": set()}
all_permissions = []
logs = getLogger(__name__)

View File

@ -1,6 +1,9 @@
import asyncio
from pagermaid import startup_functions, shutdown_functions, logs
from pyrogram import StopPropagation
from pagermaid import hook_functions, logs
from pagermaid.single_utils import Message
class Hook:
@ -10,7 +13,7 @@ class Hook:
注册一个启动钩子
"""
def decorator(function):
startup_functions.add(function)
hook_functions["startup"].add(function)
return function
return decorator
@ -20,13 +23,37 @@ class Hook:
注册一个关闭钩子
"""
def decorator(function):
shutdown_functions.add(function)
hook_functions["shutdown"].add(function)
return function
return decorator
@staticmethod
def command_preprocessor():
"""
注册一个命令预处理钩子
"""
def decorator(function):
hook_functions["command_pre"].add(function)
return function
return decorator
@staticmethod
def command_postprocessor():
"""
注册一个命令后处理钩子
"""
def decorator(function):
hook_functions["command_post"].add(function)
return function
return decorator
@staticmethod
async def startup():
if cors := [startup() for startup in startup_functions]:
if cors := [startup() for startup in hook_functions["startup"]]:
try:
await asyncio.gather(*cors)
except Exception as exception:
@ -34,8 +61,28 @@ class Hook:
@staticmethod
async def shutdown():
if cors := [shutdown() for shutdown in shutdown_functions]:
if cors := [shutdown() for shutdown in hook_functions["shutdown"]]:
try:
await asyncio.gather(*cors)
except Exception as exception:
logs.info(f"[shutdown]: {type(exception)}: {exception}")
@staticmethod
async def command_pre(message: Message):
if cors := [pre(message) for pre in hook_functions["command_pre"]]: # noqa
try:
await asyncio.gather(*cors)
except StopPropagation as e:
raise StopPropagation from e
except Exception as exception:
logs.info(f"[command_pre]: {type(exception)}: {exception}")
@staticmethod
async def command_post(message: Message):
if cors := [post(message) for post in hook_functions["command_post"]]: # noqa
try:
await asyncio.gather(*cors)
except StopPropagation as e:
raise StopPropagation from e
except Exception as exception:
logs.info(f"[command_post]: {type(exception)}: {exception}")

View File

@ -137,10 +137,14 @@ def listener(**args):
raise ContinuePropagation
read_context[(message.chat.id, message.id)] = True
if command:
await Hook.command_pre(message)
if function.__code__.co_argcount == 1:
await function(message)
elif function.__code__.co_argcount == 2:
await function(client, message)
if command:
await Hook.command_post(message)
except StopPropagation as e:
raise StopPropagation from e
except KeyboardInterrupt as e:

View File

@ -3,7 +3,7 @@ import importlib
import pagermaid.config
import pagermaid.modules
from pagermaid import bot, logs, help_messages, all_permissions, startup_functions, shutdown_functions
from pagermaid import bot, logs, help_messages, all_permissions, hook_functions
from pagermaid.listener import listener
from pagermaid.utils import lang, Message
@ -20,8 +20,8 @@ def reload_all():
importlib.reload(pagermaid.config)
help_messages.clear()
all_permissions.clear()
startup_functions.clear()
shutdown_functions.clear()
for functions in hook_functions.values():
functions.clear() # noqa: clear all hooks
for module_name in pagermaid.modules.module_list:
try: