mirror of
https://github.com/TeamPGM/PagerMaid-Pyro.git
synced 2024-11-24 07:20:37 +00:00
🔖 Update to v1.1.8
支持 命令预处理/后处理 钩子
This commit is contained in:
parent
08a3fc5a5f
commit
b024004a94
@ -1,5 +1,5 @@
|
||||
import contextlib
|
||||
from typing import Callable, Awaitable, Set
|
||||
from typing import Callable, Awaitable, Set, Dict
|
||||
|
||||
from coloredlogs import ColoredFormatter
|
||||
from datetime import datetime, timezone
|
||||
@ -14,15 +14,15 @@ import pyromod.listen
|
||||
from pyrogram import Client
|
||||
import sys
|
||||
|
||||
pgm_version = "1.1.7"
|
||||
pgm_version = "1.1.8"
|
||||
CMD_LIST = {}
|
||||
module_dir = __path__[0]
|
||||
working_dir = getcwd()
|
||||
# solve same process
|
||||
read_context = {}
|
||||
help_messages = {}
|
||||
startup_functions: Set[Callable[[], Awaitable[None]]] = set()
|
||||
shutdown_functions: Set[Callable[[], Awaitable[None]]] = set()
|
||||
hook_functions: Dict[str, Set[Callable[[], Awaitable[None]]]] = {
|
||||
"startup": set(), "shutdown": set(), "command_pre": set(), "command_post": set()}
|
||||
all_permissions = []
|
||||
|
||||
logs = getLogger(__name__)
|
||||
|
@ -1,6 +1,9 @@
|
||||
import asyncio
|
||||
|
||||
from pagermaid import startup_functions, shutdown_functions, logs
|
||||
from pyrogram import StopPropagation
|
||||
|
||||
from pagermaid import hook_functions, logs
|
||||
from pagermaid.single_utils import Message
|
||||
|
||||
|
||||
class Hook:
|
||||
@ -10,7 +13,7 @@ class Hook:
|
||||
注册一个启动钩子
|
||||
"""
|
||||
def decorator(function):
|
||||
startup_functions.add(function)
|
||||
hook_functions["startup"].add(function)
|
||||
return function
|
||||
return decorator
|
||||
|
||||
@ -20,13 +23,37 @@ class Hook:
|
||||
注册一个关闭钩子
|
||||
"""
|
||||
def decorator(function):
|
||||
shutdown_functions.add(function)
|
||||
hook_functions["shutdown"].add(function)
|
||||
return function
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def command_preprocessor():
|
||||
"""
|
||||
注册一个命令预处理钩子
|
||||
"""
|
||||
|
||||
def decorator(function):
|
||||
hook_functions["command_pre"].add(function)
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def command_postprocessor():
|
||||
"""
|
||||
注册一个命令后处理钩子
|
||||
"""
|
||||
|
||||
def decorator(function):
|
||||
hook_functions["command_post"].add(function)
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
async def startup():
|
||||
if cors := [startup() for startup in startup_functions]:
|
||||
if cors := [startup() for startup in hook_functions["startup"]]:
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except Exception as exception:
|
||||
@ -34,8 +61,28 @@ class Hook:
|
||||
|
||||
@staticmethod
|
||||
async def shutdown():
|
||||
if cors := [shutdown() for shutdown in shutdown_functions]:
|
||||
if cors := [shutdown() for shutdown in hook_functions["shutdown"]]:
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except Exception as exception:
|
||||
logs.info(f"[shutdown]: {type(exception)}: {exception}")
|
||||
|
||||
@staticmethod
|
||||
async def command_pre(message: Message):
|
||||
if cors := [pre(message) for pre in hook_functions["command_pre"]]: # noqa
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except StopPropagation as e:
|
||||
raise StopPropagation from e
|
||||
except Exception as exception:
|
||||
logs.info(f"[command_pre]: {type(exception)}: {exception}")
|
||||
|
||||
@staticmethod
|
||||
async def command_post(message: Message):
|
||||
if cors := [post(message) for post in hook_functions["command_post"]]: # noqa
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except StopPropagation as e:
|
||||
raise StopPropagation from e
|
||||
except Exception as exception:
|
||||
logs.info(f"[command_post]: {type(exception)}: {exception}")
|
||||
|
@ -137,10 +137,14 @@ def listener(**args):
|
||||
raise ContinuePropagation
|
||||
read_context[(message.chat.id, message.id)] = True
|
||||
|
||||
if command:
|
||||
await Hook.command_pre(message)
|
||||
if function.__code__.co_argcount == 1:
|
||||
await function(message)
|
||||
elif function.__code__.co_argcount == 2:
|
||||
await function(client, message)
|
||||
if command:
|
||||
await Hook.command_post(message)
|
||||
except StopPropagation as e:
|
||||
raise StopPropagation from e
|
||||
except KeyboardInterrupt as e:
|
||||
|
@ -3,7 +3,7 @@ import importlib
|
||||
import pagermaid.config
|
||||
import pagermaid.modules
|
||||
|
||||
from pagermaid import bot, logs, help_messages, all_permissions, startup_functions, shutdown_functions
|
||||
from pagermaid import bot, logs, help_messages, all_permissions, hook_functions
|
||||
from pagermaid.listener import listener
|
||||
from pagermaid.utils import lang, Message
|
||||
|
||||
@ -20,8 +20,8 @@ def reload_all():
|
||||
importlib.reload(pagermaid.config)
|
||||
help_messages.clear()
|
||||
all_permissions.clear()
|
||||
startup_functions.clear()
|
||||
shutdown_functions.clear()
|
||||
for functions in hook_functions.values():
|
||||
functions.clear() # noqa: clear all hooks
|
||||
|
||||
for module_name in pagermaid.modules.module_list:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user