MibooGram/core/plugin/_plugin.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

304 lines
11 KiB
Python
Raw Normal View History

"""插件"""
import asyncio
from abc import ABC
from dataclasses import asdict
from datetime import timedelta
from functools import partial, wraps
from itertools import chain
from multiprocessing import RLock as Lock
from types import MethodType
from typing import (
Any,
ClassVar,
Dict,
Iterable,
List,
Optional,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel
from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler
from typing_extensions import ParamSpec
from core.handler.adminhandler import AdminHandler
from core.plugin._funcs import ConversationFuncs, PluginFuncs
from core.plugin._handler import ConversationDataType
from utils.const import WRAPPER_ASSIGNMENTS
from utils.helpers import isabstract
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData
from core.plugin._job import JobData
from multiprocessing.synchronize import RLock as LockType
__all__ = ("Plugin", "PluginType", "get_all_plugins")
wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS)
P = ParamSpec("P")
T = TypeVar("T")
R = TypeVar("R")
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
"""用于储存生成 handler 时所需要的参数(例如 block的属性名"""
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block的属性名"""
_ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
_JOB_ATTR_NAME = "_job_data"
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
class _Plugin(PluginFuncs):
"""插件"""
_lock: ClassVar["LockType"] = Lock()
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
_installed: bool = False
_handlers: Optional[List[HandlerType]] = None
_error_handlers: Optional[List["ErrorHandlerData"]] = None
_jobs: Optional[List[Job]] = None
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError("No application was set for this Plugin.")
return self._application
@property
def handlers(self) -> List[HandlerType]:
"""该插件的所有 handler"""
with self._lock:
if self._handlers is None:
self._handlers = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
):
for data in datas:
data: "HandlerData"
if data.admin:
self._handlers.append(
AdminHandler(
handler=data.type(
callback=func,
**data.kwargs,
),
application=self.application,
)
)
else:
self._handlers.append(
data.type(
callback=func,
**data.kwargs,
)
)
return self._handlers
@property
def error_handlers(self) -> List["ErrorHandlerData"]:
with self._lock:
if self._error_handlers is None:
self._error_handlers = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, []))
):
for data in datas:
data: "ErrorHandlerData"
data.func = func
self._error_handlers.append(data)
return self._error_handlers
def _install_jobs(self) -> None:
if self._jobs is None:
self._jobs = []
for attr in dir(self):
# noinspection PyUnboundLocalVariable
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _JOB_ATTR_NAME, []))
):
for data in datas:
data: "JobData"
self._jobs.append(
getattr(self.application.telegram.job_queue, data.type)(
callback=func,
**data.kwargs,
**{
key: value
for key, value in asdict(data).items()
if key not in ["type", "kwargs", "dispatcher"]
},
)
)
@property
def jobs(self) -> List[Job]:
with self._lock:
if self._jobs is None:
self._jobs = []
self._install_jobs()
return self._jobs
async def initialize(self) -> None:
"""初始化插件"""
async def shutdown(self) -> None:
"""销毁插件"""
async def install(self) -> None:
"""安装"""
group = id(self)
if not self._installed:
await self.initialize()
# initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题
async with self._asyncio_lock:
self._install_jobs()
for h in self.handlers:
if not isinstance(h, TypeHandler):
self.application.telegram.add_handler(h, group)
else:
self.application.telegram.add_handler(h, -1)
for h in self.error_handlers:
self.application.telegram.add_error_handler(h.func, h.block)
self._installed = True
async def uninstall(self) -> None:
"""卸载"""
group = id(self)
with self._lock:
if self._installed:
if group in self.application.telegram.handlers:
del self.application.telegram.handlers[id(self)]
for h in self.handlers:
if isinstance(h, TypeHandler):
self.application.telegram.remove_handler(h, -1)
for h in self.error_handlers:
self.application.telegram.remove_error_handler(h.func)
for j in self.application.telegram.job_queue.jobs():
j.schedule_removal()
await self.shutdown()
self._installed = False
async def reload(self) -> None:
await self.uninstall()
await self.install()
class _Conversation(_Plugin, ConversationFuncs, ABC):
"""Conversation类"""
# noinspection SpellCheckingInspection
class Config(BaseModel):
allow_reentry: bool = False
per_chat: bool = True
per_user: bool = True
per_message: bool = False
conversation_timeout: Optional[Union[float, timedelta]] = None
name: Optional[str] = None
map_to_parent: Optional[Dict[object, object]] = None
block: bool = False
def __init_subclass__(cls, **kwargs):
cls._conversation_kwargs = kwargs
super(_Conversation, cls).__init_subclass__()
return cls
@property
def handlers(self) -> List[HandlerType]:
with self._lock:
if self._handlers is None:
self._handlers = []
entry_points: List[HandlerType] = []
states: Dict[Any, List[HandlerType]] = {}
fallbacks: List[HandlerType] = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and (func := getattr(self, attr, None)) is not None
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
):
conversation_data: "ConversationData"
handlers: List[HandlerType] = []
for data in datas:
handlers.append(
data.type(
callback=func,
**data.kwargs,
)
)
if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None):
if (_type := conversation_data.type) is ConversationDataType.Entry:
entry_points.extend(handlers)
elif _type is ConversationDataType.State:
if conversation_data.state in states:
states[conversation_data.state].extend(handlers)
else:
states[conversation_data.state] = handlers
elif _type is ConversationDataType.Fallback:
fallbacks.extend(handlers)
else:
self._handlers.extend(handlers)
else:
self._handlers.extend(handlers)
if entry_points and states and fallbacks:
kwargs = self._conversation_kwargs
kwargs.update(self.Config().dict())
self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs))
else:
temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks}
reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items()))
logger.warning(
"'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason)
)
return self._handlers
class Plugin(_Plugin, ABC):
"""插件"""
Conversation = _Conversation
PluginType = TypeVar("PluginType", bound=_Plugin)
def get_all_plugins() -> Iterable[Type[PluginType]]:
"""获取所有 Plugin 的子类"""
return filter(
lambda x: x.__name__[0] != "_" and not isabstract(x),
chain(Plugin.__subclasses__(), _Conversation.__subclasses__()),
)