From 1fe8a5efeeb0c5e15c78717796b5ce251b0688af Mon Sep 17 00:00:00 2001 From: xtaodada Date: Mon, 31 Jul 2023 22:10:37 +0800 Subject: [PATCH] :sparkles: separate core code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 洛水居室 Co-authored-by: Karako Co-authored-by: LittleMengBot <79637933+LittleMengBot@users.noreply.github.com> Co-authored-by: Nahida Co-authored-by: SiHuan Co-authored-by: Chuangbo Li Co-authored-by: zhxy-CN Co-authored-by: =?UTF-8?q?=E8=89=BE=E8=BF=AA?= <62269186+AnotiaWang@users.noreply.github.com> --- .gitignore | 130 ++++--- README.md | 15 + gram_core/README.md | 22 ++ gram_core/__init__.py | 0 gram_core/application.py | 289 ++++++++++++++++ gram_core/base_service.py | 60 ++++ gram_core/basemodel.py | 29 ++ gram_core/builtins/__init__.py | 1 + gram_core/builtins/contexts.py | 38 +++ gram_core/builtins/dispatcher.py | 309 +++++++++++++++++ gram_core/builtins/executor.py | 131 +++++++ gram_core/builtins/reloader.py | 185 ++++++++++ gram_core/config.py | 161 +++++++++ gram_core/dependence/__init__.py | 1 + gram_core/dependence/aiobrowser.py | 56 +++ gram_core/dependence/aiobrowser.pyi | 16 + gram_core/dependence/database.py | 51 +++ gram_core/dependence/mtproto.py | 67 ++++ gram_core/dependence/mtproto.pyi | 31 ++ gram_core/dependence/redisdb.py | 50 +++ gram_core/error.py | 7 + gram_core/handler/__init__.py | 0 gram_core/handler/adminhandler.py | 59 ++++ gram_core/handler/callbackqueryhandler.py | 62 ++++ gram_core/handler/limiterhandler.py | 71 ++++ gram_core/manager.py | 286 ++++++++++++++++ gram_core/override/__init__.py | 0 gram_core/override/telegram.py | 117 +++++++ gram_core/plugin/__init__.py | 16 + gram_core/plugin/_funcs.py | 178 ++++++++++ gram_core/plugin/_handler.py | 380 +++++++++++++++++++++ gram_core/plugin/_job.py | 173 ++++++++++ gram_core/plugin/_plugin.py | 314 +++++++++++++++++ gram_core/ratelimiter.py | 67 ++++ gram_core/services/__init__.py | 0 gram_core/services/cookies/__init__.py | 5 + gram_core/services/cookies/cache.py | 97 ++++++ gram_core/services/cookies/error.py | 12 + gram_core/services/cookies/models.py | 39 +++ gram_core/services/cookies/repositories.py | 55 +++ gram_core/services/cookies/services.py | 159 +++++++++ gram_core/services/devices/__init__.py | 5 + gram_core/services/devices/models.py | 23 ++ gram_core/services/devices/repositories.py | 41 +++ gram_core/services/devices/services.py | 25 ++ gram_core/services/players/__init__.py | 3 + gram_core/services/players/error.py | 2 + gram_core/services/players/models.py | 96 ++++++ gram_core/services/players/repositories.py | 110 ++++++ gram_core/services/players/services.py | 43 +++ gram_core/services/task/__init__.py | 1 + gram_core/services/task/models.py | 44 +++ gram_core/services/task/repositories.py | 50 +++ gram_core/services/task/services.py | 63 ++++ gram_core/services/template/README.md | 11 + gram_core/services/template/__init__.py | 1 + gram_core/services/template/cache.py | 58 ++++ gram_core/services/template/error.py | 14 + gram_core/services/template/models.py | 146 ++++++++ gram_core/services/template/services.py | 207 +++++++++++ gram_core/services/users/__init__.py | 0 gram_core/services/users/cache.py | 24 ++ gram_core/services/users/models.py | 34 ++ gram_core/services/users/repositories.py | 44 +++ gram_core/services/users/services.py | 83 +++++ gram_core/sqlmodel/__init__.py | 0 gram_core/sqlmodel/session.py | 118 +++++++ gram_core/version.py | 1 + setup.py | 34 ++ 69 files changed, 4966 insertions(+), 54 deletions(-) create mode 100644 gram_core/README.md create mode 100644 gram_core/__init__.py create mode 100644 gram_core/application.py create mode 100644 gram_core/base_service.py create mode 100644 gram_core/basemodel.py create mode 100644 gram_core/builtins/__init__.py create mode 100644 gram_core/builtins/contexts.py create mode 100644 gram_core/builtins/dispatcher.py create mode 100644 gram_core/builtins/executor.py create mode 100644 gram_core/builtins/reloader.py create mode 100644 gram_core/config.py create mode 100644 gram_core/dependence/__init__.py create mode 100644 gram_core/dependence/aiobrowser.py create mode 100644 gram_core/dependence/aiobrowser.pyi create mode 100644 gram_core/dependence/database.py create mode 100644 gram_core/dependence/mtproto.py create mode 100644 gram_core/dependence/mtproto.pyi create mode 100644 gram_core/dependence/redisdb.py create mode 100644 gram_core/error.py create mode 100644 gram_core/handler/__init__.py create mode 100644 gram_core/handler/adminhandler.py create mode 100644 gram_core/handler/callbackqueryhandler.py create mode 100644 gram_core/handler/limiterhandler.py create mode 100644 gram_core/manager.py create mode 100644 gram_core/override/__init__.py create mode 100644 gram_core/override/telegram.py create mode 100644 gram_core/plugin/__init__.py create mode 100644 gram_core/plugin/_funcs.py create mode 100644 gram_core/plugin/_handler.py create mode 100644 gram_core/plugin/_job.py create mode 100644 gram_core/plugin/_plugin.py create mode 100644 gram_core/ratelimiter.py create mode 100644 gram_core/services/__init__.py create mode 100644 gram_core/services/cookies/__init__.py create mode 100644 gram_core/services/cookies/cache.py create mode 100644 gram_core/services/cookies/error.py create mode 100644 gram_core/services/cookies/models.py create mode 100644 gram_core/services/cookies/repositories.py create mode 100644 gram_core/services/cookies/services.py create mode 100644 gram_core/services/devices/__init__.py create mode 100644 gram_core/services/devices/models.py create mode 100644 gram_core/services/devices/repositories.py create mode 100644 gram_core/services/devices/services.py create mode 100644 gram_core/services/players/__init__.py create mode 100644 gram_core/services/players/error.py create mode 100644 gram_core/services/players/models.py create mode 100644 gram_core/services/players/repositories.py create mode 100644 gram_core/services/players/services.py create mode 100644 gram_core/services/task/__init__.py create mode 100644 gram_core/services/task/models.py create mode 100644 gram_core/services/task/repositories.py create mode 100644 gram_core/services/task/services.py create mode 100644 gram_core/services/template/README.md create mode 100644 gram_core/services/template/__init__.py create mode 100644 gram_core/services/template/cache.py create mode 100644 gram_core/services/template/error.py create mode 100644 gram_core/services/template/models.py create mode 100644 gram_core/services/template/services.py create mode 100644 gram_core/services/users/__init__.py create mode 100644 gram_core/services/users/cache.py create mode 100644 gram_core/services/users/models.py create mode 100644 gram_core/services/users/repositories.py create mode 100644 gram_core/services/users/services.py create mode 100644 gram_core/sqlmodel/__init__.py create mode 100644 gram_core/sqlmodel/session.py create mode 100644 gram_core/version.py create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index 68bc17f..aab2292 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,57 @@ +# Created by .ignore support plugin (hsz.mobi) + +### Windows template +# Windows thumbnail cache files +Thumbs.db +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk +### macOS template +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk +### Python template # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -20,7 +74,6 @@ parts/ sdist/ var/ wheels/ -share/python-wheels/ *.egg-info/ .installed.cfg *.egg @@ -39,17 +92,14 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ -.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover -*.py,cover .hypothesis/ .pytest_cache/ -cover/ # Translations *.mo @@ -59,7 +109,6 @@ cover/ *.log local_settings.py db.sqlite3 -db.sqlite3-journal # Flask stuff: instance/ @@ -72,49 +121,16 @@ instance/ docs/_build/ # PyBuilder -.pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints -# IPython -profile_default/ -ipython_config.py - # pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version +.python-version -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff +# celery beat schedule file celerybeat-schedule -celerybeat.pid # SageMath parsed files *.sage.py @@ -140,21 +156,27 @@ venv.bak/ # mypy .mypy_cache/ -.dmypy.json -dmypy.json +### Linux template +*~ -# Pyre type checker -.pyre/ +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* -# pytype static type analyzer -.pytype/ +# KDE directory preferences +.directory -# Cython debug symbols -cython_debug/ +# Linux trash folder which might appear on any partition or disk +.Trash-* -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### VisualStudioCode template +.vscode + +### Jebrains template +.idea +### Dynaconf config +**/*.local.yml +**/.secrets.yml +utils/ diff --git a/README.md b/README.md index 0f50a62..ec7acad 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,17 @@ # GramCore + +## 1. Overview + telegram robot base core. + +## 2. Usage + +### 2.1 init project + +```bash +poetry install -v +``` + +### 2.2 usage + +TODO diff --git a/gram_core/README.md b/gram_core/README.md new file mode 100644 index 0000000..0f1200d --- /dev/null +++ b/gram_core/README.md @@ -0,0 +1,22 @@ +# core 目录说明 + +## 关于 `Service` + +服务 `Service` 需定义在 `services` 文件夹下, 并继承 `core.service.Service` + +每个 `Service` 都应包含 `start` 和 `stop` 方法, 且这两个方法都为异步方法 + +```python +from core.service import Service + + +class TestService(Service): + def __init__(self): + """do something""" + + async def start(self, *args, **kwargs): + """do something""" + + async def stop(self, *args, **kwargs): + """do something""" +``` \ No newline at end of file diff --git a/gram_core/__init__.py b/gram_core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gram_core/application.py b/gram_core/application.py new file mode 100644 index 0000000..3568330 --- /dev/null +++ b/gram_core/application.py @@ -0,0 +1,289 @@ +"""BOT""" +import asyncio +import signal +from functools import wraps +from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func +from ssl import SSLZeroReturnError +from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar + +import pytz +import uvicorn +from fastapi import FastAPI +from telegram import Bot, Update +from telegram.error import NetworkError, TelegramError, TimedOut +from telegram.ext import ( + Application as TelegramApplication, + ApplicationBuilder as TelegramApplicationBuilder, + Defaults, + JobQueue, +) +from typing_extensions import ParamSpec +from uvicorn import Server + +from gram_core.config import config as application_config +from gram_core.handler.limiterhandler import LimiterHandler +from gram_core.manager import Managers +from gram_core.override.telegram import HTTPXRequest +from gram_core.ratelimiter import RateLimiter +from utils.const import WRAPPER_ASSIGNMENTS +from utils.log import logger +from utils.models.signal import Singleton + +if TYPE_CHECKING: + from asyncio import Task + from types import FrameType + +__all__ = ("Application",) + +R = TypeVar("R") +T = TypeVar("T") +P = ParamSpec("P") + + +class Application(Singleton): + """Application""" + + _web_server_task: Optional["Task"] = None + + _startup_funcs: List[Callable] = [] + _shutdown_funcs: List[Callable] = [] + + def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None: + self._running = False + self.managers = managers + self.telegram = telegram + self.web_server = web_server + self.managers.set_application(application=self) # 给 managers 设置 application + self.managers.build_executor("Application") + + @classmethod + def build(cls): + managers = Managers() + telegram = ( + TelegramApplicationBuilder() + .get_updates_read_timeout(application_config.update_read_timeout) + .get_updates_write_timeout(application_config.update_write_timeout) + .get_updates_connect_timeout(application_config.update_connect_timeout) + .get_updates_pool_timeout(application_config.update_pool_timeout) + .defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai"))) + .token(application_config.bot_token) + .request( + HTTPXRequest( + connection_pool_size=application_config.connection_pool_size, + proxy_url=application_config.proxy_url, + read_timeout=application_config.read_timeout, + write_timeout=application_config.write_timeout, + connect_timeout=application_config.connect_timeout, + pool_timeout=application_config.pool_timeout, + ) + ) + .rate_limiter(RateLimiter()) + .build() + ) + web_server = Server( + uvicorn.Config( + app=FastAPI(debug=application_config.debug), + port=application_config.webserver.port, + host=application_config.webserver.host, + log_config=None, + ) + ) + return cls(managers, telegram, web_server) + + @property + def running(self) -> bool: + """bot 是否正在运行""" + with self._lock: + return self._running + + @property + def web_app(self) -> FastAPI: + """fastapi app""" + return self.web_server.config.app + + @property + def bot(self) -> Optional[Bot]: + return self.telegram.bot + + @property + def job_queue(self) -> Optional[JobQueue]: + return self.telegram.job_queue + + async def _on_startup(self) -> None: + for func in self._startup_funcs: + await self.managers.executor(func, block=getattr(func, "block", False)) + + async def _on_shutdown(self) -> None: + for func in self._shutdown_funcs: + await self.managers.executor(func, block=getattr(func, "block", False)) + + async def initialize(self): + """BOT 初始化""" + self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制 + await self.managers.start_dependency() # 启动基础服务 + await self.managers.init_components() # 实例化组件 + await self.managers.start_services() # 启动其他服务 + await self.managers.install_plugins() # 安装插件 + + async def shutdown(self): + """BOT 关闭""" + await self.managers.uninstall_plugins() # 卸载插件 + await self.managers.stop_services() # 终止其他服务 + await self.managers.stop_dependency() # 终止基础服务 + + async def start(self) -> None: + """启动 BOT""" + logger.info("正在启动 BOT 中...") + + def error_callback(exc: TelegramError) -> None: + """错误信息回调""" + self.telegram.create_task(self.telegram.process_error(error=exc, update=None)) + + await self.telegram.initialize() + logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True}) + + if application_config.webserver.enable: # 如果使用 web app + server_config = self.web_server.config + server_config.setup_event_loop() + if not server_config.loaded: + server_config.load() + self.web_server.lifespan = server_config.lifespan_class(server_config) + try: + await self.web_server.startup() + except OSError as e: + if e.errno == 10048: + logger.error("Web Server 端口被占用:%s", e) + logger.error("Web Server 启动失败,正在退出") + raise SystemExit from None + + if self.web_server.should_exit: + logger.error("Web Server 启动失败,正在退出") + raise SystemExit from None + logger.success("Web Server 启动成功") + + self._web_server_task = asyncio.create_task(self.web_server.main_loop()) + + for _ in range(5): # 连接至 telegram 服务器 + try: + await self.telegram.updater.start_polling( + error_callback=error_callback, allowed_updates=Update.ALL_TYPES + ) + break + except TimedOut: + logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True}) + continue + except NetworkError as e: + logger.exception() + if isinstance(e, SSLZeroReturnError): + logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.") + else: + logger.error("网络连接出现问题, 请检查您的网络状况.") + raise SystemExit from e + + await self.initialize() + logger.success("BOT 初始化成功") + logger.debug("BOT 开始启动") + + await self._on_startup() + await self.telegram.start() + self._running = True + logger.success("BOT 启动成功") + + def stop_signal_handler(self, signum: int): + """终止信号处理""" + signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")} + logger.debug("接收到了终止信号 %s 正在退出...", signals[signum]) + if self._web_server_task: + self._web_server_task.cancel() + + async def idle(self) -> None: + """在接收到中止信号之前,堵塞loop""" + + task = None + + def stop_handler(signum: int, _: "FrameType") -> None: + self.stop_signal_handler(signum) + task.cancel() + + for s in (SIGINT, SIGTERM, SIGABRT): + signal_func(s, stop_handler) + + while True: + task = asyncio.create_task(asyncio.sleep(600)) + + try: + await task + except asyncio.CancelledError: + break + + async def stop(self) -> None: + """关闭""" + logger.info("BOT 正在关闭") + self._running = False + + await self._on_shutdown() + + if self.telegram.updater.running: + await self.telegram.updater.stop() + + await self.shutdown() + + if self.telegram.running: + await self.telegram.stop() + + await self.telegram.shutdown() + if self.web_server is not None: + try: + await self.web_server.shutdown() + logger.info("Web Server 已经关闭") + except AttributeError: + pass + + logger.success("BOT 关闭成功") + + def launch(self) -> None: + """启动""" + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(self.start()) + loop.run_until_complete(self.idle()) + except (SystemExit, KeyboardInterrupt) as exc: + logger.debug("接收到了终止信号,BOT 即将关闭", exc_info=exc) # 接收到了终止信号 + except NetworkError as e: + if isinstance(e, SSLZeroReturnError): + logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.") + else: + logger.critical("网络连接出现问题, 请检查您的网络状况.") + except Exception as e: + logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e) + finally: + loop.run_until_complete(self.stop()) + + if application_config.reload: + raise SystemExit from None + + def on_startup(self, func: Callable[P, R]) -> Callable[P, R]: + """注册一个在 BOT 启动时执行的函数""" + + if func not in self._startup_funcs: + self._startup_funcs.append(func) + + # noinspection PyTypeChecker + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapper + + def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]: + """注册一个在 BOT 停止时执行的函数""" + + if func not in self._shutdown_funcs: + self._shutdown_funcs.append(func) + + # noinspection PyTypeChecker + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapper diff --git a/gram_core/base_service.py b/gram_core/base_service.py new file mode 100644 index 0000000..c61a6e8 --- /dev/null +++ b/gram_core/base_service.py @@ -0,0 +1,60 @@ +from abc import ABC +from itertools import chain +from typing import ClassVar, Iterable, Type, TypeVar + +from typing_extensions import Self + +from utils.helpers import isabstract + +__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services") + + +class _BaseService: + """服务基类""" + + _is_component: ClassVar[bool] = False + _is_dependence: ClassVar[bool] = False + + def __init_subclass__(cls, load: bool = True, **kwargs): + cls.is_dependence = cls._is_dependence + cls.is_component = cls._is_component + cls.load = load + + async def __aenter__(self) -> Self: + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.shutdown() + + async def initialize(self) -> None: + """Initialize resources used by this service""" + + async def shutdown(self) -> None: + """Stop & clear resources used by this service""" + + +class _Dependence(_BaseService, ABC): + _is_dependence: ClassVar[bool] = True + + +class _Component(_BaseService, ABC): + _is_component: ClassVar[bool] = True + + +class BaseService(_BaseService, ABC): + Dependence: Type[_BaseService] = _Dependence + Component: Type[_BaseService] = _Component + + +BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService) +DependenceType = TypeVar("DependenceType", bound=_Dependence) +ComponentType = TypeVar("ComponentType", bound=_Component) + + +# noinspection PyProtectedMember +def get_all_services() -> Iterable[Type[_BaseService]]: + return filter( + lambda x: x.__name__[0] != "_" and x.load and not isabstract(x), + chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()), + ) diff --git a/gram_core/basemodel.py b/gram_core/basemodel.py new file mode 100644 index 0000000..c65f58a --- /dev/null +++ b/gram_core/basemodel.py @@ -0,0 +1,29 @@ +import enum + +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib + +from pydantic import BaseSettings + +__all__ = ("RegionEnum", "Settings") + + +class RegionEnum(int, enum.Enum): + """账号数据所在服务器""" + + NULL = 0 + HYPERION = 1 # 米忽悠国服 hyperion + HOYOLAB = 2 # 米忽悠国际服 hoyolab + + +class Settings(BaseSettings): + def __new__(cls, *args, **kwargs): + cls.update_forward_refs() + return super(Settings, cls).__new__(cls) # pylint: disable=E1120 + + class Config(BaseSettings.Config): + case_sensitive = False + json_loads = jsonlib.loads + json_dumps = jsonlib.dumps diff --git a/gram_core/builtins/__init__.py b/gram_core/builtins/__init__.py new file mode 100644 index 0000000..4f29666 --- /dev/null +++ b/gram_core/builtins/__init__.py @@ -0,0 +1 @@ +"""bot builtins""" diff --git a/gram_core/builtins/contexts.py b/gram_core/builtins/contexts.py new file mode 100644 index 0000000..832c978 --- /dev/null +++ b/gram_core/builtins/contexts.py @@ -0,0 +1,38 @@ +"""上下文管理""" +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from telegram.ext import CallbackContext + from telegram import Update + +__all__ = [ + "CallbackContextCV", + "UpdateCV", + "handler_contexts", + "job_contexts", +] + +CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback") +UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate") + + +@contextmanager +def handler_contexts(update: "Update", context: "CallbackContext") -> None: + context_token = CallbackContextCV.set(context) + update_token = UpdateCV.set(update) + try: + yield + finally: + CallbackContextCV.reset(context_token) + UpdateCV.reset(update_token) + + +@contextmanager +def job_contexts(context: "CallbackContext") -> None: + token = CallbackContextCV.set(context) + try: + yield + finally: + CallbackContextCV.reset(token) diff --git a/gram_core/builtins/dispatcher.py b/gram_core/builtins/dispatcher.py new file mode 100644 index 0000000..236af24 --- /dev/null +++ b/gram_core/builtins/dispatcher.py @@ -0,0 +1,309 @@ +"""参数分发器""" +import asyncio +import inspect +from abc import ABC, abstractmethod +from asyncio import AbstractEventLoop +from functools import cached_property, lru_cache, partial, wraps +from inspect import Parameter, Signature +from itertools import chain +from types import GenericAlias, MethodType +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Type, + Union, +) + +from arkowrapper import ArkoWrapper +from fastapi import FastAPI +from telegram import Bot as TelegramBot, Chat, Message, Update, User +from telegram.ext import Application as TelegramApplication, CallbackContext, Job +from typing_extensions import ParamSpec +from uvicorn import Server + +from gram_core.application import Application +from utils.const import WRAPPER_ASSIGNMENTS +from utils.typedefs import R, T + +__all__ = ( + "catch", + "AbstractDispatcher", + "BaseDispatcher", + "HandlerDispatcher", + "JobDispatcher", + "dispatched", +) + +P = ParamSpec("P") + +TargetType = Union[Type, str, Callable[[Any], bool]] + +_CATCH_TARGET_ATTR = "_catch_targets" + + +def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]: + def decorate(func: Callable[P, R]) -> Callable[P, R]: + setattr(func, _CATCH_TARGET_ATTR, targets) + + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapper + + return decorate + + +@lru_cache(64) +def get_signature(func: Union[type, Callable]) -> Signature: + if isinstance(func, type): + return inspect.signature(func.__init__) + return inspect.signature(func) + + +class AbstractDispatcher(ABC): + """参数分发器""" + + IGNORED_ATTRS = [] + + _args: List[Any] = [] + _kwargs: Dict[Union[str, Type], Any] = {} + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") + return self._application + + def __init__(self, *args, **kwargs) -> None: + self._args = list(args) + self._kwargs = dict(kwargs) + + for _, value in kwargs.items(): + type_arg = type(value) + if type_arg != str: + self._kwargs[type_arg] = value + + for arg in args: + type_arg = type(arg) + if type_arg != str: + self._kwargs[type_arg] = arg + + @cached_property + def catch_funcs(self) -> List[MethodType]: + # noinspection PyTypeChecker + return list( + ArkoWrapper(dir(self)) + .filter(lambda x: not x.startswith("_")) + .filter( + lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"] + ) + .map(lambda x: getattr(self, x)) + .filter(lambda x: isinstance(x, MethodType)) + .filter(lambda x: hasattr(x, "_catch_targets")) + ) + + @cached_property + def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]: + result = {} + for catch_func in self.catch_funcs: + catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR) + for catch_target in catch_targets: + result[catch_target] = catch_func + return result + + @cached_property + def dispatch_funcs(self) -> List[MethodType]: + return list( + ArkoWrapper(dir(self)) + .filter(lambda x: x.startswith("dispatch_by_")) + .map(lambda x: getattr(self, x)) + .filter(lambda x: isinstance(x, MethodType)) + ) + + @abstractmethod + def dispatch_by_default(self, parameter: Parameter) -> Parameter: + """默认的 dispatch 方法""" + + @abstractmethod + def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter: + """使用 catch_func 获取并分配参数""" + + def dispatch(self, func: Callable[P, R]) -> Callable[..., R]: + """将参数分配给函数,从而合成一个无需参数即可执行的函数""" + params = {} + signature = get_signature(func) + parameters: Dict[str, Parameter] = dict(signature.parameters) + + for name, parameter in list(parameters.items()): + parameter: Parameter + if any( + [ + name == "self" and isinstance(func, (type, MethodType)), + parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL], + ] + ): + del parameters[name] + continue + + for dispatch_func in self.dispatch_funcs: + parameters[name] = dispatch_func(parameter) + + for name, parameter in parameters.items(): + if parameter.default != Parameter.empty: + params[name] = parameter.default + else: + params[name] = None + + return partial(func, **params) + + @catch(Application) + def catch_application(self) -> Application: + return self.application + + +class BaseDispatcher(AbstractDispatcher): + """默认参数分发器""" + + _instances: Sequence[Any] + + def _get_kwargs(self) -> Dict[Type[T], T]: + result = self._get_default_kwargs() + result[AbstractDispatcher] = self + result.update(self._kwargs) + return result + + def _get_default_kwargs(self) -> Dict[Type[T], T]: + application = self.application + _default_kwargs = { + FastAPI: application.web_app, + Server: application.web_server, + TelegramApplication: application.telegram, + TelegramBot: application.telegram.bot, + } + if not application.running: + for obj in chain( + application.managers.dependency, + application.managers.components, + application.managers.services, + application.managers.plugins, + ): + _default_kwargs[type(obj)] = obj + return {k: v for k, v in _default_kwargs.items() if v is not None} + + def dispatch_by_default(self, parameter: Parameter) -> Parameter: + annotation = parameter.annotation + # noinspection PyTypeChecker + if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None: + parameter._default = value # pylint: disable=W0212 + return parameter + + def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter: + annotation = parameter.annotation + if annotation != Any and isinstance(annotation, GenericAlias): + return parameter + + catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name) + if catch_func is not None: + # noinspection PyUnresolvedReferences,PyProtectedMember + parameter._default = catch_func() # pylint: disable=W0212 + return parameter + + @catch(AbstractEventLoop) + def catch_loop(self) -> AbstractEventLoop: + return asyncio.get_event_loop() + + +class HandlerDispatcher(BaseDispatcher): + """Handler 参数分发器""" + + def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None: + super().__init__(update=update, context=context, **kwargs) + self._update = update + self._context = context + + def dispatch( + self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None + ) -> Callable[..., R]: + self._update = update or self._update + self._context = context or self._context + if self._update is None: + from gram_core.builtins.contexts import UpdateCV + + self._update = UpdateCV.get() + if self._context is None: + from gram_core.builtins.contexts import CallbackContextCV + + self._context = CallbackContextCV.get() + return super().dispatch(func) + + def dispatch_by_default(self, parameter: Parameter) -> Parameter: + """HandlerDispatcher 默认不使用 dispatch_by_default""" + return parameter + + @catch(Update) + def catch_update(self) -> Update: + return self._update + + @catch(CallbackContext) + def catch_context(self) -> CallbackContext: + return self._context + + @catch(Message) + def catch_message(self) -> Message: + return self._update.effective_message + + @catch(User) + def catch_user(self) -> User: + return self._update.effective_user + + @catch(Chat) + def catch_chat(self) -> Chat: + return self._update.effective_chat + + +class JobDispatcher(BaseDispatcher): + """Job 参数分发器""" + + def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None: + super().__init__(context=context, **kwargs) + self._context = context + + def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]: + self._context = context or self._context + if self._context is None: + from gram_core.builtins.contexts import CallbackContextCV + + self._context = CallbackContextCV.get() + return super().dispatch(func) + + @catch("data") + def catch_data(self) -> Any: + return self._context.job.data + + @catch(Job) + def catch_job(self) -> Job: + return self._context.job + + @catch(CallbackContext) + def catch_context(self) -> CallbackContext: + return self._context + + +def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher): + def decorate(func: Callable[P, R]) -> Callable[P, R]: + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return dispatcher().dispatch(func)(*args, **kwargs) + + return wrapper + + return decorate diff --git a/gram_core/builtins/executor.py b/gram_core/builtins/executor.py new file mode 100644 index 0000000..ad16372 --- /dev/null +++ b/gram_core/builtins/executor.py @@ -0,0 +1,131 @@ +"""执行器""" +import inspect +from functools import cached_property +from multiprocessing import RLock as Lock +from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar + +from telegram import Update +from telegram.ext import CallbackContext +from typing_extensions import ParamSpec, Self + +from gram_core.builtins.contexts import handler_contexts, job_contexts + +if TYPE_CHECKING: + from gram_core.application import Application + from gram_core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher + from multiprocessing.synchronize import RLock as LockType + +__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor") + +T = TypeVar("T") +R = TypeVar("R") +P = ParamSpec("P") + + +class BaseExecutor: + """执行器 + Args: + name(str): 该执行器的名称。执行器的名称是唯一的。 + + 只支持执行只拥有 POSITIONAL_OR_KEYWORD 和 KEYWORD_ONLY 两种参数类型的函数 + """ + + _lock: ClassVar["LockType"] = Lock() + _instances: ClassVar[Dict[str, Self]] = {} + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") + return self._application + + def __new__(cls: Type[T], name: str, *args, **kwargs) -> T: + with cls._lock: + if (instance := cls._instances.get(name)) is None: + instance = object.__new__(cls) + instance.__init__(name, *args, **kwargs) + cls._instances.update({name: instance}) + return instance + + @cached_property + def name(self) -> str: + """当前执行器的名称""" + return self._name + + def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None: + self._name = name + self._dispatcher = dispatcher + + +class Executor(BaseExecutor, Generic[P, R]): + async def __call__( + self, + target: Callable[P, R], + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + **kwargs, + ) -> R: + dispatcher = self._dispatcher or dispatcher + dispatcher_instance = dispatcher(**kwargs) + dispatcher_instance.set_application(application=self.application) + dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数 + + # 执行 + if inspect.iscoroutinefunction(target): + result = await dispatched_func() + else: + result = dispatched_func() + + return result + + +class HandlerExecutor(BaseExecutor, Generic[P, R]): + """Handler专用执行器""" + + _callback: Callable[P, R] + _dispatcher: "HandlerDispatcher" + + def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None: + if dispatcher is None: + from gram_core.builtins.dispatcher import HandlerDispatcher + + dispatcher = HandlerDispatcher + super().__init__("handler", dispatcher) + self._callback = func + self._dispatcher = dispatcher() + + def set_application(self, application: "Application") -> None: + self._application = application + if self._dispatcher is not None: + self._dispatcher.set_application(application) + + async def __call__(self, update: Update, context: CallbackContext) -> R: + with handler_contexts(update, context): + dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context) + return await dispatched_func() + + +class JobExecutor(BaseExecutor): + """Job 专用执行器""" + + def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None: + if dispatcher is None: + from gram_core.builtins.dispatcher import JobDispatcher + + dispatcher = JobDispatcher + super().__init__("job", dispatcher) + self._callback = func + self._dispatcher = dispatcher() + + def set_application(self, application: "Application") -> None: + self._application = application + if self._dispatcher is not None: + self._dispatcher.set_application(application) + + async def __call__(self, context: CallbackContext) -> R: + with job_contexts(context): + dispatched_func = self._dispatcher.dispatch(self._callback, context=context) + return await dispatched_func() diff --git a/gram_core/builtins/reloader.py b/gram_core/builtins/reloader.py new file mode 100644 index 0000000..6b09f07 --- /dev/null +++ b/gram_core/builtins/reloader.py @@ -0,0 +1,185 @@ +import inspect +import multiprocessing +import os +import signal +import threading +from pathlib import Path +from typing import Callable, Iterator, List, Optional, TYPE_CHECKING + +from watchfiles import watch + +from utils.const import HANDLED_SIGNALS, PROJECT_ROOT +from utils.log import logger +from utils.typedefs import StrOrPath + +if TYPE_CHECKING: + from multiprocessing.process import BaseProcess + +__all__ = ("Reloader",) + +multiprocessing.allow_connection_pickling() +spawn = multiprocessing.get_context("spawn") + + +class FileFilter: + """监控文件过滤""" + + def __init__(self, includes: List[str], excludes: List[str]) -> None: + default_includes = ["*.py"] + self.includes = [default for default in default_includes if default not in excludes] + self.includes.extend(includes) + self.includes = list(set(self.includes)) + + default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__] + self.excludes = [default for default in default_excludes if default not in includes] + self.exclude_dirs = [] + for e in excludes: + p = Path(e) + try: + is_dir = p.is_dir() + except OSError: + is_dir = False + + if is_dir: + self.exclude_dirs.append(p) + else: + self.excludes.append(e) + self.excludes = list(set(self.excludes)) + + def __call__(self, path: Path) -> bool: + for include_pattern in self.includes: + if path.match(include_pattern): + for exclude_dir in self.exclude_dirs: + if exclude_dir in path.parents: + return False + + for exclude_pattern in self.excludes: + if path.match(exclude_pattern): + return False + + return True + return False + + +class Reloader: + _target: Callable[..., None] + _process: "BaseProcess" + + @property + def process(self) -> "BaseProcess": + return self._process + + @property + def target(self) -> Callable[..., None]: + return self._target + + def __init__( + self, + target: Callable[..., None], + *, + reload_delay: float = 0.25, + reload_dirs: List[StrOrPath] = None, + reload_includes: List[str] = None, + reload_excludes: List[str] = None, + ): + if inspect.iscoroutinefunction(target): + raise ValueError("不支持异步函数") + self._target = target + + self.reload_delay = reload_delay + + _reload_dirs = [] + for reload_dir in reload_dirs or []: + _reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir))) + + self.reload_dirs = [] + for reload_dir in _reload_dirs: + append = True + for parent in reload_dir.parents: + if parent in _reload_dirs: + append = False + break + if append: + self.reload_dirs.append(reload_dir) + + if not self.reload_dirs: + logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"}) + + self._should_exit = threading.Event() + + frame = inspect.currentframe().f_back + + self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]]) + self.watcher = watch( + *self.reload_dirs, + watch_filter=None, + stop_event=self._should_exit, + yield_on_timeout=True, + ) + + def get_changes(self) -> Optional[List[Path]]: + if not self._process.is_alive(): + logger.info("目标进程已经关闭", extra={"tag": "Reloader"}) + self._should_exit.set() + try: + changes = next(self.watcher) + except StopIteration: + return None + if changes: + unique_paths = {Path(c[1]) for c in changes} + return [p for p in unique_paths if self.watch_filter(p)] + return None + + def __iter__(self) -> Iterator[Optional[List[Path]]]: + return self + + def __next__(self) -> Optional[List[Path]]: + return self.get_changes() + + def run(self) -> None: + self.startup() + for changes in self: + if changes: + logger.warning( + "检测到文件 %s 发生改变, 正在重载...", + [str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes], + extra={"tag": "Reloader"}, + ) + self.restart() + + self.shutdown() + + def signal_handler(self, *_) -> None: + """当接收到结束信号量时""" + self._process.join(3) + if self._process.is_alive(): + self._process.terminate() + self._process.join() + self._should_exit.set() + + def startup(self) -> None: + """启动进程""" + logger.info("目标进程正在启动", extra={"tag": "Reloader"}) + + for sig in HANDLED_SIGNALS: + signal.signal(sig, self.signal_handler) + + self._process = spawn.Process(target=self._target) + self._process.start() + logger.success("目标进程启动成功", extra={"tag": "Reloader"}) + + def restart(self) -> None: + """重启进程""" + self._process.terminate() + self._process.join(10) + + self._process = spawn.Process(target=self._target) + self._process.start() + logger.info("目标进程已经重载", extra={"tag": "Reloader"}) + + def shutdown(self) -> None: + """关闭进程""" + self._process.terminate() + self._process.join(10) + + logger.info("重载器已经关闭", extra={"tag": "Reloader"}) diff --git a/gram_core/config.py b/gram_core/config.py new file mode 100644 index 0000000..9f48794 --- /dev/null +++ b/gram_core/config.py @@ -0,0 +1,161 @@ +from enum import Enum +from pathlib import Path +from typing import List, Optional, Union + +import dotenv +from pydantic import AnyUrl, Field + +from gram_core.basemodel import Settings +from utils.const import PROJECT_ROOT +from utils.typedefs import NaturalNumber + +__all__ = ("ApplicationConfig", "config", "JoinGroups") + +dotenv.load_dotenv() + + +class JoinGroups(str, Enum): + NO_ALLOW = "NO_ALLOW" + ALLOW_AUTH_USER = "ALLOW_AUTH_USER" + ALLOW_USER = "ALLOW_USER" + ALLOW_ALL = "ALLOW_ALL" + + +class DatabaseConfig(Settings): + driver_name: str = "mysql+asyncmy" + host: Optional[str] = None + port: Optional[int] = None + username: Optional[str] = None + password: Optional[str] = None + database: Optional[str] = None + + class Config(Settings.Config): + env_prefix = "db_" + + +class RedisConfig(Settings): + host: str = "127.0.0.1" + port: int = 6379 + database: int = Field(default=0, env="redis_db") + password: Optional[str] = None + + class Config(Settings.Config): + env_prefix = "redis_" + + +class LoggerConfig(Settings): + name: str = "PaiGram" + width: Optional[int] = None + time_format: str = "[%Y-%m-%d %X]" + traceback_max_frames: int = 20 + path: Path = PROJECT_ROOT / "logs" + render_keywords: List[str] = ["BOT"] + locals_max_length: int = 10 + locals_max_string: int = 80 + locals_max_depth: Optional[NaturalNumber] = None + filtered_names: List[str] = ["uvicorn"] + + class Config(Settings.Config): + env_prefix = "logger_" + + +class MTProtoConfig(Settings): + api_id: Optional[int] = None + api_hash: Optional[str] = None + + +class WebServerConfig(Settings): + enable: bool = False + """是否启用WebServer""" + + url: AnyUrl = "http://localhost:8080" + host: str = "localhost" + port: int = 8080 + + class Config(Settings.Config): + env_prefix = "web_" + + +class ErrorConfig(Settings): + pb_url: str = "" + pb_sunset: int = 43200 + pb_max_lines: int = 1000 + sentry_dsn: str = "" + notification_chat_id: Optional[str] = None + + class Config(Settings.Config): + env_prefix = "error_" + + +class ReloadConfig(Settings): + delay: float = 0.25 + dirs: List[str] = [] + include: List[str] = [] + exclude: List[str] = [] + + class Config(Settings.Config): + env_prefix = "reload_" + + +class NoticeConfig(Settings): + user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!" + + class Config(Settings.Config): + env_prefix = "notice_" + + +class ApplicationConfig(Settings): + debug: bool = False + """debug 开关""" + retry: int = 5 + """重试次数""" + auto_reload: bool = False + """自动重载""" + + proxy_url: Optional[AnyUrl] = None + """代理链接""" + upload_bbs_host: Optional[AnyUrl] = "https://upload-bbs.miyoushe.com" + + bot_token: str = "" + """BOT的token""" + + owner: Optional[int] = None + + channels: List[int] = [] + """文章推送群组""" + + verify_groups: List[Union[int, str]] = [] + """启用群验证功能的群组""" + join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW + """是否允许机器人被邀请到其它群组""" + + timeout: int = 10 + connection_pool_size: int = 256 + read_timeout: Optional[float] = None + write_timeout: Optional[float] = None + connect_timeout: Optional[float] = None + pool_timeout: Optional[float] = None + update_read_timeout: Optional[float] = None + update_write_timeout: Optional[float] = None + update_connect_timeout: Optional[float] = None + update_pool_timeout: Optional[float] = None + + genshin_ttl: Optional[int] = None + + enka_network_api_agent: str = "" + pass_challenge_api: str = "" + pass_challenge_app_key: str = "" + pass_challenge_user_web: str = "" + + reload: ReloadConfig = ReloadConfig() + database: DatabaseConfig = DatabaseConfig() + logger: LoggerConfig = LoggerConfig() + webserver: WebServerConfig = WebServerConfig() + redis: RedisConfig = RedisConfig() + mtproto: MTProtoConfig = MTProtoConfig() + error: ErrorConfig = ErrorConfig() + notice: NoticeConfig = NoticeConfig() + + +ApplicationConfig.update_forward_refs() +config = ApplicationConfig() diff --git a/gram_core/dependence/__init__.py b/gram_core/dependence/__init__.py new file mode 100644 index 0000000..4ac55c2 --- /dev/null +++ b/gram_core/dependence/__init__.py @@ -0,0 +1 @@ +"""基础服务""" diff --git a/gram_core/dependence/aiobrowser.py b/gram_core/dependence/aiobrowser.py new file mode 100644 index 0000000..a0ad299 --- /dev/null +++ b/gram_core/dependence/aiobrowser.py @@ -0,0 +1,56 @@ +from typing import Optional, TYPE_CHECKING + +from playwright.async_api import Error, async_playwright + +from gram_core.base_service import BaseService +from utils.log import logger + +if TYPE_CHECKING: + from playwright.async_api import Playwright as AsyncPlaywright, Browser + +__all__ = ("AioBrowser",) + + +class AioBrowser(BaseService.Dependence): + @property + def browser(self): + return self._browser + + def __init__(self, loop=None): + self._browser: Optional["Browser"] = None + self._playwright: Optional["AsyncPlaywright"] = None + self._loop = loop + + async def get_browser(self): + if self._browser is None: + await self.initialize() + return self._browser + + async def initialize(self): + if self._playwright is None: + logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True}) + self._playwright = await async_playwright().start() + logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True}) + if self._browser is None: + logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True}) + try: + self._browser = await self._playwright.chromium.launch(timeout=5000) + logger.success("[blue]Browser[/] 启动成功", extra={"markup": True}) + except Error as err: + if "playwright install" in str(err): + logger.error( + "检查到 [blue]playwright[/] 刚刚安装或者未升级\n" + "请运行以下命令下载新浏览器\n" + "[blue bold]playwright install chromium[/]", + extra={"markup": True}, + ) + raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium") + raise err + + return self._browser + + async def shutdown(self): + if self._browser is not None: + await self._browser.close() + if self._playwright is not None: + self._playwright.stop() diff --git a/gram_core/dependence/aiobrowser.pyi b/gram_core/dependence/aiobrowser.pyi new file mode 100644 index 0000000..51694e5 --- /dev/null +++ b/gram_core/dependence/aiobrowser.pyi @@ -0,0 +1,16 @@ +from asyncio import AbstractEventLoop + +from playwright.async_api import Browser, Playwright as AsyncPlaywright + +from gram_core.base_service import BaseService + +__all__ = ("AioBrowser",) + +class AioBrowser(BaseService.Dependence): + _browser: Browser | None + _playwright: AsyncPlaywright | None + _loop: AbstractEventLoop + + @property + def browser(self) -> Browser | None: ... + async def get_browser(self) -> Browser: ... diff --git a/gram_core/dependence/database.py b/gram_core/dependence/database.py new file mode 100644 index 0000000..752e5e1 --- /dev/null +++ b/gram_core/dependence/database.py @@ -0,0 +1,51 @@ +import contextlib +from typing import Optional + +from sqlalchemy.engine import URL +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import sessionmaker +from typing_extensions import Self + +from gram_core.base_service import BaseService +from gram_core.config import ApplicationConfig +from gram_core.sqlmodel.session import AsyncSession + +__all__ = ("Database",) + + +class Database(BaseService.Dependence): + @classmethod + def from_config(cls, config: ApplicationConfig) -> Self: + return cls(**config.database.dict()) + + def __init__( + self, + driver_name: str, + host: Optional[str] = None, + port: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + database: Optional[str] = None, + ): + self.database = database # skipcq: PTC-W0052 + self.password = password + self.username = username + self.port = port + self.host = host + self.url = URL.create( + driver_name, + username=self.username, + password=self.password, + host=self.host, + port=self.port, + database=self.database, + ) + self.engine = create_async_engine(self.url) + self.Session = sessionmaker(bind=self.engine, class_=AsyncSession) + + @contextlib.asynccontextmanager + async def session(self) -> AsyncSession: + yield self.Session() + + async def shutdown(self): + self.Session.close_all() diff --git a/gram_core/dependence/mtproto.py b/gram_core/dependence/mtproto.py new file mode 100644 index 0000000..eadbb29 --- /dev/null +++ b/gram_core/dependence/mtproto.py @@ -0,0 +1,67 @@ +import os +from typing import Optional +from urllib.parse import urlparse + +import aiofiles + +from gram_core.base_service import BaseService +from gram_core.config import config as bot_config +from utils.log import logger + +try: + from pyrogram import Client + from pyrogram.session import session + + session.log.debug = lambda *args, **kwargs: None # 关闭日记 + PYROGRAM_AVAILABLE = True +except ImportError: + Client = None + session = None + PYROGRAM_AVAILABLE = False + + +class MTProto(BaseService.Dependence): + async def get_session(self): + async with aiofiles.open(self.session_path, mode="r") as f: + return await f.read() + + async def set_session(self, b: str): + async with aiofiles.open(self.session_path, mode="w+") as f: + await f.write(b) + + def session_exists(self): + return os.path.exists(self.session_path) + + def __init__(self): + self.name = "paigram" + current_dir = os.getcwd() + self.session_path = os.path.join(current_dir, "paigram.session") + self.client: Optional[Client] = None + self.proxy: Optional[dict] = None + http_proxy = os.environ.get("HTTP_PROXY") + if http_proxy is not None: + http_proxy_url = urlparse(http_proxy) + self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port} + + async def initialize(self): # pylint: disable=W0221 + if not PYROGRAM_AVAILABLE: + logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None") + return + if bot_config.mtproto.api_id is None: + logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None") + return + if bot_config.mtproto.api_hash is None: + logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None") + return + self.client = Client( + api_id=bot_config.mtproto.api_id, + api_hash=bot_config.mtproto.api_hash, + name=self.name, + bot_token=bot_config.bot_token, + proxy=self.proxy, + ) + await self.client.start() + + async def shutdown(self): # pylint: disable=W0221 + if self.client is not None: + await self.client.stop(block=False) diff --git a/gram_core/dependence/mtproto.pyi b/gram_core/dependence/mtproto.pyi new file mode 100644 index 0000000..c780bb7 --- /dev/null +++ b/gram_core/dependence/mtproto.pyi @@ -0,0 +1,31 @@ +from __future__ import annotations +from typing import TypedDict + +from gram_core.base_service import BaseService + +try: + from pyrogram import Client + from pyrogram.session import session + + PYROGRAM_AVAILABLE = True +except ImportError: + Client = None + session = None + PYROGRAM_AVAILABLE = False + +__all__ = ("MTProto",) + +class _ProxyType(TypedDict): + scheme: str + hostname: str | None + port: int | None + +class MTProto(BaseService.Dependence): + name: str + session_path: str + client: Client | None + proxy: _ProxyType | None + + async def get_session(self) -> str: ... + async def set_session(self, b: str) -> None: ... + def session_exists(self) -> bool: ... diff --git a/gram_core/dependence/redisdb.py b/gram_core/dependence/redisdb.py new file mode 100644 index 0000000..6cada50 --- /dev/null +++ b/gram_core/dependence/redisdb.py @@ -0,0 +1,50 @@ +from typing import Optional, Union + +import fakeredis.aioredis +from redis import asyncio as aioredis +from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError +from typing_extensions import Self + +from gram_core.base_service import BaseService +from gram_core.config import ApplicationConfig +from utils.log import logger + +__all__ = ["RedisDB"] + + +class RedisDB(BaseService.Dependence): + @classmethod + def from_config(cls, config: ApplicationConfig) -> Self: + return cls(**config.redis.dict()) + + def __init__( + self, host: str = "127.0.0.1", port: int = 6379, database: Union[str, int] = 0, password: Optional[str] = None + ): + self.client = aioredis.Redis(host=host, port=port, db=database, password=password) + self.ttl = 600 + + async def ping(self): + # noinspection PyUnresolvedReferences + if await self.client.ping(): + logger.info("连接 [red]Redis[/] 成功", extra={"markup": True}) + else: + logger.info("连接 [red]Redis[/] 失败", extra={"markup": True}) + raise RuntimeError("连接 Redis 失败") + + async def start_fake_redis(self): + self.client = fakeredis.aioredis.FakeRedis() + await self.ping() + + async def initialize(self): + logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True}) + try: + await self.ping() + except (RedisTimeoutError, RedisConnectionError) as exc: + if isinstance(exc, RedisTimeoutError): + logger.warning("连接 [red]Redis[/] 超时,使用 [red]fakeredis[/] 模拟", extra={"markup": True}) + if isinstance(exc, RedisConnectionError): + logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True}) + await self.start_fake_redis() + + async def shutdown(self): + await self.client.close() diff --git a/gram_core/error.py b/gram_core/error.py new file mode 100644 index 0000000..344b684 --- /dev/null +++ b/gram_core/error.py @@ -0,0 +1,7 @@ +"""此模块包含核心模块的错误的基类""" +from typing import Union + + +class ServiceNotFoundError(Exception): + def __init__(self, name: Union[str, type]): + super().__init__(f"No service named '{name if isinstance(name, str) else name.__name__}'") diff --git a/gram_core/handler/__init__.py b/gram_core/handler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gram_core/handler/adminhandler.py b/gram_core/handler/adminhandler.py new file mode 100644 index 0000000..e8c7096 --- /dev/null +++ b/gram_core/handler/adminhandler.py @@ -0,0 +1,59 @@ +import asyncio +from typing import TypeVar, TYPE_CHECKING, Any, Optional + +from telegram import Update +from telegram.ext import ApplicationHandlerStop, BaseHandler + +from gram_core.error import ServiceNotFoundError +from gram_core.services.users.services import UserAdminService +from utils.log import logger + +if TYPE_CHECKING: + from gram_core.application import Application + from telegram.ext import Application as TelegramApplication + +RT = TypeVar("RT") +UT = TypeVar("UT") + +CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") + + +class AdminHandler(BaseHandler[Update, CCT]): + _lock = asyncio.Lock() + + def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None: + self.handler = handler + self.application = application + self.user_service: Optional["UserAdminService"] = None + super().__init__(self.handler.callback, self.handler.block) + + def check_update(self, update: object) -> bool: + if not isinstance(update, Update): + return False + return self.handler.check_update(update) + + async def _user_service(self) -> "UserAdminService": + async with self._lock: + if self.user_service is not None: + return self.user_service + user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None) + if user_service is None: + raise ServiceNotFoundError("UserAdminService") + self.user_service = user_service + return self.user_service + + async def handle_update( + self, + update: "UT", + application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]", + check_result: Any, + context: "CCT", + ) -> RT: + user_service = await self._user_service() + user = update.effective_user + if await user_service.is_admin(user.id): + return await self.handler.handle_update(update, application, check_result, context) + message = update.effective_message + logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id) + await message.reply_text("权限不足") + raise ApplicationHandlerStop diff --git a/gram_core/handler/callbackqueryhandler.py b/gram_core/handler/callbackqueryhandler.py new file mode 100644 index 0000000..f931e4e --- /dev/null +++ b/gram_core/handler/callbackqueryhandler.py @@ -0,0 +1,62 @@ +import asyncio +from contextlib import AbstractAsyncContextManager +from types import TracebackType +from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type + +from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop + +from utils.log import logger + +if TYPE_CHECKING: + from telegram.ext import Application + +RT = TypeVar("RT") +UT = TypeVar("UT") +CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") + + +class OverlappingException(Exception): + pass + + +class OverlappingContext(AbstractAsyncContextManager): + _lock = asyncio.Lock() + + def __init__(self, context: "CCT"): + self.context = context + + async def __aenter__(self) -> None: + async with self._lock: + flag = self.context.user_data.get("overlapping", False) + if flag: + raise OverlappingException + self.context.user_data["overlapping"] = True + return None + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + async with self._lock: + del self.context.user_data["overlapping"] + return None + + +class CallbackQueryHandler(BaseCallbackQueryHandler): + async def handle_update( + self, + update: "UT", + application: "Application[Any, CCT, Any, Any, Any, Any]", + check_result: Any, + context: "CCT", + ) -> RT: + self.collect_additional_context(context, update, application, check_result) + try: + async with OverlappingContext(context): + return await self.callback(update, context) + except OverlappingException as exc: + user = update.effective_user + logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id) + raise ApplicationHandlerStop from exc diff --git a/gram_core/handler/limiterhandler.py b/gram_core/handler/limiterhandler.py new file mode 100644 index 0000000..53bc4c0 --- /dev/null +++ b/gram_core/handler/limiterhandler.py @@ -0,0 +1,71 @@ +import asyncio +from typing import TypeVar, Optional + +from telegram import Update +from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler + +from utils.log import logger + +UT = TypeVar("UT") +CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") + + +class LimiterHandler(TypeHandler[UT, CCT]): + _lock = asyncio.Lock() + + def __init__( + self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None + ): + """Limiter Handler 通过 + `Leaky bucket algorithm `_ + 实现对用户的输入的精确控制 + + 输入超过一定速率后,代码会抛出 + :class:`telegram.ext.ApplicationHandlerStop` + 异常并在一段时间内防止用户执行任何其他操作 + + :param max_rate: 在抛出异常之前最多允许 频率/秒 的速度 + :param time_period: 在限制速率的时间段的持续时间 + :param amount: 提供的容量 + :param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount + """ + self.max_rate = max_rate + self.amount = amount + self._rate_per_sec = max_rate / time_period + self.limit_time = limit_time + super().__init__(Update, self.limiter_callback) + + async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + if update.inline_query is not None: + return + loop = asyncio.get_running_loop() + async with self._lock: + time = loop.time() + user_data = context.user_data + if user_data is None: + return + user_limit_time = user_data.get("limit_time") + if user_limit_time is not None: + if time >= user_limit_time: + del user_data["limit_time"] + else: + raise ApplicationHandlerStop + last_task_time = user_data.get("last_task_time", 0) + if last_task_time: + task_level = user_data.get("task_level", 0) + elapsed = time - last_task_time + decrement = elapsed * self._rate_per_sec + task_level = max(task_level - decrement, 0) + user_data["task_level"] = task_level + if not task_level + self.amount <= self.max_rate: + if self.limit_time: + limit_time = self.limit_time + else: + limit_time = 1 / self._rate_per_sec * self.amount + user_data["limit_time"] = time + limit_time + user = update.effective_user + logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s 秒", user.full_name, user.id, limit_time) + raise ApplicationHandlerStop + user_data["last_task_time"] = time + task_level = user_data.get("task_level", 0) + user_data["task_level"] = task_level + self.amount diff --git a/gram_core/manager.py b/gram_core/manager.py new file mode 100644 index 0000000..3ea70fe --- /dev/null +++ b/gram_core/manager.py @@ -0,0 +1,286 @@ +import asyncio +from importlib import import_module +from pathlib import Path +from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar + +from arkowrapper import ArkoWrapper +from async_timeout import timeout +from typing_extensions import ParamSpec + +from gram_core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services +from gram_core.config import config as bot_config +from utils.const import PLUGIN_DIR, PROJECT_ROOT +from utils.helpers import gen_pkg +from utils.log import logger + +if TYPE_CHECKING: + from gram_core.application import Application + from gram_core.plugin import PluginType + from gram_core.builtins.executor import Executor + +__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers") + +R = TypeVar("R") +T = TypeVar("T") +P = ParamSpec("P") + + +def _load_module(path: Path) -> None: + for pkg in gen_pkg(path): + try: + logger.debug('正在导入 "%s"', pkg) + import_module(pkg) + except Exception as e: + logger.exception( + '在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True} + ) + raise SystemExit from e + + +class Manager(Generic[T]): + """生命周期控制基类""" + + _executor: Optional["Executor"] = None + _lib: Dict[Type[T], T] = {} + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") + return self._application + + @property + def executor(self) -> "Executor": + """执行器""" + if self._executor is None: + raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.") + return self._executor + + def build_executor(self, name: str): + from gram_core.builtins.executor import Executor + from gram_core.builtins.dispatcher import BaseDispatcher + + self._executor = Executor(name, dispatcher=BaseDispatcher) + self._executor.set_application(self.application) + + +class DependenceManager(Manager[DependenceType]): + """基础依赖管理""" + + _dependency: Dict[Type[DependenceType], DependenceType] = {} + + @property + def dependency(self) -> List[DependenceType]: + return list(self._dependency.values()) + + @property + def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]: + return self._dependency + + async def start_dependency(self) -> None: + _load_module(PROJECT_ROOT / "core/dependence") + + for dependence in filter(lambda x: x.is_dependence, get_all_services()): + dependence: Type[DependenceType] + instance: DependenceType + try: + if hasattr(dependence, "from_config"): # 如果有 from_config 方法 + instance = dependence.from_config(bot_config) # 用 from_config 实例化服务 + else: + instance = await self.executor(dependence) + + await instance.initialize() + logger.success('基础服务 "%s" 启动成功', dependence.__name__) + + self._lib[dependence] = instance + self._dependency[dependence] = instance + + except Exception as e: + logger.exception('基础服务 "%s" 初始化失败,BOT 将自动关闭', dependence.__name__) + raise SystemExit from e + + async def stop_dependency(self) -> None: + async def task(d): + try: + async with timeout(5): + await d.shutdown() + logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__) + except asyncio.TimeoutError: + logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__) + except Exception as e: + logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e) + + tasks = [] + for dependence in self._dependency.values(): + tasks.append(asyncio.create_task(task(dependence))) + + await asyncio.gather(*tasks) + + +class ComponentManager(Manager[ComponentType]): + """组件管理""" + + _components: Dict[Type[ComponentType], ComponentType] = {} + + @property + def components(self) -> List[ComponentType]: + return list(self._components.values()) + + @property + def components_map(self) -> Dict[Type[ComponentType], ComponentType]: + return self._components + + async def init_components(self): + for path in filter( + lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir() + ): + _load_module(path) + components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component) + retry_times = 0 + max_retry_times = len(components) + while components: + start_len = len(components) + for component in list(components): + component: Type[ComponentType] + instance: ComponentType + try: + instance = await self.executor(component) + self._lib[component] = instance + self._components[component] = instance + components = components.remove(component) + except Exception as e: # pylint: disable=W0703 + logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True}) + end_len = len(list(components)) + if start_len == end_len: + retry_times += 1 + + if retry_times == max_retry_times and components: + for component in components: + logger.error('组件 "%s" 初始化失败', component.__name__) + raise SystemExit + + +class ServiceManager(Manager[BaseServiceType]): + """服务控制类""" + + _services: Dict[Type[BaseServiceType], BaseServiceType] = {} + + @property + def services(self) -> List[BaseServiceType]: + return list(self._services.values()) + + @property + def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]: + return self._services + + async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType: + instance: BaseServiceType + try: + if hasattr(target, "from_config"): # 如果有 from_config 方法 + instance = target.from_config(bot_config) # 用 from_config 实例化服务 + else: + instance = await self.executor(target) + + await instance.initialize() + logger.success('服务 "%s" 启动成功', target.__name__) + + return instance + + except Exception as e: # pylint: disable=W0703 + logger.exception('服务 "%s" 初始化失败,BOT 将自动关闭', target.__name__) + raise SystemExit from e + + async def start_services(self) -> None: + for path in filter( + lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir() + ): + _load_module(path) + + for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类 + instance = await self._initialize_service(service) + + self._lib[service] = instance + self._services[service] = instance + + async def stop_services(self) -> None: + """关闭服务""" + if not self._services: + return + + async def task(s): + try: + async with timeout(5): + await s.shutdown() + logger.success('服务 "%s" 关闭成功', s.__class__.__name__) + except asyncio.TimeoutError: + logger.warning('服务 "%s" 关闭超时', s.__class__.__name__) + except Exception as e: + logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e) + + logger.info("正在关闭服务") + tasks = [] + for service in self._services.values(): + tasks.append(asyncio.create_task(task(service))) + + await asyncio.gather(*tasks) + + +class PluginManager(Manager["PluginType"]): + """插件管理""" + + _plugins: Dict[Type["PluginType"], "PluginType"] = {} + + @property + def plugins(self) -> List["PluginType"]: + """所有已经加载的插件""" + return list(self._plugins.values()) + + @property + def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]: + return self._plugins + + async def install_plugins(self) -> None: + """安装所有插件""" + from gram_core.plugin import get_all_plugins + + for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()): + _load_module(path) + + for plugin in get_all_plugins(): + plugin: Type["PluginType"] + + try: + instance: "PluginType" = await self.executor(plugin) + except Exception as e: # pylint: disable=W0703 + logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) + continue + + self._plugins[plugin] = instance + + if self._application is not None: + instance.set_application(self._application) + + await asyncio.create_task(self.plugin_install_task(plugin, instance)) + + @staticmethod + async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"): + try: + await instance.install() + logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}") + except Exception as e: # pylint: disable=W0703 + logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) + + async def uninstall_plugins(self) -> None: + for plugin in self._plugins.values(): + try: + await plugin.uninstall() + except Exception as e: # pylint: disable=W0703 + logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) + + +class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager): + """BOT 除自身外的生命周期管理类""" diff --git a/gram_core/override/__init__.py b/gram_core/override/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gram_core/override/telegram.py b/gram_core/override/telegram.py new file mode 100644 index 0000000..d3fac5a --- /dev/null +++ b/gram_core/override/telegram.py @@ -0,0 +1,117 @@ +"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化""" +from typing import Any, AsyncIterable, Optional + +import httpcore +from httpx import ( + AsyncByteStream, + AsyncHTTPTransport as DefaultAsyncHTTPTransport, + Limits, + Response as DefaultResponse, + Timeout, +) +from telegram.request import HTTPXRequest as DefaultHTTPXRequest + +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib + +__all__ = ("HTTPXRequest",) + + +class Response(DefaultResponse): + def json(self, **kwargs: Any) -> Any: + # noinspection PyProtectedMember + from httpx._utils import guess_json_utf + + if self.charset_encoding is None and self.content and len(self.content) > 3: + encoding = guess_json_utf(self.content) + if encoding is not None: + return jsonlib.loads(self.content.decode(encoding), **kwargs) + return jsonlib.loads(self.text, **kwargs) + + +# noinspection PyProtectedMember +class AsyncHTTPTransport(DefaultAsyncHTTPTransport): + async def handle_async_request(self, request) -> Response: + from httpx._transports.default import ( + map_httpcore_exceptions, + AsyncResponseStream, + ) + + if not isinstance(request.stream, AsyncByteStream): + raise AssertionError + + req = httpcore.Request( + method=request.method, + url=httpcore.URL( + scheme=request.url.raw_scheme, + host=request.url.raw_host, + port=request.url.port, + target=request.url.raw_path, + ), + headers=request.headers.raw, + content=request.stream, + extensions=request.extensions, + ) + with map_httpcore_exceptions(): + resp = await self._pool.handle_async_request(req) + + if not isinstance(resp.stream, AsyncIterable): + raise AssertionError + + return Response( + status_code=resp.status, + headers=resp.headers, + stream=AsyncResponseStream(resp.stream), + extensions=resp.extensions, + ) + + +class HTTPXRequest(DefaultHTTPXRequest): + def __init__( # pylint: disable=W0231 + self, + connection_pool_size: int = 1, + proxy_url: str = None, + read_timeout: Optional[float] = 5.0, + write_timeout: Optional[float] = 5.0, + connect_timeout: Optional[float] = 5.0, + pool_timeout: Optional[float] = 1.0, + http_version: str = "1.1", + ): + self._http_version = http_version + timeout = Timeout( + connect=connect_timeout, + read=read_timeout, + write=write_timeout, + pool=pool_timeout, + ) + limits = Limits( + max_connections=connection_pool_size, + max_keepalive_connections=connection_pool_size, + ) + if http_version not in ("1.1", "2"): + raise ValueError("`http_version` must be either '1.1' or '2'.") + http1 = http_version == "1.1" + self._client_kwargs = dict( + timeout=timeout, + proxies=proxy_url, + limits=limits, + transport=AsyncHTTPTransport(limits=limits), + http1=http1, + http2=not http1, + ) + + try: + self._client = self._build_client() + except ImportError as exc: + if "httpx[http2]" not in str(exc) and "httpx[socks]" not in str(exc): + raise exc + + if "httpx[socks]" in str(exc): + raise RuntimeError( + "To use Socks5 proxies, PTB must be installed via `pip install " "python-telegram-bot[socks]`." + ) from exc + raise RuntimeError( + "To use HTTP/2, PTB must be installed via `pip install " "python-telegram-bot[http2]`." + ) from exc diff --git a/gram_core/plugin/__init__.py b/gram_core/plugin/__init__.py new file mode 100644 index 0000000..e0b3051 --- /dev/null +++ b/gram_core/plugin/__init__.py @@ -0,0 +1,16 @@ +"""插件""" + +from gram_core.plugin._handler import conversation, error_handler, handler +from gram_core.plugin._job import TimeType, job +from gram_core.plugin._plugin import Plugin, PluginType, get_all_plugins + +__all__ = ( + "Plugin", + "PluginType", + "get_all_plugins", + "handler", + "error_handler", + "conversation", + "job", + "TimeType", +) diff --git a/gram_core/plugin/_funcs.py b/gram_core/plugin/_funcs.py new file mode 100644 index 0000000..03e90c1 --- /dev/null +++ b/gram_core/plugin/_funcs.py @@ -0,0 +1,178 @@ +from pathlib import Path +from typing import List, Optional, Union, TYPE_CHECKING + +import aiofiles +import httpx +from httpx import UnsupportedProtocol +from telegram import Chat, Message, ReplyKeyboardRemove, Update +from telegram.error import Forbidden, NetworkError +from telegram.ext import CallbackContext, ConversationHandler, Job + +from gram_core.dependence.redisdb import RedisDB +from gram_core.plugin._handler import conversation, handler +from utils.const import CACHE_DIR, REQUEST_HEADERS +from utils.error import UrlResourcesNotFoundError +from utils.helpers import sha1 +from utils.log import logger + +if TYPE_CHECKING: + from gram_core.application import Application + +try: + import ujson as json +except ImportError: + import json + +__all__ = ( + "PluginFuncs", + "ConversationFuncs", +) + + +class PluginFuncs: + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError("No application was set for this PluginManager.") + return self._application + + async def _delete_message(self, context: CallbackContext) -> None: + job = context.job + message_id = job.data + chat_info = f"chat_id[{job.chat_id}]" + + try: + chat = await self.get_chat(job.chat_id) + full_name = chat.full_name + if full_name: + chat_info = f"{full_name}[{chat.id}]" + else: + chat_info = f"{chat.title}[{chat.id}]" + except (NetworkError, Forbidden) as exc: + logger.warning("获取 chat info 失败 %s", exc.message) + except Exception as exc: + logger.warning("获取 chat info 消息失败 %s", str(exc)) + + logger.debug("删除消息 %s message_id[%s]", chat_info, message_id) + + try: + # noinspection PyTypeChecker + await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id) + except NetworkError as exc: + logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) + except Forbidden as exc: + logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) + except Exception as exc: + logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc) + + async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat: + application = self.application + redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None) + + if not redis_db: + return await application.bot.get_chat(chat_id) + + qname = f"bot:chat:{chat_id}" + + data = await redis_db.client.get(qname) + if data: + json_data = json.loads(data) + return Chat.de_json(json_data, application.telegram.bot) + + chat_info = await application.telegram.bot.get_chat(chat_id) + await redis_db.client.set(qname, chat_info.to_json(), ex=expire) + return chat_info + + def add_delete_message_job( + self, + message: Optional[Union[int, Message]] = None, + *, + delay: int = 60, + name: Optional[str] = None, + chat: Optional[Union[int, Chat]] = None, + context: Optional[CallbackContext] = None, + ) -> Job: + """延迟删除消息""" + + if isinstance(message, Message): + if chat is None: + chat = message.chat_id + message = message.id + + chat = chat.id if isinstance(chat, Chat) else chat + + job_queue = self.application.job_queue or context.job_queue + + if job_queue is None or chat is None: + raise RuntimeError + + return job_queue.run_once( + callback=self._delete_message, + when=delay, + data=message, + name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message", + chat_id=chat, + job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"}, + ) + + @staticmethod + async def download_resource(url: str, return_path: bool = False) -> str: + url_sha1 = sha1(url) # url 的 hash 值 + pathed_url = Path(url) + + file_name = url_sha1 + pathed_url.suffix + file_path = CACHE_DIR.joinpath(file_name) + + if not file_path.exists(): # 若文件不存在,则下载 + async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client: + try: + response = await client.get(url) + except UnsupportedProtocol: + logger.error("链接不支持 url[%s]", url) + return "" + + if response.is_error: + logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code) + raise UrlResourcesNotFoundError(url) + + if response.status_code != 200: + logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code) + raise UrlResourcesNotFoundError(url) + + async with aiofiles.open(file_path, mode="wb") as f: + await f.write(response.content) + + logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path) + + return file_path if return_path else Path(file_path).as_uri() + + @staticmethod + def get_args(context: CallbackContext) -> List[str]: + args = context.args + match = context.match + + if args is None: + if match is not None and (command := match.groups()[0]): + temp = [] + command_parts = command.split(" ") + for command_part in command_parts: + if command_part: + temp.append(command_part) + return temp + return [] + if len(args) >= 1: + return args + return [] + + +class ConversationFuncs: + @conversation.fallback + @handler.command(command="cancel", block=False) + async def cancel(self, update: Update, _) -> int: + await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END diff --git a/gram_core/plugin/_handler.py b/gram_core/plugin/_handler.py new file mode 100644 index 0000000..758f266 --- /dev/null +++ b/gram_core/plugin/_handler.py @@ -0,0 +1,380 @@ +from dataclasses import dataclass +from enum import Enum +from functools import wraps +from importlib import import_module +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Pattern, + TYPE_CHECKING, + Type, + TypeVar, + Union, +) + +from pydantic import BaseModel + +# noinspection PyProtectedMember +from telegram._utils.defaultvalue import DEFAULT_TRUE + +# noinspection PyProtectedMember +from telegram._utils.types import DVInput +from telegram.ext import BaseHandler +from telegram.ext.filters import BaseFilter +from typing_extensions import ParamSpec + +from gram_core.handler.callbackqueryhandler import CallbackQueryHandler +from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS + +if TYPE_CHECKING: + from gram_core.builtins.dispatcher import AbstractDispatcher + +__all__ = ( + "handler", + "conversation", + "ConversationDataType", + "ConversationData", + "HandlerData", + "ErrorHandlerData", + "error_handler", +) + +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R") +UT = TypeVar("UT") + +HandlerType = TypeVar("HandlerType", bound=BaseHandler) +HandlerCls = Type[HandlerType] + +Module = import_module("telegram.ext") + +HANDLER_DATA_ATTR_NAME = "_handler_datas" +"""用于储存生成 handler 时所需要的参数(例如 block)的属性名""" + +ERROR_HANDLER_ATTR_NAME = "_error_handler_data" + +CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data" +"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名""" + +WRAPPER_ASSIGNMENTS = list( + set( + _WRAPPER_ASSIGNMENTS + + [ + HANDLER_DATA_ATTR_NAME, + ERROR_HANDLER_ATTR_NAME, + CONVERSATION_HANDLER_ATTR_NAME, + ] + ) +) + + +@dataclass(init=True) +class HandlerData: + type: Type[HandlerType] + admin: bool + kwargs: Dict[str, Any] + dispatcher: Optional[Type["AbstractDispatcher"]] = None + + +class _Handler: + _type: Type["HandlerType"] + + kwargs: Dict[str, Any] = {} + + def __init_subclass__(cls, **kwargs) -> None: + """用于获取 python-telegram-bot 中对应的 handler class""" + + handler_name = f"{cls.__name__.strip('_')}Handler" + + if handler_name == "CallbackQueryHandler": + cls._type = CallbackQueryHandler + return + + cls._type = getattr(Module, handler_name, None) + + def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None: + self.dispatcher = dispatcher + self.admin = admin + self.kwargs = kwargs + + def __call__(self, func: Callable[P, R]) -> Callable[P, R]: + """decorator实现,从 func 生成 Handler""" + + handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, []) + handler_datas.append( + HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher) + ) + setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas) + + return func + + +class _CallbackQuery(_Handler): + def __init__( + self, + pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None, + *, + block: DVInput[bool] = DEFAULT_TRUE, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher) + + +class _ChatJoinRequest(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher) + + +class _ChatMember(_Handler): + def __init__( + self, + chat_member_types: int = -1, + *, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher) + + +class _ChosenInlineResult(_Handler): + def __init__( + self, + block: DVInput[bool] = DEFAULT_TRUE, + *, + pattern: Union[str, Pattern] = None, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(block=block, pattern=pattern, dispatcher=dispatcher) + + +class _Command(_Handler): + def __init__( + self, + command: Union[str, List[str]], + filters: "BaseFilter" = None, + *, + block: DVInput[bool] = DEFAULT_TRUE, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_Command, self).__init__( + command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher + ) + + +class _InlineQuery(_Handler): + def __init__( + self, + pattern: Union[str, Pattern] = None, + chat_types: List[str] = None, + *, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher) + + +class _Message(_Handler): + def __init__( + self, + filters: BaseFilter, + *, + block: DVInput[bool] = DEFAULT_TRUE, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ) -> None: + super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher) + + +class _PollAnswer(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher) + + +class _Poll(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_Poll, self).__init__(block=block, dispatcher=dispatcher) + + +class _PreCheckoutQuery(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher) + + +class _Prefix(_Handler): + def __init__( + self, + prefix: str, + command: str, + filters: BaseFilter = None, + *, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_Prefix, self).__init__( + prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher + ) + + +class _ShippingQuery(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher) + + +class _StringCommand(_Handler): + def __init__( + self, + command: str, + *, + admin: bool = False, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher) + + +class _StringRegex(_Handler): + def __init__( + self, + pattern: Union[str, Pattern], + *, + block: DVInput[bool] = DEFAULT_TRUE, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher) + + +class _Type(_Handler): + # noinspection PyShadowingBuiltins + def __init__( + self, + type: Type[UT], # pylint: disable=W0622 + strict: bool = False, + *, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): # pylint: disable=redefined-builtin + super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher) + + +# noinspection PyPep8Naming +class handler(_Handler): + callback_query = _CallbackQuery + chat_join_request = _ChatJoinRequest + chat_member = _ChatMember + chosen_inline_result = _ChosenInlineResult + command = _Command + inline_query = _InlineQuery + message = _Message + poll_answer = _PollAnswer + pool = _Poll + pre_checkout_query = _PreCheckoutQuery + prefix = _Prefix + shipping_query = _ShippingQuery + string_command = _StringCommand + string_regex = _StringRegex + type = _Type + + def __init__( + self, + handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]], + *, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + **kwargs: P.kwargs, + ) -> None: + self._type = handler_type + super().__init__(admin=admin, dispatcher=dispatcher, **kwargs) + + +class ConversationDataType(Enum): + """conversation handler 的类型""" + + Entry = "entry" + State = "state" + Fallback = "fallback" + + +class ConversationData(BaseModel): + """用于储存 conversation handler 的数据""" + + type: ConversationDataType + state: Optional[Any] = None + + +class _ConversationType: + _type: ClassVar[ConversationDataType] + + def __init_subclass__(cls, **kwargs) -> None: + cls._type = ConversationDataType(cls.__name__.lstrip("_").lower()) + + +def _entry(func: Callable[P, R]) -> Callable[P, R]: + setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry)) + + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapped + + +class _State(_ConversationType): + def __init__(self, state: Any) -> None: + self.state = state + + def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]: + setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state)) + return func + + +def _fallback(func: Callable[P, R]) -> Callable[P, R]: + setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback)) + + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapped + + +# noinspection PyPep8Naming +class conversation(_Handler): + entry_point = _entry + state = _State + fallback = _fallback + + +@dataclass(init=True) +class ErrorHandlerData: + block: bool + func: Optional[Callable] = None + + +# noinspection PyPep8Naming +class error_handler: + _func: Callable[P, R] + + def __init__( + self, + *, + block: bool = DEFAULT_TRUE, + ): + self._block = block + + def __call__(self, func: Callable[P, T]) -> Callable[P, T]: + self._func = func + wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self) + + handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, []) + handler_datas.append(ErrorHandlerData(block=self._block)) + setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas) + + return self._func diff --git a/gram_core/plugin/_job.py b/gram_core/plugin/_job.py new file mode 100644 index 0000000..30fafb7 --- /dev/null +++ b/gram_core/plugin/_job.py @@ -0,0 +1,173 @@ +"""插件""" +import datetime +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union + +# noinspection PyProtectedMember +from telegram._utils.types import JSONDict + +# noinspection PyProtectedMember +from telegram.ext._utils.types import JobCallback +from typing_extensions import ParamSpec + +if TYPE_CHECKING: + from gram_core.builtins.dispatcher import AbstractDispatcher + +__all__ = ["TimeType", "job", "JobData"] + +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R") + +TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time] + +_JOB_ATTR_NAME = "_job_data" + + +@dataclass(init=True) +class JobData: + name: str + data: Any + chat_id: int + user_id: int + type: str + job_kwargs: JSONDict = field(default_factory=dict) + kwargs: JSONDict = field(default_factory=dict) + dispatcher: Optional[Type["AbstractDispatcher"]] = None + + +class _Job: + kwargs: Dict = {} + + def __init__( + self, + name: str = None, + data: object = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + **kwargs, + ): + self.name = name + self.data = data + self.chat_id = chat_id + self.user_id = user_id + self.job_kwargs = {} if job_kwargs is None else job_kwargs + self.kwargs = kwargs + if dispatcher is None: + from gram_core.builtins.dispatcher import JobDispatcher + + dispatcher = JobDispatcher + + self.dispatcher = dispatcher + + def __call__(self, func: JobCallback) -> JobCallback: + data = JobData( + name=self.name, + data=self.data, + chat_id=self.chat_id, + user_id=self.user_id, + job_kwargs=self.job_kwargs, + kwargs=self.kwargs, + type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"), + dispatcher=self.dispatcher, + ) + if hasattr(func, _JOB_ATTR_NAME): + job_datas = getattr(func, _JOB_ATTR_NAME) + job_datas.append(data) + setattr(func, _JOB_ATTR_NAME, job_datas) + else: + setattr(func, _JOB_ATTR_NAME, [data]) + return func + + +class _RunOnce(_Job): + def __init__( + self, + when: TimeType, + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when) + + +class _RunRepeating(_Job): + def __init__( + self, + interval: Union[float, datetime.timedelta], + first: TimeType = None, + last: TimeType = None, + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__( + name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last + ) + + +class _RunMonthly(_Job): + def __init__( + self, + when: datetime.time, + day: int, + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day) + + +class _RunDaily(_Job): + def __init__( + self, + time: datetime.time, + days: Tuple[int, ...] = tuple(range(7)), + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days) + + +class _RunCustom(_Job): + def __init__( + self, + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher) + + +# noinspection PyPep8Naming +class job: + run_once = _RunOnce + run_repeating = _RunRepeating + run_monthly = _RunMonthly + run_daily = _RunDaily + run_custom = _RunCustom diff --git a/gram_core/plugin/_plugin.py b/gram_core/plugin/_plugin.py new file mode 100644 index 0000000..084bc38 --- /dev/null +++ b/gram_core/plugin/_plugin.py @@ -0,0 +1,314 @@ +"""插件""" +import asyncio +from abc import ABC +from dataclasses import asdict +from datetime import timedelta +from functools import partial, wraps +from itertools import chain +from multiprocessing import RLock as Lock +from types import MethodType +from typing import ( + Any, + ClassVar, + Dict, + Iterable, + List, + Optional, + TYPE_CHECKING, + Type, + TypeVar, + Union, +) + +from pydantic import BaseModel +from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler +from typing_extensions import ParamSpec + +from gram_core.handler.adminhandler import AdminHandler +from gram_core.plugin._funcs import ConversationFuncs, PluginFuncs +from gram_core.plugin._handler import ConversationDataType +from utils.const import WRAPPER_ASSIGNMENTS +from utils.helpers import isabstract +from utils.log import logger + +if TYPE_CHECKING: + from gram_core.application import Application + from gram_core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData + from gram_core.plugin._job import JobData + from multiprocessing.synchronize import RLock as LockType + +__all__ = ("Plugin", "PluginType", "get_all_plugins") + +wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS) +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R") + +HandlerType = TypeVar("HandlerType", bound=BaseHandler) + +_HANDLER_DATA_ATTR_NAME = "_handler_datas" +"""用于储存生成 handler 时所需要的参数(例如 block)的属性名""" + +_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data" +"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名""" + +_ERROR_HANDLER_ATTR_NAME = "_error_handler_data" + +_JOB_ATTR_NAME = "_job_data" + +_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"] + + +class _Plugin(PluginFuncs): + """插件""" + + _lock: ClassVar["LockType"] = Lock() + _asyncio_lock: ClassVar["LockType"] = asyncio.Lock() + _installed: bool = False + + _handlers: Optional[List[HandlerType]] = None + _error_handlers: Optional[List["ErrorHandlerData"]] = None + _jobs: Optional[List[Job]] = None + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError("No application was set for this Plugin.") + return self._application + + @property + def handlers(self) -> List[HandlerType]: + """该插件的所有 handler""" + with self._lock: + if self._handlers is None: + self._handlers = [] + + for attr in dir(self): + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and isinstance(func := getattr(self, attr), MethodType) + and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, [])) + ): + for data in datas: + data: "HandlerData" + if data.admin: + self._handlers.append( + AdminHandler( + handler=data.type( + callback=func, + **data.kwargs, + ), + application=self.application, + ) + ) + else: + self._handlers.append( + data.type( + callback=func, + **data.kwargs, + ) + ) + return self._handlers + + @property + def error_handlers(self) -> List["ErrorHandlerData"]: + with self._lock: + if self._error_handlers is None: + self._error_handlers = [] + for attr in dir(self): + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and isinstance(func := getattr(self, attr), MethodType) + and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, [])) + ): + for data in datas: + data: "ErrorHandlerData" + data.func = func + self._error_handlers.append(data) + + return self._error_handlers + + def _install_jobs(self) -> None: + if self._jobs is None: + self._jobs = [] + for attr in dir(self): + # noinspection PyUnboundLocalVariable + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and isinstance(func := getattr(self, attr), MethodType) + and (datas := getattr(func, _JOB_ATTR_NAME, [])) + ): + for data in datas: + data: "JobData" + self._jobs.append( + getattr(self.application.telegram.job_queue, data.type)( + callback=func, + **data.kwargs, + **{ + key: value + for key, value in asdict(data).items() + if key not in ["type", "kwargs", "dispatcher"] + }, + ) + ) + + @property + def jobs(self) -> List[Job]: + with self._lock: + if self._jobs is None: + self._jobs = [] + self._install_jobs() + return self._jobs + + async def initialize(self) -> None: + """初始化插件""" + + async def shutdown(self) -> None: + """销毁插件""" + + async def install(self) -> None: + """安装""" + group = id(self) + if not self._installed: + await self.initialize() + # initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题 + async with self._asyncio_lock: + self._install_jobs() + + for h in self.handlers: + if not isinstance(h, TypeHandler): + self.application.telegram.add_handler(h, group) + else: + self.application.telegram.add_handler(h, -1) + + for h in self.error_handlers: + self.application.telegram.add_error_handler(h.func, h.block) + self._installed = True + + async def uninstall(self) -> None: + """卸载""" + group = id(self) + + with self._lock: + if self._installed: + if group in self.application.telegram.handlers: + del self.application.telegram.handlers[id(self)] + + for h in self.handlers: + if isinstance(h, TypeHandler): + self.application.telegram.remove_handler(h, -1) + for h in self.error_handlers: + self.application.telegram.remove_error_handler(h.func) + + for j in self.application.telegram.job_queue.jobs(): + j.schedule_removal() + await self.shutdown() + self._installed = False + + async def reload(self) -> None: + await self.uninstall() + await self.install() + + +class _Conversation(_Plugin, ConversationFuncs, ABC): + """Conversation类""" + + # noinspection SpellCheckingInspection + class Config(BaseModel): + allow_reentry: bool = False + per_chat: bool = True + per_user: bool = True + per_message: bool = False + conversation_timeout: Optional[Union[float, timedelta]] = None + name: Optional[str] = None + map_to_parent: Optional[Dict[object, object]] = None + block: bool = False + + def __init_subclass__(cls, **kwargs): + cls._conversation_kwargs = kwargs + super(_Conversation, cls).__init_subclass__() + return cls + + @property + def handlers(self) -> List[HandlerType]: + with self._lock: + if self._handlers is None: + self._handlers = [] + + entry_points: List[HandlerType] = [] + states: Dict[Any, List[HandlerType]] = {} + fallbacks: List[HandlerType] = [] + for attr in dir(self): + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and (func := getattr(self, attr, None)) is not None + and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, [])) + ): + conversation_data: "ConversationData" + + handlers: List[HandlerType] = [] + for data in datas: + if data.admin: + handlers.append( + AdminHandler( + handler=data.type( + callback=func, + **data.kwargs, + ), + application=self.application, + ) + ) + else: + handlers.append( + data.type( + callback=func, + **data.kwargs, + ) + ) + + if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None): + if (_type := conversation_data.type) is ConversationDataType.Entry: + entry_points.extend(handlers) + elif _type is ConversationDataType.State: + if conversation_data.state in states: + states[conversation_data.state].extend(handlers) + else: + states[conversation_data.state] = handlers + elif _type is ConversationDataType.Fallback: + fallbacks.extend(handlers) + else: + self._handlers.extend(handlers) + else: + self._handlers.extend(handlers) + if entry_points and states and fallbacks: + kwargs = self._conversation_kwargs + kwargs.update(self.Config().dict()) + self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs)) + else: + temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks} + reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items())) + logger.warning( + "'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason) + ) + return self._handlers + + +class Plugin(_Plugin, ABC): + """插件""" + + Conversation = _Conversation + + +PluginType = TypeVar("PluginType", bound=_Plugin) + + +def get_all_plugins() -> Iterable[Type[PluginType]]: + """获取所有 Plugin 的子类""" + return filter( + lambda x: x.__name__[0] != "_" and not isabstract(x), + chain(Plugin.__subclasses__(), _Conversation.__subclasses__()), + ) diff --git a/gram_core/ratelimiter.py b/gram_core/ratelimiter.py new file mode 100644 index 0000000..0e6d369 --- /dev/null +++ b/gram_core/ratelimiter.py @@ -0,0 +1,67 @@ +import asyncio +import contextlib +from typing import Callable, Coroutine, Any, Union, List, Dict, Optional, TypeVar, Type + +from telegram.error import RetryAfter +from telegram.ext import BaseRateLimiter, ApplicationHandlerStop + +from utils.log import logger + +JSONDict: Type[dict[str, Any]] = Dict[str, Any] +RL_ARGS = TypeVar("RL_ARGS") + + +class RateLimiter(BaseRateLimiter[int]): + _lock = asyncio.Lock() + __slots__ = ( + "_limiter_info", + "_retry_after_event", + ) + + def __init__(self): + self._limiter_info: Dict[Union[str, int], float] = {} + self._retry_after_event = asyncio.Event() + self._retry_after_event.set() + + async def process_request( + self, + callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, List[JSONDict]]]], + args: Any, + kwargs: Dict[str, Any], + endpoint: str, + data: Dict[str, Any], + rate_limit_args: Optional[RL_ARGS], + ) -> Union[bool, JSONDict, List[JSONDict]]: + chat_id = data.get("chat_id") + + with contextlib.suppress(ValueError, TypeError): + chat_id = int(chat_id) + + loop = asyncio.get_running_loop() + time = loop.time() + + await self._retry_after_event.wait() + + async with self._lock: + chat_limit_time = self._limiter_info.get(chat_id) + if chat_limit_time: + if time >= chat_limit_time: + raise ApplicationHandlerStop + del self._limiter_info[chat_id] + + try: + return await callback(*args, **kwargs) + except RetryAfter as exc: + logger.warning("chat_id[%s] 触发洪水限制 当前被服务器限制 retry_after[%s]秒", chat_id, exc.retry_after) + self._limiter_info[chat_id] = time + (exc.retry_after * 2) + sleep = exc.retry_after + 0.1 + self._retry_after_event.clear() + await asyncio.sleep(sleep) + finally: + self._retry_after_event.set() + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass diff --git a/gram_core/services/__init__.py b/gram_core/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gram_core/services/cookies/__init__.py b/gram_core/services/cookies/__init__.py new file mode 100644 index 0000000..6314ee9 --- /dev/null +++ b/gram_core/services/cookies/__init__.py @@ -0,0 +1,5 @@ +"""CookieService""" + +from gram_core.services.cookies.services import CookiesService, PublicCookiesService + +__all__ = ("CookiesService", "PublicCookiesService") diff --git a/gram_core/services/cookies/cache.py b/gram_core/services/cookies/cache.py new file mode 100644 index 0000000..bba671c --- /dev/null +++ b/gram_core/services/cookies/cache.py @@ -0,0 +1,97 @@ +from typing import List, Union + +from gram_core.base_service import BaseService +from gram_core.basemodel import RegionEnum +from gram_core.dependence.redisdb import RedisDB +from gram_core.services.cookies.error import CookiesCachePoolExhausted +from utils.error import RegionNotFoundError + +__all__ = ("PublicCookiesCache",) + + +class PublicCookiesCache(BaseService.Component): + """使用优先级(score)进行排序,对使用次数最少的Cookies进行审核""" + + def __init__(self, redis: RedisDB): + self.client = redis.client + self.score_qname = "cookie:public" + self.user_times_qname = "cookie:public:times" + self.end = 20 + self.user_times_ttl = 60 * 60 * 24 + + def get_public_cookies_queue_name(self, region: RegionEnum): + if region == RegionEnum.HYPERION: + return f"{self.score_qname}:yuanshen" + if region == RegionEnum.HOYOLAB: + return f"{self.score_qname}:genshin" + raise RegionNotFoundError(region.name) + + async def putback_public_cookies(self, uid: int, region: RegionEnum): + """重新添加单个到缓存列表 + :param uid: + :param region: + :return: + """ + qname = self.get_public_cookies_queue_name(region) + score_maps = {f"{uid}": 0} + result = await self.client.zrem(qname, f"{uid}") + if result == 1: + await self.client.zadd(qname, score_maps) + return result + + async def add_public_cookies(self, uid: Union[List[int], int], region: RegionEnum): + """单个或批量添加到缓存列表 + :param uid: + :param region: + :return: 成功返回列表大小 + """ + qname = self.get_public_cookies_queue_name(region) + if isinstance(uid, int): + score_maps = {f"{uid}": 0} + elif isinstance(uid, list): + score_maps = {f"{i}": 0 for i in uid} + else: + raise TypeError("uid variable type error") + async with self.client.pipeline(transaction=True) as pipe: + # nx:只添加新元素。不要更新已经存在的元素 + await pipe.zadd(qname, score_maps, nx=True) + await pipe.zcard(qname) + add, count = await pipe.execute() + return int(add), count + + async def get_public_cookies(self, region: RegionEnum): + """从缓存列表获取 + :param region: + :return: + """ + qname = self.get_public_cookies_queue_name(region) + scores = await self.client.zrange(qname, 0, self.end, withscores=True, score_cast_func=int) + if len(scores) <= 0: + raise CookiesCachePoolExhausted + key = scores[0][0] + score = scores[0][1] + async with self.client.pipeline(transaction=True) as pipe: + await pipe.zincrby(qname, 1, key) + await pipe.execute() + return int(key), score + 1 + + async def delete_public_cookies(self, uid: int, region: RegionEnum): + qname = self.get_public_cookies_queue_name(region) + async with self.client.pipeline(transaction=True) as pipe: + await pipe.zrem(qname, uid) + return await pipe.execute() + + async def get_public_cookies_count(self, limit: bool = True): + async with self.client.pipeline(transaction=True) as pipe: + if limit: + await pipe.zcount(0, self.end) + else: + await pipe.zcard(self.score_qname) + return await pipe.execute() + + async def incr_by_user_times(self, user_id: Union[List[int], int], amount: int = 1): + qname = f"{self.user_times_qname}:{user_id}" + times = await self.client.incrby(qname, amount) + if times <= 1: + await self.client.expire(qname, self.user_times_ttl) + return times diff --git a/gram_core/services/cookies/error.py b/gram_core/services/cookies/error.py new file mode 100644 index 0000000..239110a --- /dev/null +++ b/gram_core/services/cookies/error.py @@ -0,0 +1,12 @@ +class CookieServiceError(Exception): + pass + + +class CookiesCachePoolExhausted(CookieServiceError): + def __init__(self): + super().__init__("Cookies cache pool is exhausted") + + +class TooManyRequestPublicCookies(CookieServiceError): + def __init__(self, user_id): + super().__init__(f"{user_id} too many request public cookies") diff --git a/gram_core/services/cookies/models.py b/gram_core/services/cookies/models.py new file mode 100644 index 0000000..4df0ba3 --- /dev/null +++ b/gram_core/services/cookies/models.py @@ -0,0 +1,39 @@ +import enum +from typing import Optional, Dict + +from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index + +from gram_core.basemodel import RegionEnum + +__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum") + + +class CookiesStatusEnum(int, enum.Enum): + STATUS_SUCCESS = 0 + INVALID_COOKIES = 1 + TOO_MANY_REQUESTS = 2 + + +class Cookies(SQLModel): + __table_args__ = ( + Index("index_user_account", "user_id", "account_id", unique=True), + dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), + ) + id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True)) + user_id: int = Field( + sa_column=Column(BigInteger()), + ) + account_id: int = Field( + default=None, + sa_column=Column( + BigInteger(), + ), + ) + data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON)) + status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum))) + region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum))) + is_share: Optional[bool] = Field(sa_column=Column(Boolean)) + + +class CookiesDataBase(Cookies, table=True): + __tablename__ = "cookies" diff --git a/gram_core/services/cookies/repositories.py b/gram_core/services/cookies/repositories.py new file mode 100644 index 0000000..4567b0a --- /dev/null +++ b/gram_core/services/cookies/repositories.py @@ -0,0 +1,55 @@ +from typing import Optional, List + +from sqlmodel import select + +from gram_core.base_service import BaseService +from gram_core.basemodel import RegionEnum +from gram_core.dependence.database import Database +from gram_core.services.cookies.models import CookiesDataBase as Cookies +from gram_core.sqlmodel.session import AsyncSession + +__all__ = ("CookiesRepository",) + + +class CookiesRepository(BaseService.Component): + def __init__(self, database: Database): + self.engine = database.engine + + async def get( + self, + user_id: int, + account_id: Optional[int] = None, + region: Optional[RegionEnum] = None, + ) -> Optional[Cookies]: + async with AsyncSession(self.engine) as session: + statement = select(Cookies).where(Cookies.user_id == user_id) + if account_id is not None: + statement = statement.where(Cookies.account_id == account_id) + if region is not None: + statement = statement.where(Cookies.region == region) + results = await session.exec(statement) + return results.first() + + async def add(self, cookies: Cookies) -> None: + async with AsyncSession(self.engine) as session: + session.add(cookies) + await session.commit() + + async def update(self, cookies: Cookies) -> Cookies: + async with AsyncSession(self.engine) as session: + session.add(cookies) + await session.commit() + await session.refresh(cookies) + return cookies + + async def delete(self, cookies: Cookies) -> None: + async with AsyncSession(self.engine) as session: + await session.delete(cookies) + await session.commit() + + async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]: + async with AsyncSession(self.engine) as session: + statement = select(Cookies).where(Cookies.region == region) + results = await session.exec(statement) + cookies = results.all() + return cookies diff --git a/gram_core/services/cookies/services.py b/gram_core/services/cookies/services.py new file mode 100644 index 0000000..8a74a60 --- /dev/null +++ b/gram_core/services/cookies/services.py @@ -0,0 +1,159 @@ +from typing import List, Optional + +from simnet import StarRailClient, Region, Game +from simnet.errors import InvalidCookies, BadRequest as SimnetBadRequest, TooManyRequests + +from gram_core.base_service import BaseService +from gram_core.basemodel import RegionEnum +from gram_core.services.cookies.cache import PublicCookiesCache +from gram_core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies +from gram_core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum +from gram_core.services.cookies.repositories import CookiesRepository +from utils.log import logger + +__all__ = ("CookiesService", "PublicCookiesService") + + +class CookiesService(BaseService): + def __init__(self, cookies_repository: CookiesRepository) -> None: + self._repository: CookiesRepository = cookies_repository + + async def update(self, cookies: Cookies): + await self._repository.update(cookies) + + async def add(self, cookies: Cookies): + await self._repository.add(cookies) + + async def get( + self, + user_id: int, + account_id: Optional[int] = None, + region: Optional[RegionEnum] = None, + ) -> Optional[Cookies]: + return await self._repository.get(user_id, account_id, region) + + async def delete(self, cookies: Cookies) -> None: + return await self._repository.delete(cookies) + + +class PublicCookiesService(BaseService): + def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache): + self._cache = public_cookies_cache + self._repository: CookiesRepository = cookies_repository + self.count: int = 0 + self.user_times_limiter = 3 * 3 + + async def initialize(self) -> None: + logger.info("正在初始化公共Cookies池") + await self.refresh() + logger.success("刷新公共Cookies池成功") + + async def refresh(self): + """刷新公共Cookies 定时任务 + :return: + """ + user_list: List[int] = [] + cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2 + for cookies in cookies_list: + if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS: + user_list.append(cookies.user_id) + if len(user_list) > 0: + add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION) + logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count) + user_list.clear() + cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB) + for cookies in cookies_list: + if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS: + user_list.append(cookies.user_id) + if len(user_list) > 0: + add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB) + logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count) + + async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL): + """获取公共Cookies + :param user_id: 用户ID + :param region: 注册的服务器 + :return: + """ + user_times = await self._cache.incr_by_user_times(user_id) + if int(user_times) > self.user_times_limiter: + logger.warning("用户 %s 使用公共Cookies次数已经到达上限", user_id) + raise TooManyRequestPublicCookies(user_id) + while True: + public_id, count = await self._cache.get_public_cookies(region) + cookies = await self._repository.get(public_id, region=region) + if cookies is None: + await self._cache.delete_public_cookies(public_id, region) + continue + if region == RegionEnum.HYPERION: + client = StarRailClient(cookies=cookies.data, region=Region.CHINESE) + elif region == RegionEnum.HOYOLAB: + client = StarRailClient(cookies=cookies.data, region=Region.OVERSEAS, lang="zh-cn") + else: + raise CookieServiceError + try: + if client.account_id is None: + raise RuntimeError("account_id not found") + record_cards = await client.get_record_cards() + for record_card in record_cards: + if record_card.game == Game.STARRAIL: + await client.get_starrail_user(record_card.uid) + break + else: + accounts = await client.get_game_accounts() + for account in accounts: + if account.game == Game.STARRAIL: + await client.get_starrail_user(account.uid) + break + except InvalidCookies as exc: + if exc.ret_code in (10001, -100): + logger.warning("用户 [%s] Cookies无效", public_id) + elif exc.ret_code == 10103: + logger.warning("用户 [%s] Cookies有效,但没有绑定到游戏帐户", public_id) + else: + logger.warning("Cookies无效 ") + logger.exception(exc) + cookies.status = CookiesStatusEnum.INVALID_COOKIES + await self._repository.update(cookies) + await self._cache.delete_public_cookies(cookies.user_id, region) + continue + except TooManyRequests: + logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id) + cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS + await self._repository.update(cookies) + await self._cache.delete_public_cookies(cookies.user_id, region) + continue + except SimnetBadRequest as exc: + if "invalid content type" in exc.message: + raise exc + if exc.ret_code == 1034: + logger.warning("用户 [%s] 触发验证", public_id) + else: + logger.warning("用户 [%s] 获取账号信息发生错误,错误信息为", public_id) + logger.exception(exc) + await self._cache.delete_public_cookies(cookies.user_id, region) + continue + except RuntimeError as exc: + if "account_id not found" in str(exc): + cookies.status = CookiesStatusEnum.INVALID_COOKIES + await self._repository.update(cookies) + await self._cache.delete_public_cookies(cookies.user_id, region) + continue + raise exc + except Exception as exc: + await self._cache.delete_public_cookies(cookies.user_id, region) + raise exc + finally: + await client.shutdown() + logger.info("用户 user_id[%s] 请求用户 user_id[%s] 的公共Cookies 该Cookies使用次数为%s次 ", user_id, public_id, count) + return cookies + + async def undo(self, user_id: int, cookies: Optional[Cookies] = None, status: Optional[CookiesStatusEnum] = None): + await self._cache.incr_by_user_times(user_id, -1) + if cookies is not None and status is not None: + cookies.status = status + await self._repository.update(cookies) + await self._cache.delete_public_cookies(cookies.user_id, cookies.region) + logger.info("用户 user_id[%s] 反馈用户 user_id[%s] 的Cookies状态为 %s", user_id, cookies.user_id, status.name) + else: + logger.info("用户 user_id[%s] 撤销一次公共Cookies计数", user_id) diff --git a/gram_core/services/devices/__init__.py b/gram_core/services/devices/__init__.py new file mode 100644 index 0000000..b822810 --- /dev/null +++ b/gram_core/services/devices/__init__.py @@ -0,0 +1,5 @@ +"""DeviceService""" + +from gram_core.services.devices.services import DevicesService + +__all__ = "DevicesService" diff --git a/gram_core/services/devices/models.py b/gram_core/services/devices/models.py new file mode 100644 index 0000000..a7a9676 --- /dev/null +++ b/gram_core/services/devices/models.py @@ -0,0 +1,23 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field, Column, Integer, BigInteger + +__all__ = ("Devices", "DevicesDataBase") + + +class Devices(SQLModel): + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True)) + account_id: int = Field( + default=None, + sa_column=Column( + BigInteger(), + ), + ) + device_id: str = Field() + device_fp: str = Field() + device_name: Optional[str] = Field(default=None) + + +class DevicesDataBase(Devices, table=True): + __tablename__ = "devices" diff --git a/gram_core/services/devices/repositories.py b/gram_core/services/devices/repositories.py new file mode 100644 index 0000000..4774c3e --- /dev/null +++ b/gram_core/services/devices/repositories.py @@ -0,0 +1,41 @@ +from typing import Optional + +from sqlmodel import select + +from gram_core.base_service import BaseService +from gram_core.dependence.database import Database +from gram_core.services.devices.models import DevicesDataBase as Devices +from gram_core.sqlmodel.session import AsyncSession + +__all__ = ("DevicesRepository",) + + +class DevicesRepository(BaseService.Component): + def __init__(self, database: Database): + self.engine = database.engine + + async def get( + self, + account_id: int, + ) -> Optional[Devices]: + async with AsyncSession(self.engine) as session: + statement = select(Devices).where(Devices.account_id == account_id) + results = await session.exec(statement) + return results.first() + + async def add(self, devices: Devices) -> None: + async with AsyncSession(self.engine) as session: + session.add(devices) + await session.commit() + + async def update(self, devices: Devices) -> Devices: + async with AsyncSession(self.engine) as session: + session.add(devices) + await session.commit() + await session.refresh(devices) + return devices + + async def delete(self, devices: Devices) -> None: + async with AsyncSession(self.engine) as session: + await session.delete(devices) + await session.commit() diff --git a/gram_core/services/devices/services.py b/gram_core/services/devices/services.py new file mode 100644 index 0000000..9f993e5 --- /dev/null +++ b/gram_core/services/devices/services.py @@ -0,0 +1,25 @@ +from typing import Optional + +from gram_core.base_service import BaseService +from gram_core.services.devices.repositories import DevicesRepository +from gram_core.services.devices.models import DevicesDataBase as Devices + + +class DevicesService(BaseService): + def __init__(self, devices_repository: DevicesRepository) -> None: + self._repository: DevicesRepository = devices_repository + + async def update(self, devices: Devices): + await self._repository.update(devices) + + async def add(self, devices: Devices): + await self._repository.add(devices) + + async def get( + self, + account_id: int, + ) -> Optional[Devices]: + return await self._repository.get(account_id) + + async def delete(self, devices: Devices) -> None: + return await self._repository.delete(devices) diff --git a/gram_core/services/players/__init__.py b/gram_core/services/players/__init__.py new file mode 100644 index 0000000..5cfee96 --- /dev/null +++ b/gram_core/services/players/__init__.py @@ -0,0 +1,3 @@ +from .services import PlayersService + +__all__ = ("PlayersService",) diff --git a/gram_core/services/players/error.py b/gram_core/services/players/error.py new file mode 100644 index 0000000..623bed8 --- /dev/null +++ b/gram_core/services/players/error.py @@ -0,0 +1,2 @@ +class PlayerNotFoundError(Exception): + pass diff --git a/gram_core/services/players/models.py b/gram_core/services/players/models.py new file mode 100644 index 0000000..beb78bb --- /dev/null +++ b/gram_core/services/players/models.py @@ -0,0 +1,96 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, BaseSettings +from sqlalchemy import TypeDecorator +from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime + +from gram_core.basemodel import RegionEnum + +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib + +__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel") + + +class Player(SQLModel): + __table_args__ = ( + Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True), + dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), + ) + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) + ) + user_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + account_id: int = Field(default=None, primary_key=True, sa_column=Column(BigInteger())) + player_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum))) + is_chosen: Optional[bool] = Field(sa_column=Column(Boolean)) + + +class PlayersDataBase(Player, table=True): + __tablename__ = "players" + + +class ExtraPlayerInfo(BaseModel): + class Config(BaseSettings.Config): + json_loads = jsonlib.loads + json_dumps = jsonlib.dumps + + waifu_id: Optional[int] = None + + +class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223 + impl = VARCHAR(length=521) + + cache_ok = True + + def process_bind_param(self, value, dialect): + """ + :param value: ExtraPlayerInfo | obj | None + :param dialect: + :return: + """ + if value is not None: + if isinstance(value, ExtraPlayerInfo): + return value.json() + raise TypeError + return value + + def process_result_value(self, value, dialect): + """ + :param value: str | obj | None + :param dialect: + :return: + """ + if value is not None: + return ExtraPlayerInfo.parse_raw(value) + return None + + +class PlayerInfo(SQLModel): + __table_args__ = ( + Index("index_user_account_player", "user_id", "player_id", unique=True), + dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), + ) + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) + ) + user_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + player_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + nickname: Optional[str] = Field() + signature: Optional[str] = Field() + hand_image: Optional[int] = Field() + name_card: Optional[int] = Field() + extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType)) + create_time: Optional[datetime] = Field( + sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102 + ) + last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102 + is_update: Optional[bool] = Field(sa_column=Column(Boolean)) + + +class PlayerInfoSQLModel(PlayerInfo, table=True): + __tablename__ = "players_info" diff --git a/gram_core/services/players/repositories.py b/gram_core/services/players/repositories.py new file mode 100644 index 0000000..c16de3f --- /dev/null +++ b/gram_core/services/players/repositories.py @@ -0,0 +1,110 @@ +from typing import List, Optional + +from sqlmodel import select, delete + +from gram_core.base_service import BaseService +from gram_core.basemodel import RegionEnum +from gram_core.dependence.database import Database +from gram_core.services.players.models import PlayerInfoSQLModel +from gram_core.services.players.models import PlayersDataBase as Player +from gram_core.sqlmodel.session import AsyncSession + +__all__ = ("PlayersRepository", "PlayerInfoRepository") + + +class PlayersRepository(BaseService.Component): + def __init__(self, database: Database): + self.engine = database.engine + + async def get( + self, + user_id: int, + player_id: Optional[int] = None, + account_id: Optional[int] = None, + region: Optional[RegionEnum] = None, + is_chosen: Optional[bool] = None, + ) -> Optional[Player]: + async with AsyncSession(self.engine) as session: + statement = select(Player).where(Player.user_id == user_id) + if player_id is not None: + statement = statement.where(Player.player_id == player_id) + if account_id is not None: + statement = statement.where(Player.account_id == account_id) + if region is not None: + statement = statement.where(Player.region == region) + if is_chosen is not None: + statement = statement.where(Player.is_chosen == is_chosen) + results = await session.exec(statement) + return results.first() + + async def add(self, player: Player) -> None: + async with AsyncSession(self.engine) as session: + session.add(player) + await session.commit() + await session.refresh(player) + + async def delete(self, player: Player) -> None: + async with AsyncSession(self.engine) as session: + await session.delete(player) + await session.commit() + + async def update(self, player: Player) -> None: + async with AsyncSession(self.engine) as session: + session.add(player) + await session.commit() + await session.refresh(player) + + async def get_all_by_user_id(self, user_id: int) -> List[Player]: + async with AsyncSession(self.engine) as session: + statement = select(Player).where(Player.user_id == user_id) + results = await session.exec(statement) + players = results.all() + return players + + +class PlayerInfoRepository(BaseService.Component): + def __init__(self, database: Database): + self.engine = database.engine + + async def get( + self, + user_id: int, + player_id: int, + ) -> Optional[PlayerInfoSQLModel]: + async with AsyncSession(self.engine) as session: + statement = ( + select(PlayerInfoSQLModel) + .where(PlayerInfoSQLModel.player_id == player_id) + .where(PlayerInfoSQLModel.user_id == user_id) + ) + results = await session.exec(statement) + return results.first() + + async def add(self, player: PlayerInfoSQLModel) -> None: + async with AsyncSession(self.engine) as session: + session.add(player) + await session.commit() + + async def delete(self, player: PlayerInfoSQLModel) -> None: + async with AsyncSession(self.engine) as session: + await session.delete(player) + await session.commit() + + async def delete_by_id( + self, + user_id: int, + player_id: int, + ) -> None: + async with AsyncSession(self.engine) as session: + statement = ( + delete(PlayerInfoSQLModel) + .where(PlayerInfoSQLModel.player_id == player_id) + .where(PlayerInfoSQLModel.user_id == user_id) + ) + await session.execute(statement) + + async def update(self, player: PlayerInfoSQLModel) -> None: + async with AsyncSession(self.engine) as session: + session.add(player) + await session.commit() + await session.refresh(player) diff --git a/gram_core/services/players/services.py b/gram_core/services/players/services.py new file mode 100644 index 0000000..d73d89a --- /dev/null +++ b/gram_core/services/players/services.py @@ -0,0 +1,43 @@ +from typing import List, Optional + +from gram_core.base_service import BaseService +from gram_core.basemodel import RegionEnum +from gram_core.services.players.models import PlayersDataBase as Player +from gram_core.services.players.repositories import PlayersRepository + +__all__ = "PlayersService" + + +class PlayersService(BaseService): + def __init__(self, players_repository: PlayersRepository) -> None: + self._repository = players_repository + + async def get( + self, + user_id: int, + player_id: Optional[int] = None, + account_id: Optional[int] = None, + region: Optional[RegionEnum] = None, + is_chosen: Optional[bool] = None, + ) -> Optional[Player]: + return await self._repository.get(user_id, player_id, account_id, region, is_chosen) + + async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]: + return await self._repository.get(user_id, region=region, is_chosen=True) + + async def add(self, player: Player) -> None: + await self._repository.add(player) + + async def update(self, player: Player) -> None: + await self._repository.update(player) + + async def get_all_by_user_id(self, user_id: int) -> List[Player]: + return await self._repository.get_all_by_user_id(user_id) + + async def remove_all_by_user_id(self, user_id: int): + players = await self._repository.get_all_by_user_id(user_id) + for player in players: + await self._repository.delete(player) + + async def delete(self, player: Player): + await self._repository.delete(player) diff --git a/gram_core/services/task/__init__.py b/gram_core/services/task/__init__.py new file mode 100644 index 0000000..0835fdf --- /dev/null +++ b/gram_core/services/task/__init__.py @@ -0,0 +1 @@ +"""TaskService""" diff --git a/gram_core/services/task/models.py b/gram_core/services/task/models.py new file mode 100644 index 0000000..d8a3cd0 --- /dev/null +++ b/gram_core/services/task/models.py @@ -0,0 +1,44 @@ +import enum +from datetime import datetime +from typing import Optional, Dict, Any + +from sqlalchemy import func, BigInteger, JSON +from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer + +__all__ = ("Task", "TaskStatusEnum", "TaskTypeEnum") + + +class TaskStatusEnum(int, enum.Enum): + STATUS_SUCCESS = 0 # 任务执行成功 + INVALID_COOKIES = 1 # Cookie无效 + ALREADY_CLAIMED = 2 # 已经获取奖励 + NEED_CHALLENGE = 3 # 需要验证码 + GENSHIN_EXCEPTION = 4 # API异常 + TIMEOUT_ERROR = 5 # 请求超时 + BAD_REQUEST = 6 # 请求失败 + FORBIDDEN = 7 # 这错误一般为通知失败 机器人被用户BAN + + +class TaskTypeEnum(int, enum.Enum): + SIGN = 0 # 签到 + RESIN = 1 # 开拓力 + REALM = 2 # 洞天宝钱 + EXPEDITION = 3 # 委托 + TRANSFORMER = 4 # 参量质变仪 + CARD = 5 # 生日画片 + + +class Task(SQLModel, table=True): + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) + ) + user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True)) + chat_id: Optional[int] = Field(default=None, sa_column=Column(BigInteger())) + time_created: Optional[datetime] = Field( + sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102 + ) + time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102 + type: TaskTypeEnum = Field(primary_key=True, sa_column=Column(Enum(TaskTypeEnum))) + status: Optional[TaskStatusEnum] = Field(sa_column=Column(Enum(TaskStatusEnum))) + data: Optional[Dict[str, Any]] = Field(sa_column=Column(JSON)) diff --git a/gram_core/services/task/repositories.py b/gram_core/services/task/repositories.py new file mode 100644 index 0000000..b45c3a8 --- /dev/null +++ b/gram_core/services/task/repositories.py @@ -0,0 +1,50 @@ +from typing import List, Optional + +from sqlmodel import select + +from gram_core.base_service import BaseService +from gram_core.dependence.database import Database +from gram_core.services.task.models import Task, TaskTypeEnum +from gram_core.sqlmodel.session import AsyncSession + +__all__ = ("TaskRepository",) + + +class TaskRepository(BaseService.Component): + def __init__(self, database: Database): + self.engine = database.engine + + async def add(self, task: Task): + async with AsyncSession(self.engine) as session: + session.add(task) + await session.commit() + + async def remove(self, task: Task): + async with AsyncSession(self.engine) as session: + await session.delete(task) + await session.commit() + + async def update(self, task: Task) -> Task: + async with AsyncSession(self.engine) as session: + session.add(task) + await session.commit() + await session.refresh(task) + return task + + async def get_by_user_id(self, user_id: int, task_type: TaskTypeEnum) -> Optional[Task]: + async with AsyncSession(self.engine) as session: + statement = select(Task).where(Task.user_id == user_id).where(Task.type == task_type) + results = await session.exec(statement) + return results.first() + + async def get_by_chat_id(self, chat_id: int, task_type: TaskTypeEnum) -> Optional[List[Task]]: + async with AsyncSession(self.engine) as session: + statement = select(Task).where(Task.chat_id == chat_id).where(Task.type == task_type) + results = await session.exec(statement) + return results.all() + + async def get_all(self, task_type: TaskTypeEnum) -> List[Task]: + async with AsyncSession(self.engine) as session: + query = select(Task).where(Task.type == task_type) + results = await session.exec(query) + return results.all() diff --git a/gram_core/services/task/services.py b/gram_core/services/task/services.py new file mode 100644 index 0000000..6f42a31 --- /dev/null +++ b/gram_core/services/task/services.py @@ -0,0 +1,63 @@ +import datetime +from typing import Optional, Dict, Any + +from gram_core.base_service import BaseService +from gram_core.services.task.models import Task, TaskTypeEnum +from gram_core.services.task.repositories import TaskRepository + +__all__ = [ + "TaskServices", + "SignServices", + "TaskCardServices", + "TaskResinServices", + "TaskExpeditionServices", +] + + +class TaskServices: + TASK_TYPE: TaskTypeEnum + + def __init__(self, task_repository: TaskRepository) -> None: + self._repository: TaskRepository = task_repository + + async def add(self, task: Task): + return await self._repository.add(task) + + async def remove(self, task: Task): + return await self._repository.remove(task) + + async def update(self, task: Task): + task.time_updated = datetime.datetime.now() + return await self._repository.update(task) + + async def get_by_user_id(self, user_id: int): + return await self._repository.get_by_user_id(user_id, self.TASK_TYPE) + + async def get_all(self): + return await self._repository.get_all(self.TASK_TYPE) + + def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): + return Task( + user_id=user_id, + chat_id=chat_id, + time_created=datetime.datetime.now(), + status=status, + type=self.TASK_TYPE, + data=data, + ) + + +class SignServices(BaseService, TaskServices): + TASK_TYPE = TaskTypeEnum.SIGN + + +class TaskCardServices(BaseService, TaskServices): + TASK_TYPE = TaskTypeEnum.CARD + + +class TaskResinServices(BaseService, TaskServices): + TASK_TYPE = TaskTypeEnum.RESIN + + +class TaskExpeditionServices(BaseService, TaskServices): + TASK_TYPE = TaskTypeEnum.EXPEDITION diff --git a/gram_core/services/template/README.md b/gram_core/services/template/README.md new file mode 100644 index 0000000..b8a9bb2 --- /dev/null +++ b/gram_core/services/template/README.md @@ -0,0 +1,11 @@ +# TemplateService + +使用 jinja2 渲染 html 为图片的服务。 + +## 预览模板 + +为了方便调试 html,在开发环境中,我们会启动 web server 用于预览模板。(可以在 .env 里调整端口等参数,参数均为 `web_` 开头) + +在派蒙收到指令开始渲染某个模板的时候,控制台会输出一个预览链接,类似 `http://localhost:8080/preview/genshin/stats/stats.html?id=45f7d86a-058e-4f64-bdeb-42903d8415b2`,有效时间 8 小时。 + +如果是无需数据的模板,永久有效,比如 `http://localhost:8080/preview/bot/help/help.html` diff --git a/gram_core/services/template/__init__.py b/gram_core/services/template/__init__.py new file mode 100644 index 0000000..79551dd --- /dev/null +++ b/gram_core/services/template/__init__.py @@ -0,0 +1 @@ +"""TemplateService""" diff --git a/gram_core/services/template/cache.py b/gram_core/services/template/cache.py new file mode 100644 index 0000000..c891b35 --- /dev/null +++ b/gram_core/services/template/cache.py @@ -0,0 +1,58 @@ +import gzip +import pickle # nosec B403 +from hashlib import sha256 +from typing import Any, Optional + +from gram_core.base_service import BaseService +from gram_core.dependence.redisdb import RedisDB + +__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"] + + +class TemplatePreviewCache(BaseService.Component): + """暂存渲染模板的数据用于预览""" + + def __init__(self, redis: RedisDB): + self.client = redis.client + self.qname = "bot:template:preview" + + async def get_data(self, key: str) -> Any: + data = await self.client.get(self.cache_key(key)) + if data: + # skipcq: BAN-B301 + return pickle.loads(gzip.decompress(data)) # nosec B301 + + async def set_data(self, key: str, data: Any, ttl: int = 8 * 60 * 60): + ck = self.cache_key(key) + await self.client.set(ck, gzip.compress(pickle.dumps(data))) + if ttl != -1: + await self.client.expire(ck, ttl) + + def cache_key(self, key: str) -> str: + return f"{self.qname}:{key}" + + +class HtmlToFileIdCache(BaseService.Component): + """html to file_id 的缓存""" + + def __init__(self, redis: RedisDB): + self.client = redis.client + self.qname = "bot:template:html-to-file-id" + + async def get_data(self, html: str, file_type: str) -> Optional[str]: + data = await self.client.get(self.cache_key(html, file_type)) + if data: + return data.decode() + + async def set_data(self, html: str, file_type: str, file_id: str, ttl: int = 24 * 60 * 60): + ck = self.cache_key(html, file_type) + await self.client.set(ck, file_id) + if ttl != -1: + await self.client.expire(ck, ttl) + + async def delete_data(self, html: str, file_type: str) -> bool: + return await self.client.delete(self.cache_key(html, file_type)) + + def cache_key(self, html: str, file_type: str) -> str: + key = sha256(html.encode()).hexdigest() + return f"{self.qname}:{file_type}:{key}" diff --git a/gram_core/services/template/error.py b/gram_core/services/template/error.py new file mode 100644 index 0000000..197e06c --- /dev/null +++ b/gram_core/services/template/error.py @@ -0,0 +1,14 @@ +class TemplateException(Exception): + pass + + +class QuerySelectorNotFound(TemplateException): + pass + + +class ErrorFileType(TemplateException): + pass + + +class FileIdNotFound(TemplateException): + pass diff --git a/gram_core/services/template/models.py b/gram_core/services/template/models.py new file mode 100644 index 0000000..2b5c9cd --- /dev/null +++ b/gram_core/services/template/models.py @@ -0,0 +1,146 @@ +from enum import Enum +from typing import List, Optional, Union + +from telegram import InputMediaDocument, InputMediaPhoto, Message +from telegram.error import BadRequest + +from gram_core.services.template.cache import HtmlToFileIdCache +from gram_core.services.template.error import ErrorFileType, FileIdNotFound + +__all__ = ["FileType", "RenderResult", "RenderGroupResult"] + + +class FileType(Enum): + PHOTO = 1 + DOCUMENT = 2 + + @staticmethod + def media_type(file_type: "FileType"): + """对应的 Telegram media 类型""" + if file_type == FileType.PHOTO: + return InputMediaPhoto + if file_type == FileType.DOCUMENT: + return InputMediaDocument + raise ErrorFileType + + +class RenderResult: + """渲染结果""" + + def __init__( + self, + html: str, + photo: Union[bytes, str], + file_type: FileType, + cache: HtmlToFileIdCache, + ttl: int = 24 * 60 * 60, + caption: Optional[str] = None, + parse_mode: Optional[str] = None, + filename: Optional[str] = None, + ): + """ + `html`: str 渲染生成的 html + `photo`: Union[bytes, str] 渲染生成的图片。bytes 表示是图片,str 则为 file_id + """ + self.caption = caption + self.parse_mode = parse_mode + self.filename = filename + self.html = html + self.photo = photo + self.file_type = file_type + self._cache = cache + self.ttl = ttl + + async def reply_photo(self, message: Message, *args, **kwargs): + """是 `message.reply_photo` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" + if self.file_type != FileType.PHOTO: + raise ErrorFileType + + try: + reply = await message.reply_photo(photo=self.photo, *args, **kwargs) + except BadRequest as exc: + if "Wrong file identifier" in exc.message and isinstance(self.photo, str): + await self._cache.delete_data(self.html, self.file_type.name) + raise BadRequest(message="Wrong file identifier specified") + raise exc + + await self.cache_file_id(reply) + + return reply + + async def reply_document(self, message: Message, *args, **kwargs): + """是 `message.reply_document` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" + if self.file_type != FileType.DOCUMENT: + raise ErrorFileType + + try: + reply = await message.reply_document(document=self.photo, *args, **kwargs) + except BadRequest as exc: + if "Wrong file identifier" in exc.message and isinstance(self.photo, str): + await self._cache.delete_data(self.html, self.file_type.name) + raise BadRequest(message="Wrong file identifier specified") + raise exc + + await self.cache_file_id(reply) + + return reply + + async def edit_media(self, message: Message, *args, **kwargs): + """是 `message.edit_media` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" + if self.file_type != FileType.PHOTO: + raise ErrorFileType + + media = InputMediaPhoto( + media=self.photo, caption=self.caption, parse_mode=self.parse_mode, filename=self.filename + ) + + try: + edit_media = await message.edit_media(media, *args, **kwargs) + except BadRequest as exc: + if "Wrong file identifier" in exc.message and isinstance(self.photo, str): + await self._cache.delete_data(self.html, self.file_type.name) + raise BadRequest(message="Wrong file identifier specified") + raise exc + + await self.cache_file_id(edit_media) + + return edit_media + + async def cache_file_id(self, reply: Message): + """缓存 telegram 返回的 file_id""" + if self.is_file_id(): + return + + if self.file_type == FileType.PHOTO and reply.photo: + file_id = reply.photo[0].file_id + elif self.file_type == FileType.DOCUMENT and reply.document: + file_id = reply.document.file_id + else: + raise FileIdNotFound + await self._cache.set_data(self.html, self.file_type.name, file_id, self.ttl) + + def is_file_id(self) -> bool: + return isinstance(self.photo, str) + + +class RenderGroupResult: + def __init__(self, results: List[RenderResult]): + self.results = results + + async def reply_media_group(self, message: Message, *args, **kwargs): + """是 `message.reply_media_group` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" + + reply = await message.reply_media_group( + media=[ + FileType.media_type(result.file_type)( + media=result.photo, caption=result.caption, parse_mode=result.parse_mode, filename=result.filename + ) + for result in self.results + ], + *args, + **kwargs, + ) + + for index, value in enumerate(reply): + result = self.results[index] + await result.cache_file_id(value) diff --git a/gram_core/services/template/services.py b/gram_core/services/template/services.py new file mode 100644 index 0000000..2cf3273 --- /dev/null +++ b/gram_core/services/template/services.py @@ -0,0 +1,207 @@ +import asyncio +from typing import Optional +from urllib.parse import urlencode, urljoin, urlsplit +from uuid import uuid4 + +from fastapi import FastAPI, HTTPException +from fastapi.responses import FileResponse, HTMLResponse +from fastapi.staticfiles import StaticFiles +from jinja2 import Environment, FileSystemLoader, Template +from playwright.async_api import ViewportSize + +from gram_core.application import Application +from gram_core.base_service import BaseService +from gram_core.config import config as application_config +from gram_core.dependence.aiobrowser import AioBrowser +from gram_core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache +from gram_core.services.template.error import QuerySelectorNotFound +from gram_core.services.template.models import FileType, RenderResult +from utils.const import PROJECT_ROOT +from utils.log import logger + +__all__ = ("TemplateService", "TemplatePreviewer") + + +class TemplateService(BaseService): + def __init__( + self, + app: Application, + browser: AioBrowser, + html_to_file_id_cache: HtmlToFileIdCache, + preview_cache: TemplatePreviewCache, + template_dir: str = "resources", + ): + self._browser = browser + self.template_dir = PROJECT_ROOT / template_dir + + self._jinja2_env = Environment( + loader=FileSystemLoader(template_dir), + enable_async=True, + autoescape=True, + auto_reload=application_config.debug, + ) + self.using_preview = application_config.debug and application_config.webserver.enable + + if self.using_preview: + self.previewer = TemplatePreviewer(self, preview_cache, app.web_app) + + self.html_to_file_id_cache = html_to_file_id_cache + + def get_template(self, template_name: str) -> Template: + return self._jinja2_env.get_template(template_name) + + async def render_async(self, template_name: str, template_data: dict) -> str: + """模板渲染 + :param template_name: 模板文件名 + :param template_data: 模板数据 + """ + loop = asyncio.get_event_loop() + start_time = loop.time() + template = self.get_template(template_name) + html = await template.render_async(**template_data) + logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time)) + return html + + async def render( + self, + template_name: str, + template_data: dict, + viewport: Optional[ViewportSize] = None, + full_page: bool = True, + evaluate: Optional[str] = None, + query_selector: Optional[str] = None, + file_type: FileType = FileType.PHOTO, + ttl: int = 24 * 60 * 60, + caption: Optional[str] = None, + parse_mode: Optional[str] = None, + filename: Optional[str] = None, + ) -> RenderResult: + """模板渲染成图片 + :param template_name: 模板文件名 + :param template_data: 模板数据 + :param viewport: 截图大小 + :param full_page: 是否长截图 + :param evaluate: 页面加载后运行的 js + :param query_selector: 截图选择器 + :param file_type: 缓存的文件类型 + :param ttl: 缓存时间 + :param caption: 图片描述 + :param parse_mode: 图片描述解析模式 + :param filename: 文件名字 + :return: + """ + loop = asyncio.get_event_loop() + start_time = loop.time() + template = self.get_template(template_name) + + if self.using_preview: + preview_url = await self.previewer.get_preview_url(template_name, template_data) + logger.debug("调试模板 URL: \n%s", preview_url) + + html = await template.render_async(**template_data) + logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time)) + + file_id = await self.html_to_file_id_cache.get_data(html, file_type.name) + if file_id and not application_config.debug: + logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id) + return RenderResult( + html=html, + photo=file_id, + file_type=file_type, + cache=self.html_to_file_id_cache, + ttl=ttl, + caption=caption, + parse_mode=parse_mode, + filename=filename, + ) + + browser = await self._browser.get_browser() + start_time = loop.time() + page = await browser.new_page(viewport=viewport) + uri = (PROJECT_ROOT / template.filename).as_uri() + await page.goto(uri) + await page.set_content(html, wait_until="networkidle") + if evaluate: + await page.evaluate(evaluate) + clip = None + if query_selector: + try: + card = await page.query_selector(query_selector) + if not card: + raise QuerySelectorNotFound + clip = await card.bounding_box() + if not clip: + raise QuerySelectorNotFound + except QuerySelectorNotFound: + logger.warning("未找到 %s 元素", query_selector) + png_data = await page.screenshot(clip=clip, full_page=full_page) + await page.close() + logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time)) + return RenderResult( + html=html, + photo=png_data, + file_type=file_type, + cache=self.html_to_file_id_cache, + ttl=ttl, + caption=caption, + parse_mode=parse_mode, + filename=filename, + ) + + +class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug): + def __init__( + self, + template_service: TemplateService, + cache: TemplatePreviewCache, + web_app: FastAPI, + ): + self.web_app = web_app + self.template_service = template_service + self.cache = cache + self.register_routes() + + async def get_preview_url(self, template: str, data: dict): + """获取预览 URL""" + components = urlsplit(application_config.webserver.url) + path = urljoin("/preview/", template) + query = {} + + # 如果有数据,暂存在 redis 中 + if data: + key = str(uuid4()) + await self.cache.set_data(key, data) + query["key"] = key + + # noinspection PyProtectedMember + return components._replace(path=path, query=urlencode(query)).geturl() + + def register_routes(self): + """注册预览用到的路由""" + + @self.web_app.get("/preview/{path:path}") + async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612 + # 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源 + if not path.endswith((".html", ".jinja2")): + full_path = self.template_service.template_dir / path + if not full_path.is_file(): + raise HTTPException(status_code=404, detail=f"Template '{path}' not found") + return FileResponse(full_path) + + # 取回暂存的渲染数据 + data = await self.cache.get_data(key) if key else {} + if key and data is None: + raise HTTPException(status_code=404, detail=f"Template data {key} not found") + + # 渲染 jinja2 模板 + html = await self.template_service.render_async(path, data) + # 将本地 URL file:// 修改为 HTTP url,因为浏览器内不允许加载本地文件 + # file:///project_dir/cache/image.jpg => /cache/image.jpg + html = html.replace(PROJECT_ROOT.as_uri(), "") + return HTMLResponse(html) + + # 其他静态资源 + for name in ["cache", "resources"]: + directory = PROJECT_ROOT / name + directory.mkdir(exist_ok=True) + self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name) diff --git a/gram_core/services/users/__init__.py b/gram_core/services/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gram_core/services/users/cache.py b/gram_core/services/users/cache.py new file mode 100644 index 0000000..3d7eda2 --- /dev/null +++ b/gram_core/services/users/cache.py @@ -0,0 +1,24 @@ +from typing import List + +from gram_core.base_service import BaseService +from gram_core.dependence.redisdb import RedisDB + +__all__ = ("UserAdminCache",) + + +class UserAdminCache(BaseService.Component): + def __init__(self, redis: RedisDB): + self.client = redis.client + self.qname = "users:admin" + + async def ismember(self, user_id: int) -> bool: + return await self.client.sismember(self.qname, user_id) + + async def get_all(self) -> List[int]: + return [int(str_data) for str_data in await self.client.smembers(self.qname)] + + async def set(self, user_id: int) -> bool: + return await self.client.sadd(self.qname, user_id) + + async def remove(self, user_id: int) -> bool: + return await self.client.srem(self.qname, user_id) diff --git a/gram_core/services/users/models.py b/gram_core/services/users/models.py new file mode 100644 index 0000000..5d5fa1c --- /dev/null +++ b/gram_core/services/users/models.py @@ -0,0 +1,34 @@ +import enum +from datetime import datetime +from typing import Optional + +from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer + +__all__ = ( + "User", + "UserDataBase", + "PermissionsEnum", +) + + +class PermissionsEnum(int, enum.Enum): + OWNER = 1 + ADMIN = 2 + PUBLIC = 3 + + +class User(SQLModel): + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) + ) + user_id: int = Field(unique=True, sa_column=Column(BigInteger())) + permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum))) + locale: Optional[str] = Field() + ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True))) + ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True))) + is_banned: Optional[int] = Field() + + +class UserDataBase(User, table=True): + __tablename__ = "users" diff --git a/gram_core/services/users/repositories.py b/gram_core/services/users/repositories.py new file mode 100644 index 0000000..a52d0bf --- /dev/null +++ b/gram_core/services/users/repositories.py @@ -0,0 +1,44 @@ +from typing import Optional, List + +from sqlmodel import select + +from gram_core.base_service import BaseService +from gram_core.dependence.database import Database +from gram_core.services.users.models import UserDataBase as User +from gram_core.sqlmodel.session import AsyncSession + +__all__ = ("UserRepository",) + + +class UserRepository(BaseService.Component): + def __init__(self, database: Database): + self.engine = database.engine + + async def get_by_user_id(self, user_id: int) -> Optional[User]: + async with AsyncSession(self.engine) as session: + statement = select(User).where(User.user_id == user_id) + results = await session.exec(statement) + return results.first() + + async def add(self, user: User): + async with AsyncSession(self.engine) as session: + session.add(user) + await session.commit() + + async def update(self, user: User) -> User: + async with AsyncSession(self.engine) as session: + session.add(user) + await session.commit() + await session.refresh(user) + return user + + async def remove(self, user: User): + async with AsyncSession(self.engine) as session: + await session.delete(user) + await session.commit() + + async def get_all(self) -> List[User]: + async with AsyncSession(self.engine) as session: + statement = select(User) + results = await session.exec(statement) + return results.all() diff --git a/gram_core/services/users/services.py b/gram_core/services/users/services.py new file mode 100644 index 0000000..a593f7c --- /dev/null +++ b/gram_core/services/users/services.py @@ -0,0 +1,83 @@ +from typing import List, Optional + +from gram_core.base_service import BaseService +from gram_core.config import config +from gram_core.services.users.cache import UserAdminCache +from gram_core.services.users.models import PermissionsEnum, UserDataBase as User +from gram_core.services.users.repositories import UserRepository + +__all__ = ("UserService", "UserAdminService") + +from utils.log import logger + + +class UserService(BaseService): + def __init__(self, user_repository: UserRepository) -> None: + self._repository: UserRepository = user_repository + + async def get_user_by_id(self, user_id: int) -> Optional[User]: + """从数据库获取用户信息 + :param user_id:用户ID + :return: User + """ + return await self._repository.get_by_user_id(user_id) + + async def remove(self, user: User): + return await self._repository.remove(user) + + async def update_user(self, user: User): + return await self._repository.add(user) + + +class UserAdminService(BaseService): + def __init__(self, user_repository: UserRepository, cache: UserAdminCache): + self.user_repository = user_repository + self._cache = cache + + async def initialize(self): + owner = config.owner + if owner: + user = await self.user_repository.get_by_user_id(owner) + if user: + if user.permissions != PermissionsEnum.OWNER: + user.permissions = PermissionsEnum.OWNER + await self._cache.set(user.user_id) + await self.user_repository.update(user) + else: + user = User(user_id=owner, permissions=PermissionsEnum.OWNER) + await self._cache.set(user.user_id) + await self.user_repository.add(user) + else: + logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限") + users = await self.user_repository.get_all() + for user in users: + await self._cache.set(user.user_id) + + async def is_admin(self, user_id: int) -> bool: + return await self._cache.ismember(user_id) + + async def get_admin_list(self) -> List[int]: + return await self._cache.get_all() + + async def add_admin(self, user_id: int) -> bool: + user = await self.user_repository.get_by_user_id(user_id) + if user: + if user.permissions == PermissionsEnum.OWNER: + return False + if user.permissions != PermissionsEnum.ADMIN: + user.permissions = PermissionsEnum.ADMIN + await self.user_repository.update(user) + else: + user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN) + await self.user_repository.add(user) + return await self._cache.set(user_id) + + async def delete_admin(self, user_id: int) -> bool: + user = await self.user_repository.get_by_user_id(user_id) + if user: + if user.permissions == PermissionsEnum.OWNER: + return True # 假装移除成功 + user.permissions = PermissionsEnum.PUBLIC + await self.user_repository.update(user) + return await self._cache.remove(user.user_id) + return False diff --git a/gram_core/sqlmodel/__init__.py b/gram_core/sqlmodel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gram_core/sqlmodel/session.py b/gram_core/sqlmodel/session.py new file mode 100644 index 0000000..88e4d3d --- /dev/null +++ b/gram_core/sqlmodel/session.py @@ -0,0 +1,118 @@ +from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload + +from sqlalchemy import util +from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession +from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine +from sqlalchemy.sql.base import Executable as _Executable +from sqlmodel.engine.result import Result, ScalarResult +from sqlmodel.orm.session import Session +from sqlmodel.sql.base import Executable +from sqlmodel.sql.expression import Select, SelectOfScalar +from typing_extensions import Literal + +_TSelectParam = TypeVar("_TSelectParam") + +__all__ = ("AsyncSession",) + + +class AsyncSession(_AsyncSession): # pylint: disable=W0223 + sync_session_class = Session + sync_session: Session + + def __init__( + self, + bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, + binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, + sync_session_class: Type[Session] = Session, + **kw: Any, + ): + super().__init__( + bind=bind, + binds=binds, + sync_session_class=sync_session_class, + **kw, + ) + + @overload + async def exec( + self, + statement: Select[_TSelectParam], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> Result[_TSelectParam]: + ... + + @overload + async def exec( + self, + statement: SelectOfScalar[_TSelectParam], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> ScalarResult[_TSelectParam]: + ... + + async def exec( + self, + statement: Union[ + Select[_TSelectParam], + SelectOfScalar[_TSelectParam], + Executable[_TSelectParam], + ], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]: + results = super().execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + if isinstance(statement, SelectOfScalar): + return (await results).scalars() # type: ignore + return await results # type: ignore + + async def execute( # pylint: disable=W0221 + self, + statement: _Executable, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> Result[Any]: + return await super().execute( # type: ignore + statement=statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + async def get( # pylint: disable=W0221 + self, + entity: Type[_TSelectParam], + ident: Any, + options: Optional[Sequence[Any]] = None, + populate_existing: bool = False, + with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, + identity_token: Optional[Any] = None, + execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, + ) -> Optional[_TSelectParam]: + return await super().get( + entity=entity, + ident=ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) diff --git a/gram_core/version.py b/gram_core/version.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/gram_core/version.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..65cc592 --- /dev/null +++ b/setup.py @@ -0,0 +1,34 @@ +"""Run setuptools.""" + +from setuptools import find_packages, setup + +from gram_core.version import __version__ + + +def get_setup_kwargs(): + """Builds a dictionary of kwargs for the setup function""" + kwargs = dict( + script_name="setup.py", + name="gram_core", + version=__version__, + author="PaiGramTeam", + url="https://github.com/PaiGramTeam/GramCore", + keywords="telegram robot base core", + description="telegram robot base core.", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + packages=find_packages(exclude=["tests*"]), + install_requires=[], + include_package_data=True, + python_requires=">=3.8", + ) + + return kwargs + + +def main(): # skipcq: PY-D0003 + setup(**get_setup_kwargs()) + + +if __name__ == "__main__": + main()