diff --git a/pagermaid/__init__.py b/pagermaid/__init__.py index f337e80..6aa03dc 100644 --- a/pagermaid/__init__.py +++ b/pagermaid/__init__.py @@ -1,4 +1,6 @@ import contextlib +from typing import Callable, Awaitable, Set + from coloredlogs import ColoredFormatter from datetime import datetime, timezone from logging import getLogger, StreamHandler, CRITICAL, INFO, basicConfig, DEBUG @@ -12,13 +14,15 @@ import pyromod.listen from pyrogram import Client import sys -pgm_version = "1.1.2" +pgm_version = "1.1.3" CMD_LIST = {} module_dir = __path__[0] working_dir = getcwd() # solve same process read_context = {} help_messages = {} +startup_functions: Set[Callable[[], Awaitable[None]]] = set() +shutdown_functions: Set[Callable[[], Awaitable[None]]] = set() all_permissions = [] logs = getLogger(__name__) diff --git a/pagermaid/__main__.py b/pagermaid/__main__.py index 9b9416e..22196d6 100644 --- a/pagermaid/__main__.py +++ b/pagermaid/__main__.py @@ -5,6 +5,7 @@ from importlib import import_module from pyrogram import idle from pagermaid import bot, logs, working_dir +from pagermaid.hook import Hook from pagermaid.modules import module_list, plugin_list from pagermaid.utils import lang, process_exit @@ -29,8 +30,9 @@ async def main(): plugin_list.remove(plugin_name) await process_exit(start=True, _client=bot) - logs.info(lang('start')) + await Hook.startup() + await idle() await bot.stop() diff --git a/pagermaid/group_manager.py b/pagermaid/group_manager.py index a99a40b..815e2a5 100644 --- a/pagermaid/group_manager.py +++ b/pagermaid/group_manager.py @@ -1,4 +1,5 @@ import casbin +from logging import CRITICAL from shutil import copyfile from os import path as os_path from re import findall @@ -10,6 +11,7 @@ from pagermaid import all_permissions, module_dir if not os_path.exists(f"data{os_path.sep}gm_policy.csv"): copyfile(f"{module_dir}{os_path.sep}assets{os_path.sep}gm_policy.csv", f"data{os_path.sep}gm_policy.csv") permissions = casbin.Enforcer(f"pagermaid{sep}assets{sep}gm_model.conf", f"data{sep}gm_policy.csv") +permissions.logger.setLevel(CRITICAL) class Permission: diff --git a/pagermaid/hook.py b/pagermaid/hook.py new file mode 100644 index 0000000..eae9515 --- /dev/null +++ b/pagermaid/hook.py @@ -0,0 +1,41 @@ +import asyncio + +from pagermaid import startup_functions, shutdown_functions, logs + + +class Hook: + @staticmethod + def on_startup(): + """ + 注册一个启动钩子 + """ + def decorator(function): + startup_functions.add(function) + return function + return decorator + + @staticmethod + def on_shutdown(): + """ + 注册一个关闭钩子 + """ + def decorator(function): + shutdown_functions.add(function) + return function + return decorator + + @staticmethod + async def startup(): + if cors := [startup() for startup in startup_functions]: + try: + await asyncio.gather(*cors) + except Exception as exception: + logs.info(f"[startup]: {type(exception)}: {exception}") + + @staticmethod + async def shutdown(): + if cors := [shutdown() for shutdown in shutdown_functions]: + try: + await asyncio.gather(*cors) + except Exception as exception: + logs.info(f"[shutdown]: {type(exception)}: {exception}") diff --git a/pagermaid/listener.py b/pagermaid/listener.py index be8e6f1..dc156e3 100644 --- a/pagermaid/listener.py +++ b/pagermaid/listener.py @@ -19,6 +19,7 @@ from pagermaid.group_manager import Permission from pagermaid.single_utils import Message, AlreadyInConversationError, TimeoutConversationError from pagermaid.utils import lang, attach_report, sudo_filter, alias_command, get_permission_name, process_exit from pagermaid.utils import client as httpx_client +from pagermaid.hook import Hook secret_generator = secrets.SystemRandom() @@ -137,18 +138,21 @@ def listener(**args): logs.warning( "Please Don't Send Commands In The Same Conversation.." ) - await message.edit(lang("conversation_already_in_error")) + with contextlib.suppress(BaseException): + await message.edit(lang("conversation_already_in_error")) except TimeoutConversationError: logs.warning( "Conversation Timed out while processing commands.." ) - await message.edit(lang("conversation_timed_out_error")) + with contextlib.suppress(BaseException): + await message.edit(lang("conversation_timed_out_error")) except UserNotParticipant: pass except ContinuePropagation as e: raise ContinuePropagation from e except SystemExit: await process_exit(start=False, _client=client, message=message) + await Hook.shutdown() sys.exit(0) except BaseException: exc_info = sys.exc_info()[1] @@ -213,14 +217,17 @@ def raw_listener(filter_s): logs.warning( "Please Don't Send Commands In The Same Conversation.." ) - await message.edit(lang("conversation_already_in_error")) + with contextlib.suppress(BaseException): + await message.edit(lang("conversation_already_in_error")) except TimeoutConversationError: logs.warning( "Conversation Timed out while processing commands.." ) - await message.edit(lang("conversation_timed_out_error")) + with contextlib.suppress(BaseException): + await message.edit(lang("conversation_timed_out_error")) except SystemExit: await process_exit(start=False, _client=client, message=message) + await Hook.shutdown() sys.exit(0) except UserNotParticipant: pass diff --git a/pagermaid/modules/reload.py b/pagermaid/modules/reload.py index 9ffc156..8a4ed9c 100644 --- a/pagermaid/modules/reload.py +++ b/pagermaid/modules/reload.py @@ -2,7 +2,7 @@ import importlib import pagermaid.config import pagermaid.modules -from pagermaid import bot, logs, help_messages, all_permissions +from pagermaid import bot, logs, help_messages, all_permissions, startup_functions, shutdown_functions from pagermaid.listener import listener from pagermaid.utils import lang, Message @@ -17,8 +17,9 @@ def reload_all(): importlib.reload(pagermaid.config) help_messages.clear() all_permissions.clear() + startup_functions.clear() + shutdown_functions.clear() - importlib.reload(pagermaid.modules) for module_name in pagermaid.modules.module_list: try: module = importlib.import_module(f"pagermaid.modules.{module_name}")