PaiGram/core/bot.py

310 lines
12 KiB
Python
Raw Normal View History

import asyncio
import inspect
import os
2022-09-17 16:28:51 +00:00
from asyncio import CancelledError
from importlib import import_module
from multiprocessing import RLock as Lock
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, Iterator, List, NoReturn, Optional, TYPE_CHECKING, Type, TypeVar
import genshin
import pytz
2022-09-17 16:28:51 +00:00
from async_timeout import timeout
from telegram.error import NetworkError, TimedOut
2022-09-17 16:28:51 +00:00
from telegram.ext import (
AIORateLimiter,
Application as TgApplication,
CallbackContext,
Defaults,
JobQueue,
MessageHandler,
filters,
2022-09-17 16:28:51 +00:00
)
from telegram.ext.filters import StatusUpdate
from core.config import BotConfig, config # pylint: disable=W0611
from core.error import ServiceNotFoundError
# noinspection PyProtectedMember
from core.plugin import Plugin, _Plugin
from core.service import Service
from utils.const import PLUGIN_DIR, PROJECT_ROOT
from utils.log import logger
if TYPE_CHECKING:
from telegram import Update
__all__ = ["bot"]
T = TypeVar("T")
PluginType = TypeVar("PluginType", bound=_Plugin)
class Bot:
_lock: ClassVar[Lock] = Lock()
_instance: ClassVar[Optional["Bot"]] = None
def __new__(cls, *args, **kwargs) -> "Bot":
"""实现单例"""
with cls._lock: # 使线程、进程安全
if cls._instance is None:
cls._instance = object.__new__(cls)
return cls._instance
app: Optional[TgApplication] = None
_config: BotConfig = config
_services: Dict[Type[T], T] = {}
_running: bool = False
2022-09-17 16:28:51 +00:00
def _inject(self, signature: inspect.Signature, target: Callable[..., T]) -> T:
kwargs = {}
for name, parameter in signature.parameters.items():
if name != "self" and parameter.annotation != inspect.Parameter.empty:
if value := self._services.get(parameter.annotation):
kwargs[name] = value
return target(**kwargs)
2022-09-17 16:28:51 +00:00
def init_inject(self, target: Callable[..., T]) -> T:
"""用于实例化Plugin的方法。用于给插件传入一些必要组件如 MySQL、Redis等"""
if isinstance(target, type):
signature = inspect.signature(target.__init__)
else:
signature = inspect.signature(target)
return self._inject(signature, target)
async def async_inject(self, target: Callable[..., T]) -> T:
return await self._inject(inspect.signature(target), target)
def _gen_pkg(self, root: Path) -> Iterator[str]:
"""生成可以用于 import_module 导入的字符串"""
for path in root.iterdir():
if not path.name.startswith("_"):
if path.is_dir():
yield from self._gen_pkg(path)
elif path.suffix == ".py":
yield str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".")
async def install_plugins(self):
"""安装插件"""
for pkg in self._gen_pkg(PLUGIN_DIR):
try:
import_module(pkg) # 导入插件
except Exception as e: # pylint: disable=W0703
logger.exception(
f'在导入文件 "{pkg}" 的过程中遇到了错误: \n[red bold]{type(e).__name__}: {e}[/]', extra={"markup": True}
)
continue # 如有错误则继续
callback_dict: Dict[int, List[Callable]] = {}
for plugin_cls in {*Plugin.__subclasses__(), *Plugin.Conversation.__subclasses__()}:
path = f"{plugin_cls.__module__}.{plugin_cls.__name__}"
try:
plugin: PluginType = self.init_inject(plugin_cls)
if hasattr(plugin, "__async_init__"):
2022-09-17 16:28:51 +00:00
await self.async_inject(plugin.__async_init__)
handlers = plugin.handlers
self.app.add_handlers(handlers)
if handlers:
logger.debug(f'插件 "{path}" 添加了 {len(handlers)} 个 handler ')
# noinspection PyProtectedMember
for priority, callback in plugin._new_chat_members_handler_funcs(): # pylint: disable=W0212
if not callback_dict.get(priority):
callback_dict[priority] = []
callback_dict[priority].append(callback)
error_handlers = plugin.error_handlers
for callback, block in error_handlers.items():
self.app.add_error_handler(callback, block)
if error_handlers:
logger.debug(f'插件 "{path}" 添加了 {len(error_handlers)} 个 error handler')
if jobs := plugin.jobs:
logger.debug(f'插件 "{path}" 添加了 {len(jobs)} 个任务')
logger.success(f'插件 "{path}" 载入成功')
except Exception as e: # pylint: disable=W0703
logger.exception(
f'在安装插件 "{path}" 的过程中遇到了错误: \n[red bold]{type(e).__name__}: {e}[/]', extra={"markup": True}
)
if callback_dict:
num = sum(len(callback_dict[i]) for i in callback_dict)
async def _new_chat_member_callback(update: "Update", context: "CallbackContext"):
nonlocal callback
for _, value in callback_dict.items():
for callback in value:
await callback(update, context)
self.app.add_handler(
MessageHandler(callback=_new_chat_member_callback, filters=StatusUpdate.NEW_CHAT_MEMBERS, block=False)
)
logger.success(
f"成功添加了 {num} 个针对 [blue]{StatusUpdate.NEW_CHAT_MEMBERS}[/] 的 [blue]MessageHandler[/]",
extra={"markup": True},
)
# special handler
from plugins.system.start import StartPlugin
self.app.add_handler(
MessageHandler(
callback=StartPlugin.unknown_command, filters=filters.COMMAND & filters.ChatType.PRIVATE, block=False
)
)
async def _start_base_services(self):
for pkg in self._gen_pkg(PROJECT_ROOT / "core/base"):
try:
import_module(pkg)
except Exception as e: # pylint: disable=W0703
logger.exception(
f'在导入文件 "{pkg}" 的过程中遇到了错误: \n[red bold]{type(e).__name__}: {e}[/]', extra={"markup": True}
)
continue
for base_service_cls in Service.__subclasses__():
try:
if hasattr(base_service_cls, "from_config"):
instance = base_service_cls.from_config(self._config)
else:
instance = self.init_inject(base_service_cls)
await instance.start()
logger.success(f'服务 "{base_service_cls.__name__}" 初始化成功')
self._services.update({base_service_cls: instance})
except Exception as e: # pylint: disable=W0703
logger.exception(f'服务 "{base_service_cls.__name__}" 初始化失败: {e}')
continue
async def start_services(self):
"""启动服务"""
await self._start_base_services()
for path in (PROJECT_ROOT / "core").iterdir():
if not path.name.startswith("_") and path.is_dir() and path.name != "base":
pkg = str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".")
try:
import_module(pkg)
except Exception as e: # pylint: disable=W0703
logger.exception(
f'在导入文件 "{pkg}" 的过程中遇到了错误: \n[red bold]{type(e).__name__}: {e}[/]', extra={"markup": True}
)
continue
async def stop_services(self):
"""关闭服务"""
if not self._services:
return
logger.info("正在关闭服务")
2022-09-17 16:28:51 +00:00
for _, service in filter(lambda x: not isinstance(x[1], TgApplication), self._services.items()):
async with timeout(5):
try:
if hasattr(service, "stop"):
2022-09-17 16:28:51 +00:00
if inspect.iscoroutinefunction(service.stop):
await service.stop()
else:
service.stop()
logger.success(f'服务 "{service.__class__.__name__}" 关闭成功')
except CancelledError:
logger.warning(f'服务 "{service.__class__.__name__}" 关闭超时')
except Exception as e: # pylint: disable=W0703
logger.exception(f'服务 "{service.__class__.__name__}" 关闭失败: \n{type(e).__name__}: {e}')
2022-09-17 16:28:51 +00:00
async def _post_init(self, context: CallbackContext) -> NoReturn:
logger.info("开始初始化 genshin.py 相关资源")
try:
# 替换为 fastgit 镜像源
for i in dir(genshin.utility.extdb):
if "_URL" in i:
setattr(
genshin.utility.extdb,
i,
getattr(genshin.utility.extdb, i).replace("githubusercontent.com", "fastgit.org"),
)
await genshin.utility.update_characters_enka()
2022-10-11 06:45:07 +00:00
except Exception as exc: # pylint: disable=W0703
logger.error("初始化 genshin.py 相关资源失败")
logger.exception(exc)
else:
logger.success("初始化 genshin.py 相关资源成功")
2022-09-17 16:28:51 +00:00
self._services.update({CallbackContext: context})
logger.info("开始初始化服务")
await self.start_services()
logger.info("开始安装插件")
await self.install_plugins()
logger.info("BOT 初始化成功")
def launch(self) -> NoReturn:
"""启动机器人"""
self._running = True
logger.info("正在初始化BOT")
self.app = (
TgApplication.builder()
.rate_limiter(AIORateLimiter())
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
.token(self._config.bot_token)
.post_init(self._post_init)
.build()
)
try:
for _ in range(5):
try:
self.app.run_polling(close_loop=False, write_timeout=10)
break
except TimedOut:
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
continue
except NetworkError as e:
logger.exception()
if "SSLZeroReturnError" in str(e):
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
else:
logger.error("网络连接出现问题, 请检查您的网络状况.")
break
except (SystemExit, KeyboardInterrupt):
pass
except Exception as e: # pylint: disable=W0703
logger.exception(f"BOT 执行过程中出现错误: {e}")
finally:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.stop_services())
loop.close()
logger.info("BOT 已经关闭")
self._running = False
def find_service(self, target: Type[T]) -> T:
"""查找服务。若没找到则抛出 ServiceNotFoundError"""
if result := self._services.get(target) is None:
raise ServiceNotFoundError(target)
return result
def add_service(self, service: T) -> NoReturn:
"""添加服务。若已经有同类型的服务,则会抛出异常"""
if type(service) in self._services:
raise ValueError(f'Service "{type(service)}" is already existed.')
self.update_service(service)
def update_service(self, service: T):
"""更新服务。若服务不存在,则添加;若存在,则更新"""
self._services.update({type(service): service})
def contain_service(self, service: Any) -> bool:
"""判断服务是否存在"""
if isinstance(service, type):
return service in self._services
else:
return service in self._services.values()
@property
def job_queue(self) -> JobQueue:
return self.app.job_queue
@property
def services(self) -> Dict[Type[T], T]:
return self._services
@property
def config(self) -> BotConfig:
return self._config
@property
def is_running(self) -> bool:
return self._running
bot = Bot()