mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-11-23 00:01:45 +00:00
233e7ab58d
Co-authored-by: luoshuijs <luoshuijs@outlook.com> Co-authored-by: Karako <karakohear@gmail.com> Co-authored-by: xtaodada <xtao@xtaolink.cn>
310 lines
9.8 KiB
Python
310 lines
9.8 KiB
Python
"""参数分发器"""
|
|
import asyncio
|
|
import inspect
|
|
from abc import ABC, abstractmethod
|
|
from asyncio import AbstractEventLoop
|
|
from functools import cached_property, lru_cache, partial, wraps
|
|
from inspect import Parameter, Signature
|
|
from itertools import chain
|
|
from types import GenericAlias, MethodType
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
from arkowrapper import ArkoWrapper
|
|
from fastapi import FastAPI
|
|
from telegram import Bot as TelegramBot, Chat, Message, Update, User
|
|
from telegram.ext import Application as TelegramApplication, CallbackContext, Job
|
|
from typing_extensions import ParamSpec
|
|
from uvicorn import Server
|
|
|
|
from core.application import Application
|
|
from utils.const import WRAPPER_ASSIGNMENTS
|
|
from utils.typedefs import R, T
|
|
|
|
__all__ = (
|
|
"catch",
|
|
"AbstractDispatcher",
|
|
"BaseDispatcher",
|
|
"HandlerDispatcher",
|
|
"JobDispatcher",
|
|
"dispatched",
|
|
)
|
|
|
|
P = ParamSpec("P")
|
|
|
|
TargetType = Union[Type, str, Callable[[Any], bool]]
|
|
|
|
_CATCH_TARGET_ATTR = "_catch_targets"
|
|
|
|
|
|
def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
|
setattr(func, _CATCH_TARGET_ATTR, targets)
|
|
|
|
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorate
|
|
|
|
|
|
@lru_cache(64)
|
|
def get_signature(func: Union[type, Callable]) -> Signature:
|
|
if isinstance(func, type):
|
|
return inspect.signature(func.__init__)
|
|
return inspect.signature(func)
|
|
|
|
|
|
class AbstractDispatcher(ABC):
|
|
"""参数分发器"""
|
|
|
|
IGNORED_ATTRS = []
|
|
|
|
_args: List[Any] = []
|
|
_kwargs: Dict[Union[str, Type], Any] = {}
|
|
_application: "Optional[Application]" = None
|
|
|
|
def set_application(self, application: "Application") -> None:
|
|
self._application = application
|
|
|
|
@property
|
|
def application(self) -> "Application":
|
|
if self._application is None:
|
|
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
|
return self._application
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
self._args = list(args)
|
|
self._kwargs = dict(kwargs)
|
|
|
|
for _, value in kwargs.items():
|
|
type_arg = type(value)
|
|
if type_arg != str:
|
|
self._kwargs[type_arg] = value
|
|
|
|
for arg in args:
|
|
type_arg = type(arg)
|
|
if type_arg != str:
|
|
self._kwargs[type_arg] = arg
|
|
|
|
@cached_property
|
|
def catch_funcs(self) -> List[MethodType]:
|
|
# noinspection PyTypeChecker
|
|
return list(
|
|
ArkoWrapper(dir(self))
|
|
.filter(lambda x: not x.startswith("_"))
|
|
.filter(
|
|
lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"]
|
|
)
|
|
.map(lambda x: getattr(self, x))
|
|
.filter(lambda x: isinstance(x, MethodType))
|
|
.filter(lambda x: hasattr(x, "_catch_targets"))
|
|
)
|
|
|
|
@cached_property
|
|
def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]:
|
|
result = {}
|
|
for catch_func in self.catch_funcs:
|
|
catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR)
|
|
for catch_target in catch_targets:
|
|
result[catch_target] = catch_func
|
|
return result
|
|
|
|
@cached_property
|
|
def dispatch_funcs(self) -> List[MethodType]:
|
|
return list(
|
|
ArkoWrapper(dir(self))
|
|
.filter(lambda x: x.startswith("dispatch_by_"))
|
|
.map(lambda x: getattr(self, x))
|
|
.filter(lambda x: isinstance(x, MethodType))
|
|
)
|
|
|
|
@abstractmethod
|
|
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
|
"""默认的 dispatch 方法"""
|
|
|
|
@abstractmethod
|
|
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
|
"""使用 catch_func 获取并分配参数"""
|
|
|
|
def dispatch(self, func: Callable[P, R]) -> Callable[..., R]:
|
|
"""将参数分配给函数,从而合成一个无需参数即可执行的函数"""
|
|
params = {}
|
|
signature = get_signature(func)
|
|
parameters: Dict[str, Parameter] = dict(signature.parameters)
|
|
|
|
for name, parameter in list(parameters.items()):
|
|
parameter: Parameter
|
|
if any(
|
|
[
|
|
name == "self" and isinstance(func, (type, MethodType)),
|
|
parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL],
|
|
]
|
|
):
|
|
del parameters[name]
|
|
continue
|
|
|
|
for dispatch_func in self.dispatch_funcs:
|
|
parameters[name] = dispatch_func(parameter)
|
|
|
|
for name, parameter in parameters.items():
|
|
if parameter.default != Parameter.empty:
|
|
params[name] = parameter.default
|
|
else:
|
|
params[name] = None
|
|
|
|
return partial(func, **params)
|
|
|
|
@catch(Application)
|
|
def catch_application(self) -> Application:
|
|
return self.application
|
|
|
|
|
|
class BaseDispatcher(AbstractDispatcher):
|
|
"""默认参数分发器"""
|
|
|
|
_instances: Sequence[Any]
|
|
|
|
def _get_kwargs(self) -> Dict[Type[T], T]:
|
|
result = self._get_default_kwargs()
|
|
result[AbstractDispatcher] = self
|
|
result.update(self._kwargs)
|
|
return result
|
|
|
|
def _get_default_kwargs(self) -> Dict[Type[T], T]:
|
|
application = self.application
|
|
_default_kwargs = {
|
|
FastAPI: application.web_app,
|
|
Server: application.web_server,
|
|
TelegramApplication: application.telegram,
|
|
TelegramBot: application.telegram.bot,
|
|
}
|
|
if not application.running:
|
|
for obj in chain(
|
|
application.managers.dependency,
|
|
application.managers.components,
|
|
application.managers.services,
|
|
application.managers.plugins,
|
|
):
|
|
_default_kwargs[type(obj)] = obj
|
|
return {k: v for k, v in _default_kwargs.items() if v is not None}
|
|
|
|
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
|
annotation = parameter.annotation
|
|
# noinspection PyTypeChecker
|
|
if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None:
|
|
parameter._default = value # pylint: disable=W0212
|
|
return parameter
|
|
|
|
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
|
annotation = parameter.annotation
|
|
if annotation != Any and isinstance(annotation, GenericAlias):
|
|
return parameter
|
|
|
|
catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name)
|
|
if catch_func is not None:
|
|
# noinspection PyUnresolvedReferences,PyProtectedMember
|
|
parameter._default = catch_func() # pylint: disable=W0212
|
|
return parameter
|
|
|
|
@catch(AbstractEventLoop)
|
|
def catch_loop(self) -> AbstractEventLoop:
|
|
return asyncio.get_event_loop()
|
|
|
|
|
|
class HandlerDispatcher(BaseDispatcher):
|
|
"""Handler 参数分发器"""
|
|
|
|
def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
|
super().__init__(update=update, context=context, **kwargs)
|
|
self._update = update
|
|
self._context = context
|
|
|
|
def dispatch(
|
|
self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None
|
|
) -> Callable[..., R]:
|
|
self._update = update or self._update
|
|
self._context = context or self._context
|
|
if self._update is None:
|
|
from core.builtins.contexts import UpdateCV
|
|
|
|
self._update = UpdateCV.get()
|
|
if self._context is None:
|
|
from core.builtins.contexts import CallbackContextCV
|
|
|
|
self._context = CallbackContextCV.get()
|
|
return super().dispatch(func)
|
|
|
|
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
|
"""HandlerDispatcher 默认不使用 dispatch_by_default"""
|
|
return parameter
|
|
|
|
@catch(Update)
|
|
def catch_update(self) -> Update:
|
|
return self._update
|
|
|
|
@catch(CallbackContext)
|
|
def catch_context(self) -> CallbackContext:
|
|
return self._context
|
|
|
|
@catch(Message)
|
|
def catch_message(self) -> Message:
|
|
return self._update.effective_message
|
|
|
|
@catch(User)
|
|
def catch_user(self) -> User:
|
|
return self._update.effective_user
|
|
|
|
@catch(Chat)
|
|
def catch_chat(self) -> Chat:
|
|
return self._update.effective_chat
|
|
|
|
|
|
class JobDispatcher(BaseDispatcher):
|
|
"""Job 参数分发器"""
|
|
|
|
def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
|
super().__init__(context=context, **kwargs)
|
|
self._context = context
|
|
|
|
def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]:
|
|
self._context = context or self._context
|
|
if self._context is None:
|
|
from core.builtins.contexts import CallbackContextCV
|
|
|
|
self._context = CallbackContextCV.get()
|
|
return super().dispatch(func)
|
|
|
|
@catch("data")
|
|
def catch_data(self) -> Any:
|
|
return self._context.job.data
|
|
|
|
@catch(Job)
|
|
def catch_job(self) -> Job:
|
|
return self._context.job
|
|
|
|
@catch(CallbackContext)
|
|
def catch_context(self) -> CallbackContext:
|
|
return self._context
|
|
|
|
|
|
def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher):
|
|
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
|
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
return dispatcher().dispatch(func)(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorate
|