PaiGram/core/builtins/dispatcher.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

310 lines
9.8 KiB
Python
Raw Normal View History

"""参数分发器"""
import asyncio
import inspect
from abc import ABC, abstractmethod
from asyncio import AbstractEventLoop
from functools import cached_property, lru_cache, partial, wraps
from inspect import Parameter, Signature
from itertools import chain
from types import GenericAlias, MethodType
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)
from arkowrapper import ArkoWrapper
from fastapi import FastAPI
from telegram import Bot as TelegramBot, Chat, Message, Update, User
from telegram.ext import Application as TelegramApplication, CallbackContext, Job
from typing_extensions import ParamSpec
from uvicorn import Server
from core.application import Application
from utils.const import WRAPPER_ASSIGNMENTS
from utils.typedefs import R, T
__all__ = (
"catch",
"AbstractDispatcher",
"BaseDispatcher",
"HandlerDispatcher",
"JobDispatcher",
"dispatched",
)
P = ParamSpec("P")
TargetType = Union[Type, str, Callable[[Any], bool]]
_CATCH_TARGET_ATTR = "_catch_targets"
def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorate(func: Callable[P, R]) -> Callable[P, R]:
setattr(func, _CATCH_TARGET_ATTR, targets)
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
return decorate
@lru_cache(64)
def get_signature(func: Union[type, Callable]) -> Signature:
if isinstance(func, type):
return inspect.signature(func.__init__)
return inspect.signature(func)
class AbstractDispatcher(ABC):
"""参数分发器"""
IGNORED_ATTRS = []
_args: List[Any] = []
_kwargs: Dict[Union[str, Type], Any] = {}
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
return self._application
def __init__(self, *args, **kwargs) -> None:
self._args = list(args)
self._kwargs = dict(kwargs)
for _, value in kwargs.items():
type_arg = type(value)
if type_arg != str:
self._kwargs[type_arg] = value
for arg in args:
type_arg = type(arg)
if type_arg != str:
self._kwargs[type_arg] = arg
@cached_property
def catch_funcs(self) -> List[MethodType]:
# noinspection PyTypeChecker
return list(
ArkoWrapper(dir(self))
.filter(lambda x: not x.startswith("_"))
.filter(
lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"]
)
.map(lambda x: getattr(self, x))
.filter(lambda x: isinstance(x, MethodType))
.filter(lambda x: hasattr(x, "_catch_targets"))
)
@cached_property
def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]:
result = {}
for catch_func in self.catch_funcs:
catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR)
for catch_target in catch_targets:
result[catch_target] = catch_func
return result
@cached_property
def dispatch_funcs(self) -> List[MethodType]:
return list(
ArkoWrapper(dir(self))
.filter(lambda x: x.startswith("dispatch_by_"))
.map(lambda x: getattr(self, x))
.filter(lambda x: isinstance(x, MethodType))
)
@abstractmethod
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
"""默认的 dispatch 方法"""
@abstractmethod
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
"""使用 catch_func 获取并分配参数"""
def dispatch(self, func: Callable[P, R]) -> Callable[..., R]:
"""将参数分配给函数,从而合成一个无需参数即可执行的函数"""
params = {}
signature = get_signature(func)
parameters: Dict[str, Parameter] = dict(signature.parameters)
for name, parameter in list(parameters.items()):
parameter: Parameter
if any(
[
name == "self" and isinstance(func, (type, MethodType)),
parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL],
]
):
del parameters[name]
continue
for dispatch_func in self.dispatch_funcs:
parameters[name] = dispatch_func(parameter)
for name, parameter in parameters.items():
if parameter.default != Parameter.empty:
params[name] = parameter.default
else:
params[name] = None
return partial(func, **params)
@catch(Application)
def catch_application(self) -> Application:
return self.application
class BaseDispatcher(AbstractDispatcher):
"""默认参数分发器"""
_instances: Sequence[Any]
def _get_kwargs(self) -> Dict[Type[T], T]:
result = self._get_default_kwargs()
result[AbstractDispatcher] = self
result.update(self._kwargs)
return result
def _get_default_kwargs(self) -> Dict[Type[T], T]:
application = self.application
_default_kwargs = {
FastAPI: application.web_app,
Server: application.web_server,
TelegramApplication: application.telegram,
TelegramBot: application.telegram.bot,
}
if not application.running:
for obj in chain(
application.managers.dependency,
application.managers.components,
application.managers.services,
application.managers.plugins,
):
_default_kwargs[type(obj)] = obj
return {k: v for k, v in _default_kwargs.items() if v is not None}
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
annotation = parameter.annotation
# noinspection PyTypeChecker
if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None:
parameter._default = value # pylint: disable=W0212
return parameter
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
annotation = parameter.annotation
if annotation != Any and isinstance(annotation, GenericAlias):
return parameter
catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name)
if catch_func is not None:
# noinspection PyUnresolvedReferences,PyProtectedMember
parameter._default = catch_func() # pylint: disable=W0212
return parameter
@catch(AbstractEventLoop)
def catch_loop(self) -> AbstractEventLoop:
return asyncio.get_event_loop()
class HandlerDispatcher(BaseDispatcher):
"""Handler 参数分发器"""
def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None:
super().__init__(update=update, context=context, **kwargs)
self._update = update
self._context = context
def dispatch(
self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None
) -> Callable[..., R]:
self._update = update or self._update
self._context = context or self._context
if self._update is None:
from core.builtins.contexts import UpdateCV
self._update = UpdateCV.get()
if self._context is None:
from core.builtins.contexts import CallbackContextCV
self._context = CallbackContextCV.get()
return super().dispatch(func)
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
"""HandlerDispatcher 默认不使用 dispatch_by_default"""
return parameter
@catch(Update)
def catch_update(self) -> Update:
return self._update
@catch(CallbackContext)
def catch_context(self) -> CallbackContext:
return self._context
@catch(Message)
def catch_message(self) -> Message:
return self._update.effective_message
@catch(User)
def catch_user(self) -> User:
return self._update.effective_user
@catch(Chat)
def catch_chat(self) -> Chat:
return self._update.effective_chat
class JobDispatcher(BaseDispatcher):
"""Job 参数分发器"""
def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None:
super().__init__(context=context, **kwargs)
self._context = context
def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]:
self._context = context or self._context
if self._context is None:
from core.builtins.contexts import CallbackContextCV
self._context = CallbackContextCV.get()
return super().dispatch(func)
@catch("data")
def catch_data(self) -> Any:
return self._context.job.data
@catch(Job)
def catch_job(self) -> Job:
return self._context.job
@catch(CallbackContext)
def catch_context(self) -> CallbackContext:
return self._context
def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher):
def decorate(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return dispatcher().dispatch(func)(*args, **kwargs)
return wrapper
return decorate