diff --git a/pagermaid/__init__.py b/pagermaid/__init__.py index 1440398..7e6b0a6 100644 --- a/pagermaid/__init__.py +++ b/pagermaid/__init__.py @@ -14,7 +14,7 @@ import pyromod.listen from pyrogram import Client import sys -pgm_version = "1.2.2" +pgm_version = "1.2.3" CMD_LIST = {} module_dir = __path__[0] working_dir = getcwd() diff --git a/pagermaid/enums/__init__.py b/pagermaid/enums/__init__.py new file mode 100644 index 0000000..82dd64c --- /dev/null +++ b/pagermaid/enums/__init__.py @@ -0,0 +1,7 @@ +from pagermaid.single_utils import Client +from pagermaid.single_utils import Message +from pagermaid.sub_utils import Sub +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from sqlitedict import SqliteDict +from httpx import AsyncClient +from logging import Logger diff --git a/pagermaid/hook.py b/pagermaid/hook.py index 9806655..089a699 100644 --- a/pagermaid/hook.py +++ b/pagermaid/hook.py @@ -3,6 +3,7 @@ import asyncio from pyrogram import StopPropagation from pagermaid import hook_functions, logs +from pagermaid.inject import inject from pagermaid.single_utils import Message @@ -65,7 +66,7 @@ class Hook: @staticmethod async def startup(): - if cors := [startup() for startup in hook_functions["startup"]]: + if cors := [startup(**inject(None, startup)) for startup in hook_functions["startup"]]: # noqa try: await asyncio.gather(*cors) except Exception as exception: @@ -73,7 +74,7 @@ class Hook: @staticmethod async def shutdown(): - if cors := [shutdown() for shutdown in hook_functions["shutdown"]]: + if cors := [shutdown(**inject(None, shutdown)) for shutdown in hook_functions["shutdown"]]: # noqa try: await asyncio.gather(*cors) except Exception as exception: @@ -81,7 +82,7 @@ class Hook: @staticmethod async def command_pre(message: Message): - if cors := [pre(message) for pre in hook_functions["command_pre"]]: # noqa + if cors := [pre(**inject(message, pre)) for pre in hook_functions["command_pre"]]: # noqa try: await asyncio.gather(*cors) except StopPropagation as e: @@ -91,7 +92,7 @@ class Hook: @staticmethod async def command_post(message: Message): - if cors := [post(message) for post in hook_functions["command_post"]]: # noqa + if cors := [post(**inject(message, post)) for post in hook_functions["command_post"]]: # noqa try: await asyncio.gather(*cors) except StopPropagation as e: @@ -101,7 +102,7 @@ class Hook: @staticmethod async def process_error_exec(message: Message, exc_info: BaseException, exc_format: str): - if cors := [error(message, exc_info, exc_format) for error in hook_functions["process_error"]]: # noqa + if cors := [error(**inject(message, error, exc_info=exc_info, exc_format=exc_format)) for error in hook_functions["process_error"]]: # noqa try: await asyncio.gather(*cors) except StopPropagation as e: diff --git a/pagermaid/inject.py b/pagermaid/inject.py new file mode 100644 index 0000000..d302b4c --- /dev/null +++ b/pagermaid/inject.py @@ -0,0 +1,21 @@ +import inspect +import pagermaid.enums as enums +import pagermaid.services as services +from typing import Dict, Optional + + +def inject(message: enums.Message, function, **data) -> Optional[Dict]: + try: + signature = inspect.signature(function) + except Exception: + return None + for parameter_name, parameter in signature.parameters.items(): + class_name = parameter.annotation.__name__ + param = message if class_name == "Message" else services.get(class_name) + if not param: + if parameter_name == "message": + param = message + else: + param = services.get(parameter_name.capitalize()) + data.setdefault(parameter_name, param) + return data diff --git a/pagermaid/listener.py b/pagermaid/listener.py index 181c702..9223334 100644 --- a/pagermaid/listener.py +++ b/pagermaid/listener.py @@ -16,6 +16,7 @@ from pyrogram.handlers import MessageHandler, EditedMessageHandler from pagermaid import help_messages, logs, Config, bot, read_context, all_permissions from pagermaid.group_manager import Permission +from pagermaid.inject import inject from pagermaid.single_utils import Message, AlreadyInConversationError, TimeoutConversationError, ListenerCanceled from pagermaid.utils import lang, attach_report, sudo_filter, alias_command, get_permission_name, process_exit from pagermaid.utils import client as httpx_client @@ -139,10 +140,15 @@ def listener(**args): if command: await Hook.command_pre(message) - if function.__code__.co_argcount == 1: - await function(message) - elif function.__code__.co_argcount == 2: - await function(client, message) + if data := inject(message, function): + await function(**data) + else: + if function.__code__.co_argcount == 0: + await function() + if function.__code__.co_argcount == 1: + await function(message) + elif function.__code__.co_argcount == 2: + await function(client, message) if command: await Hook.command_post(message) except StopPropagation as e: diff --git a/pagermaid/services/__init__.py b/pagermaid/services/__init__.py new file mode 100644 index 0000000..40b39e9 --- /dev/null +++ b/pagermaid/services/__init__.py @@ -0,0 +1,10 @@ +from pagermaid import bot +from pagermaid import logs +from pagermaid.single_utils import sqlite +from pagermaid.scheduler import scheduler +from pagermaid.utils import client + + +def get(name: str): + data = {"Client": bot, "Logger": logs, "SqliteDict": sqlite, "AsyncIOScheduler": scheduler, "AsyncClient": client} + return data.get(name, None) diff --git a/requirements.txt b/requirements.txt index 59fb598..2cc6811 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pyrogram==2.0.33 +pyrogram==2.0.35 TgCrypto>=1.2.3 Pillow>=8.4.0 pytz>=2021.3