mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-11-16 12:51:35 +00:00
233e7ab58d
Co-authored-by: luoshuijs <luoshuijs@outlook.com> Co-authored-by: Karako <karakohear@gmail.com> Co-authored-by: xtaodada <xtao@xtaolink.cn>
304 lines
11 KiB
Python
304 lines
11 KiB
Python
"""插件"""
|
||
import asyncio
|
||
from abc import ABC
|
||
from dataclasses import asdict
|
||
from datetime import timedelta
|
||
from functools import partial, wraps
|
||
from itertools import chain
|
||
from multiprocessing import RLock as Lock
|
||
from types import MethodType
|
||
from typing import (
|
||
Any,
|
||
ClassVar,
|
||
Dict,
|
||
Iterable,
|
||
List,
|
||
Optional,
|
||
TYPE_CHECKING,
|
||
Type,
|
||
TypeVar,
|
||
Union,
|
||
)
|
||
|
||
from pydantic import BaseModel
|
||
from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler
|
||
from typing_extensions import ParamSpec
|
||
|
||
from core.handler.adminhandler import AdminHandler
|
||
from core.plugin._funcs import ConversationFuncs, PluginFuncs
|
||
from core.plugin._handler import ConversationDataType
|
||
from utils.const import WRAPPER_ASSIGNMENTS
|
||
from utils.helpers import isabstract
|
||
from utils.log import logger
|
||
|
||
if TYPE_CHECKING:
|
||
from core.application import Application
|
||
from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData
|
||
from core.plugin._job import JobData
|
||
from multiprocessing.synchronize import RLock as LockType
|
||
|
||
__all__ = ("Plugin", "PluginType", "get_all_plugins")
|
||
|
||
wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS)
|
||
P = ParamSpec("P")
|
||
T = TypeVar("T")
|
||
R = TypeVar("R")
|
||
|
||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
||
|
||
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||
|
||
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
||
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
||
|
||
_ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
||
|
||
_JOB_ATTR_NAME = "_job_data"
|
||
|
||
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
|
||
|
||
|
||
class _Plugin(PluginFuncs):
|
||
"""插件"""
|
||
|
||
_lock: ClassVar["LockType"] = Lock()
|
||
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
|
||
_installed: bool = False
|
||
|
||
_handlers: Optional[List[HandlerType]] = None
|
||
_error_handlers: Optional[List["ErrorHandlerData"]] = None
|
||
_jobs: Optional[List[Job]] = None
|
||
_application: "Optional[Application]" = None
|
||
|
||
def set_application(self, application: "Application") -> None:
|
||
self._application = application
|
||
|
||
@property
|
||
def application(self) -> "Application":
|
||
if self._application is None:
|
||
raise RuntimeError("No application was set for this Plugin.")
|
||
return self._application
|
||
|
||
@property
|
||
def handlers(self) -> List[HandlerType]:
|
||
"""该插件的所有 handler"""
|
||
with self._lock:
|
||
if self._handlers is None:
|
||
self._handlers = []
|
||
|
||
for attr in dir(self):
|
||
if (
|
||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||
and isinstance(func := getattr(self, attr), MethodType)
|
||
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||
):
|
||
for data in datas:
|
||
data: "HandlerData"
|
||
if data.admin:
|
||
self._handlers.append(
|
||
AdminHandler(
|
||
handler=data.type(
|
||
callback=func,
|
||
**data.kwargs,
|
||
),
|
||
application=self.application,
|
||
)
|
||
)
|
||
else:
|
||
self._handlers.append(
|
||
data.type(
|
||
callback=func,
|
||
**data.kwargs,
|
||
)
|
||
)
|
||
return self._handlers
|
||
|
||
@property
|
||
def error_handlers(self) -> List["ErrorHandlerData"]:
|
||
with self._lock:
|
||
if self._error_handlers is None:
|
||
self._error_handlers = []
|
||
for attr in dir(self):
|
||
if (
|
||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||
and isinstance(func := getattr(self, attr), MethodType)
|
||
and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, []))
|
||
):
|
||
for data in datas:
|
||
data: "ErrorHandlerData"
|
||
data.func = func
|
||
self._error_handlers.append(data)
|
||
|
||
return self._error_handlers
|
||
|
||
def _install_jobs(self) -> None:
|
||
if self._jobs is None:
|
||
self._jobs = []
|
||
for attr in dir(self):
|
||
# noinspection PyUnboundLocalVariable
|
||
if (
|
||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||
and isinstance(func := getattr(self, attr), MethodType)
|
||
and (datas := getattr(func, _JOB_ATTR_NAME, []))
|
||
):
|
||
for data in datas:
|
||
data: "JobData"
|
||
self._jobs.append(
|
||
getattr(self.application.telegram.job_queue, data.type)(
|
||
callback=func,
|
||
**data.kwargs,
|
||
**{
|
||
key: value
|
||
for key, value in asdict(data).items()
|
||
if key not in ["type", "kwargs", "dispatcher"]
|
||
},
|
||
)
|
||
)
|
||
|
||
@property
|
||
def jobs(self) -> List[Job]:
|
||
with self._lock:
|
||
if self._jobs is None:
|
||
self._jobs = []
|
||
self._install_jobs()
|
||
return self._jobs
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化插件"""
|
||
|
||
async def shutdown(self) -> None:
|
||
"""销毁插件"""
|
||
|
||
async def install(self) -> None:
|
||
"""安装"""
|
||
group = id(self)
|
||
if not self._installed:
|
||
await self.initialize()
|
||
# initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题
|
||
async with self._asyncio_lock:
|
||
self._install_jobs()
|
||
|
||
for h in self.handlers:
|
||
if not isinstance(h, TypeHandler):
|
||
self.application.telegram.add_handler(h, group)
|
||
else:
|
||
self.application.telegram.add_handler(h, -1)
|
||
|
||
for h in self.error_handlers:
|
||
self.application.telegram.add_error_handler(h.func, h.block)
|
||
self._installed = True
|
||
|
||
async def uninstall(self) -> None:
|
||
"""卸载"""
|
||
group = id(self)
|
||
|
||
with self._lock:
|
||
if self._installed:
|
||
if group in self.application.telegram.handlers:
|
||
del self.application.telegram.handlers[id(self)]
|
||
|
||
for h in self.handlers:
|
||
if isinstance(h, TypeHandler):
|
||
self.application.telegram.remove_handler(h, -1)
|
||
for h in self.error_handlers:
|
||
self.application.telegram.remove_error_handler(h.func)
|
||
|
||
for j in self.application.telegram.job_queue.jobs():
|
||
j.schedule_removal()
|
||
await self.shutdown()
|
||
self._installed = False
|
||
|
||
async def reload(self) -> None:
|
||
await self.uninstall()
|
||
await self.install()
|
||
|
||
|
||
class _Conversation(_Plugin, ConversationFuncs, ABC):
|
||
"""Conversation类"""
|
||
|
||
# noinspection SpellCheckingInspection
|
||
class Config(BaseModel):
|
||
allow_reentry: bool = False
|
||
per_chat: bool = True
|
||
per_user: bool = True
|
||
per_message: bool = False
|
||
conversation_timeout: Optional[Union[float, timedelta]] = None
|
||
name: Optional[str] = None
|
||
map_to_parent: Optional[Dict[object, object]] = None
|
||
block: bool = False
|
||
|
||
def __init_subclass__(cls, **kwargs):
|
||
cls._conversation_kwargs = kwargs
|
||
super(_Conversation, cls).__init_subclass__()
|
||
return cls
|
||
|
||
@property
|
||
def handlers(self) -> List[HandlerType]:
|
||
with self._lock:
|
||
if self._handlers is None:
|
||
self._handlers = []
|
||
|
||
entry_points: List[HandlerType] = []
|
||
states: Dict[Any, List[HandlerType]] = {}
|
||
fallbacks: List[HandlerType] = []
|
||
for attr in dir(self):
|
||
if (
|
||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||
and (func := getattr(self, attr, None)) is not None
|
||
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||
):
|
||
conversation_data: "ConversationData"
|
||
|
||
handlers: List[HandlerType] = []
|
||
for data in datas:
|
||
handlers.append(
|
||
data.type(
|
||
callback=func,
|
||
**data.kwargs,
|
||
)
|
||
)
|
||
|
||
if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None):
|
||
if (_type := conversation_data.type) is ConversationDataType.Entry:
|
||
entry_points.extend(handlers)
|
||
elif _type is ConversationDataType.State:
|
||
if conversation_data.state in states:
|
||
states[conversation_data.state].extend(handlers)
|
||
else:
|
||
states[conversation_data.state] = handlers
|
||
elif _type is ConversationDataType.Fallback:
|
||
fallbacks.extend(handlers)
|
||
else:
|
||
self._handlers.extend(handlers)
|
||
else:
|
||
self._handlers.extend(handlers)
|
||
if entry_points and states and fallbacks:
|
||
kwargs = self._conversation_kwargs
|
||
kwargs.update(self.Config().dict())
|
||
self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs))
|
||
else:
|
||
temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks}
|
||
reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items()))
|
||
logger.warning(
|
||
"'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason)
|
||
)
|
||
return self._handlers
|
||
|
||
|
||
class Plugin(_Plugin, ABC):
|
||
"""插件"""
|
||
|
||
Conversation = _Conversation
|
||
|
||
|
||
PluginType = TypeVar("PluginType", bound=_Plugin)
|
||
|
||
|
||
def get_all_plugins() -> Iterable[Type[PluginType]]:
|
||
"""获取所有 Plugin 的子类"""
|
||
return filter(
|
||
lambda x: x.__name__[0] != "_" and not isabstract(x),
|
||
chain(Plugin.__subclasses__(), _Conversation.__subclasses__()),
|
||
)
|