From b1141c65b8df51bf0cbcecb8d9cd0b3f4ab49616 Mon Sep 17 00:00:00 2001 From: omg-xtao <100690902+omg-xtao@users.noreply.github.com> Date: Wed, 2 Aug 2023 20:10:08 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20separate=20core=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> --- core/application.py | 286 +----------------- core/base_service.py | 59 +--- core/basemodel.py | 28 +- core/builtins/__init__.py | 1 - core/builtins/contexts.py | 38 --- core/builtins/dispatcher.py | 309 ------------------- core/builtins/executor.py | 131 -------- core/builtins/reloader.py | 185 ------------ core/config.py | 160 +--------- core/dependence/aiobrowser.py | 55 +--- core/dependence/aiobrowser.pyi | 16 - core/dependence/database.py | 50 +--- core/dependence/mtproto.py | 68 +---- core/dependence/mtproto.pyi | 31 -- core/dependence/redisdb.py | 49 +-- core/error.py | 7 +- core/handler/adminhandler.py | 60 +--- core/handler/callbackqueryhandler.py | 63 +--- core/handler/limiterhandler.py | 72 +---- core/manager.py | 286 ------------------ core/override/telegram.py | 117 -------- core/plugin/__init__.py | 6 +- core/plugin/_funcs.py | 178 ----------- core/plugin/_handler.py | 380 ------------------------ core/plugin/_job.py | 173 ----------- core/plugin/_plugin.py | 314 -------------------- core/ratelimiter.py | 67 ----- core/services/cookies/cache.py | 96 +----- core/services/cookies/error.py | 13 +- core/services/cookies/models.py | 38 +-- core/services/cookies/repositories.py | 54 +--- core/services/cookies/services.py | 213 +++++-------- core/services/devices/models.py | 22 +- core/services/devices/repositories.py | 40 +-- core/services/devices/services.py | 26 +- core/services/players/error.py | 5 +- core/services/players/models.py | 95 +----- core/services/players/repositories.py | 109 +------ core/services/players/services.py | 42 +-- core/services/task/models.py | 43 +-- core/services/task/repositories.py | 49 +-- core/services/task/services.py | 178 +---------- core/services/template/cache.py | 57 +--- core/services/template/error.py | 20 +- core/services/template/models.py | 145 +-------- core/services/template/services.py | 206 +------------ core/services/users/cache.py | 23 +- core/services/users/models.py | 29 +- core/services/users/repositories.py | 43 +-- core/services/users/services.py | 82 +---- core/{override => sqlmodel}/__init__.py | 0 core/sqlmodel/session.py | 117 +------- poetry.lock | 17 +- pyproject.toml | 1 + requirements.txt | 1 + run.py | 2 +- 56 files changed, 148 insertions(+), 4807 deletions(-) delete mode 100644 core/builtins/__init__.py delete mode 100644 core/builtins/contexts.py delete mode 100644 core/builtins/dispatcher.py delete mode 100644 core/builtins/executor.py delete mode 100644 core/builtins/reloader.py delete mode 100644 core/dependence/aiobrowser.pyi delete mode 100644 core/dependence/mtproto.pyi delete mode 100644 core/manager.py delete mode 100644 core/override/telegram.py delete mode 100644 core/plugin/_funcs.py delete mode 100644 core/plugin/_handler.py delete mode 100644 core/plugin/_job.py delete mode 100644 core/plugin/_plugin.py delete mode 100644 core/ratelimiter.py rename core/{override => sqlmodel}/__init__.py (100%) diff --git a/core/application.py b/core/application.py index 2dc0fab..d4984cb 100644 --- a/core/application.py +++ b/core/application.py @@ -1,289 +1,5 @@ """BOT""" -import asyncio -import signal -from functools import wraps -from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func -from ssl import SSLZeroReturnError -from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar -import pytz -import uvicorn -from fastapi import FastAPI -from telegram import Bot, Update -from telegram.error import NetworkError, TelegramError, TimedOut -from telegram.ext import ( - Application as TelegramApplication, - ApplicationBuilder as TelegramApplicationBuilder, - Defaults, - JobQueue, -) -from typing_extensions import ParamSpec -from uvicorn import Server - -from core.config import config as application_config -from core.handler.limiterhandler import LimiterHandler -from core.manager import Managers -from core.override.telegram import HTTPXRequest -from core.ratelimiter import RateLimiter -from utils.const import WRAPPER_ASSIGNMENTS -from utils.log import logger -from utils.models.signal import Singleton - -if TYPE_CHECKING: - from asyncio import Task - from types import FrameType +from gram_core.application import Application __all__ = ("Application",) - -R = TypeVar("R") -T = TypeVar("T") -P = ParamSpec("P") - - -class Application(Singleton): - """Application""" - - _web_server_task: Optional["Task"] = None - - _startup_funcs: List[Callable] = [] - _shutdown_funcs: List[Callable] = [] - - def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None: - self._running = False - self.managers = managers - self.telegram = telegram - self.web_server = web_server - self.managers.set_application(application=self) # 给 managers 设置 application - self.managers.build_executor("Application") - - @classmethod - def build(cls): - managers = Managers() - telegram = ( - TelegramApplicationBuilder() - .get_updates_read_timeout(application_config.update_read_timeout) - .get_updates_write_timeout(application_config.update_write_timeout) - .get_updates_connect_timeout(application_config.update_connect_timeout) - .get_updates_pool_timeout(application_config.update_pool_timeout) - .defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai"))) - .token(application_config.bot_token) - .request( - HTTPXRequest( - connection_pool_size=application_config.connection_pool_size, - proxy_url=application_config.proxy_url, - read_timeout=application_config.read_timeout, - write_timeout=application_config.write_timeout, - connect_timeout=application_config.connect_timeout, - pool_timeout=application_config.pool_timeout, - ) - ) - .rate_limiter(RateLimiter()) - .build() - ) - web_server = Server( - uvicorn.Config( - app=FastAPI(debug=application_config.debug), - port=application_config.webserver.port, - host=application_config.webserver.host, - log_config=None, - ) - ) - return cls(managers, telegram, web_server) - - @property - def running(self) -> bool: - """bot 是否正在运行""" - with self._lock: - return self._running - - @property - def web_app(self) -> FastAPI: - """fastapi app""" - return self.web_server.config.app - - @property - def bot(self) -> Optional[Bot]: - return self.telegram.bot - - @property - def job_queue(self) -> Optional[JobQueue]: - return self.telegram.job_queue - - async def _on_startup(self) -> None: - for func in self._startup_funcs: - await self.managers.executor(func, block=getattr(func, "block", False)) - - async def _on_shutdown(self) -> None: - for func in self._shutdown_funcs: - await self.managers.executor(func, block=getattr(func, "block", False)) - - async def initialize(self): - """BOT 初始化""" - self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制 - await self.managers.start_dependency() # 启动基础服务 - await self.managers.init_components() # 实例化组件 - await self.managers.start_services() # 启动其他服务 - await self.managers.install_plugins() # 安装插件 - - async def shutdown(self): - """BOT 关闭""" - await self.managers.uninstall_plugins() # 卸载插件 - await self.managers.stop_services() # 终止其他服务 - await self.managers.stop_dependency() # 终止基础服务 - - async def start(self) -> None: - """启动 BOT""" - logger.info("正在启动 BOT 中...") - - def error_callback(exc: TelegramError) -> None: - """错误信息回调""" - self.telegram.create_task(self.telegram.process_error(error=exc, update=None)) - - await self.telegram.initialize() - logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True}) - - if application_config.webserver.enable: # 如果使用 web app - server_config = self.web_server.config - server_config.setup_event_loop() - if not server_config.loaded: - server_config.load() - self.web_server.lifespan = server_config.lifespan_class(server_config) - try: - await self.web_server.startup() - except OSError as e: - if e.errno == 10048: - logger.error("Web Server 端口被占用:%s", e) - logger.error("Web Server 启动失败,正在退出") - raise SystemExit from None - - if self.web_server.should_exit: - logger.error("Web Server 启动失败,正在退出") - raise SystemExit from None - logger.success("Web Server 启动成功") - - self._web_server_task = asyncio.create_task(self.web_server.main_loop()) - - for _ in range(5): # 连接至 telegram 服务器 - try: - await self.telegram.updater.start_polling( - error_callback=error_callback, allowed_updates=Update.ALL_TYPES - ) - break - except TimedOut: - logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True}) - continue - except NetworkError as e: - logger.exception() - if isinstance(e, SSLZeroReturnError): - logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.") - else: - logger.error("网络连接出现问题, 请检查您的网络状况.") - raise SystemExit from e - - await self.initialize() - logger.success("BOT 初始化成功") - logger.debug("BOT 开始启动") - - await self._on_startup() - await self.telegram.start() - self._running = True - logger.success("BOT 启动成功") - - def stop_signal_handler(self, signum: int): - """终止信号处理""" - signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")} - logger.debug("接收到了终止信号 %s 正在退出...", signals[signum]) - if self._web_server_task: - self._web_server_task.cancel() - - async def idle(self) -> None: - """在接收到中止信号之前,堵塞loop""" - - task = None - - def stop_handler(signum: int, _: "FrameType") -> None: - self.stop_signal_handler(signum) - task.cancel() - - for s in (SIGINT, SIGTERM, SIGABRT): - signal_func(s, stop_handler) - - while True: - task = asyncio.create_task(asyncio.sleep(600)) - - try: - await task - except asyncio.CancelledError: - break - - async def stop(self) -> None: - """关闭""" - logger.info("BOT 正在关闭") - self._running = False - - await self._on_shutdown() - - if self.telegram.updater.running: - await self.telegram.updater.stop() - - await self.shutdown() - - if self.telegram.running: - await self.telegram.stop() - - await self.telegram.shutdown() - if self.web_server is not None: - try: - await self.web_server.shutdown() - logger.info("Web Server 已经关闭") - except AttributeError: - pass - - logger.success("BOT 关闭成功") - - def launch(self) -> None: - """启动""" - loop = asyncio.get_event_loop() - try: - loop.run_until_complete(self.start()) - loop.run_until_complete(self.idle()) - except (SystemExit, KeyboardInterrupt) as exc: - logger.debug("接收到了终止信号,BOT 即将关闭", exc_info=exc) # 接收到了终止信号 - except NetworkError as e: - if isinstance(e, SSLZeroReturnError): - logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.") - else: - logger.critical("网络连接出现问题, 请检查您的网络状况.") - except Exception as e: - logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e) - finally: - loop.run_until_complete(self.stop()) - - if application_config.reload: - raise SystemExit from None - - def on_startup(self, func: Callable[P, R]) -> Callable[P, R]: - """注册一个在 BOT 启动时执行的函数""" - - if func not in self._startup_funcs: - self._startup_funcs.append(func) - - # noinspection PyTypeChecker - @wraps(func, assigned=WRAPPER_ASSIGNMENTS) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - return func(*args, **kwargs) - - return wrapper - - def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]: - """注册一个在 BOT 停止时执行的函数""" - - if func not in self._shutdown_funcs: - self._shutdown_funcs.append(func) - - # noinspection PyTypeChecker - @wraps(func, assigned=WRAPPER_ASSIGNMENTS) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - return func(*args, **kwargs) - - return wrapper diff --git a/core/base_service.py b/core/base_service.py index c61a6e8..17873cb 100644 --- a/core/base_service.py +++ b/core/base_service.py @@ -1,60 +1,3 @@ -from abc import ABC -from itertools import chain -from typing import ClassVar, Iterable, Type, TypeVar - -from typing_extensions import Self - -from utils.helpers import isabstract +from gram_core.base_service import BaseService, BaseServiceType, DependenceType, ComponentType, get_all_services __all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services") - - -class _BaseService: - """服务基类""" - - _is_component: ClassVar[bool] = False - _is_dependence: ClassVar[bool] = False - - def __init_subclass__(cls, load: bool = True, **kwargs): - cls.is_dependence = cls._is_dependence - cls.is_component = cls._is_component - cls.load = load - - async def __aenter__(self) -> Self: - await self.initialize() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - await self.shutdown() - - async def initialize(self) -> None: - """Initialize resources used by this service""" - - async def shutdown(self) -> None: - """Stop & clear resources used by this service""" - - -class _Dependence(_BaseService, ABC): - _is_dependence: ClassVar[bool] = True - - -class _Component(_BaseService, ABC): - _is_component: ClassVar[bool] = True - - -class BaseService(_BaseService, ABC): - Dependence: Type[_BaseService] = _Dependence - Component: Type[_BaseService] = _Component - - -BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService) -DependenceType = TypeVar("DependenceType", bound=_Dependence) -ComponentType = TypeVar("ComponentType", bound=_Component) - - -# noinspection PyProtectedMember -def get_all_services() -> Iterable[Type[_BaseService]]: - return filter( - lambda x: x.__name__[0] != "_" and x.load and not isabstract(x), - chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()), - ) diff --git a/core/basemodel.py b/core/basemodel.py index c65f58a..ac0f42d 100644 --- a/core/basemodel.py +++ b/core/basemodel.py @@ -1,29 +1,3 @@ -import enum - -try: - import ujson as jsonlib -except ImportError: - import json as jsonlib - -from pydantic import BaseSettings +from gram_core.basemodel import RegionEnum, Settings __all__ = ("RegionEnum", "Settings") - - -class RegionEnum(int, enum.Enum): - """账号数据所在服务器""" - - NULL = 0 - HYPERION = 1 # 米忽悠国服 hyperion - HOYOLAB = 2 # 米忽悠国际服 hoyolab - - -class Settings(BaseSettings): - def __new__(cls, *args, **kwargs): - cls.update_forward_refs() - return super(Settings, cls).__new__(cls) # pylint: disable=E1120 - - class Config(BaseSettings.Config): - case_sensitive = False - json_loads = jsonlib.loads - json_dumps = jsonlib.dumps diff --git a/core/builtins/__init__.py b/core/builtins/__init__.py deleted file mode 100644 index 4f29666..0000000 --- a/core/builtins/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""bot builtins""" diff --git a/core/builtins/contexts.py b/core/builtins/contexts.py deleted file mode 100644 index 832c978..0000000 --- a/core/builtins/contexts.py +++ /dev/null @@ -1,38 +0,0 @@ -"""上下文管理""" -from contextlib import contextmanager -from contextvars import ContextVar -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from telegram.ext import CallbackContext - from telegram import Update - -__all__ = [ - "CallbackContextCV", - "UpdateCV", - "handler_contexts", - "job_contexts", -] - -CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback") -UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate") - - -@contextmanager -def handler_contexts(update: "Update", context: "CallbackContext") -> None: - context_token = CallbackContextCV.set(context) - update_token = UpdateCV.set(update) - try: - yield - finally: - CallbackContextCV.reset(context_token) - UpdateCV.reset(update_token) - - -@contextmanager -def job_contexts(context: "CallbackContext") -> None: - token = CallbackContextCV.set(context) - try: - yield - finally: - CallbackContextCV.reset(token) diff --git a/core/builtins/dispatcher.py b/core/builtins/dispatcher.py deleted file mode 100644 index a51cfcd..0000000 --- a/core/builtins/dispatcher.py +++ /dev/null @@ -1,309 +0,0 @@ -"""参数分发器""" -import asyncio -import inspect -from abc import ABC, abstractmethod -from asyncio import AbstractEventLoop -from functools import cached_property, lru_cache, partial, wraps -from inspect import Parameter, Signature -from itertools import chain -from types import GenericAlias, MethodType -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Type, - Union, -) - -from arkowrapper import ArkoWrapper -from fastapi import FastAPI -from telegram import Bot as TelegramBot, Chat, Message, Update, User -from telegram.ext import Application as TelegramApplication, CallbackContext, Job -from typing_extensions import ParamSpec -from uvicorn import Server - -from core.application import Application -from utils.const import WRAPPER_ASSIGNMENTS -from utils.typedefs import R, T - -__all__ = ( - "catch", - "AbstractDispatcher", - "BaseDispatcher", - "HandlerDispatcher", - "JobDispatcher", - "dispatched", -) - -P = ParamSpec("P") - -TargetType = Union[Type, str, Callable[[Any], bool]] - -_CATCH_TARGET_ATTR = "_catch_targets" - - -def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]: - def decorate(func: Callable[P, R]) -> Callable[P, R]: - setattr(func, _CATCH_TARGET_ATTR, targets) - - @wraps(func, assigned=WRAPPER_ASSIGNMENTS) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - return func(*args, **kwargs) - - return wrapper - - return decorate - - -@lru_cache(64) -def get_signature(func: Union[type, Callable]) -> Signature: - if isinstance(func, type): - return inspect.signature(func.__init__) - return inspect.signature(func) - - -class AbstractDispatcher(ABC): - """参数分发器""" - - IGNORED_ATTRS = [] - - _args: List[Any] = [] - _kwargs: Dict[Union[str, Type], Any] = {} - _application: "Optional[Application]" = None - - def set_application(self, application: "Application") -> None: - self._application = application - - @property - def application(self) -> "Application": - if self._application is None: - raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") - return self._application - - def __init__(self, *args, **kwargs) -> None: - self._args = list(args) - self._kwargs = dict(kwargs) - - for _, value in kwargs.items(): - type_arg = type(value) - if type_arg != str: - self._kwargs[type_arg] = value - - for arg in args: - type_arg = type(arg) - if type_arg != str: - self._kwargs[type_arg] = arg - - @cached_property - def catch_funcs(self) -> List[MethodType]: - # noinspection PyTypeChecker - return list( - ArkoWrapper(dir(self)) - .filter(lambda x: not x.startswith("_")) - .filter( - lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"] - ) - .map(lambda x: getattr(self, x)) - .filter(lambda x: isinstance(x, MethodType)) - .filter(lambda x: hasattr(x, "_catch_targets")) - ) - - @cached_property - def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]: - result = {} - for catch_func in self.catch_funcs: - catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR) - for catch_target in catch_targets: - result[catch_target] = catch_func - return result - - @cached_property - def dispatch_funcs(self) -> List[MethodType]: - return list( - ArkoWrapper(dir(self)) - .filter(lambda x: x.startswith("dispatch_by_")) - .map(lambda x: getattr(self, x)) - .filter(lambda x: isinstance(x, MethodType)) - ) - - @abstractmethod - def dispatch_by_default(self, parameter: Parameter) -> Parameter: - """默认的 dispatch 方法""" - - @abstractmethod - def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter: - """使用 catch_func 获取并分配参数""" - - def dispatch(self, func: Callable[P, R]) -> Callable[..., R]: - """将参数分配给函数,从而合成一个无需参数即可执行的函数""" - params = {} - signature = get_signature(func) - parameters: Dict[str, Parameter] = dict(signature.parameters) - - for name, parameter in list(parameters.items()): - parameter: Parameter - if any( - [ - name == "self" and isinstance(func, (type, MethodType)), - parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL], - ] - ): - del parameters[name] - continue - - for dispatch_func in self.dispatch_funcs: - parameters[name] = dispatch_func(parameter) - - for name, parameter in parameters.items(): - if parameter.default != Parameter.empty: - params[name] = parameter.default - else: - params[name] = None - - return partial(func, **params) - - @catch(Application) - def catch_application(self) -> Application: - return self.application - - -class BaseDispatcher(AbstractDispatcher): - """默认参数分发器""" - - _instances: Sequence[Any] - - def _get_kwargs(self) -> Dict[Type[T], T]: - result = self._get_default_kwargs() - result[AbstractDispatcher] = self - result.update(self._kwargs) - return result - - def _get_default_kwargs(self) -> Dict[Type[T], T]: - application = self.application - _default_kwargs = { - FastAPI: application.web_app, - Server: application.web_server, - TelegramApplication: application.telegram, - TelegramBot: application.telegram.bot, - } - if not application.running: - for obj in chain( - application.managers.dependency, - application.managers.components, - application.managers.services, - application.managers.plugins, - ): - _default_kwargs[type(obj)] = obj - return {k: v for k, v in _default_kwargs.items() if v is not None} - - def dispatch_by_default(self, parameter: Parameter) -> Parameter: - annotation = parameter.annotation - # noinspection PyTypeChecker - if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None: - parameter._default = value # pylint: disable=W0212 - return parameter - - def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter: - annotation = parameter.annotation - if annotation != Any and isinstance(annotation, GenericAlias): - return parameter - - catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name) - if catch_func is not None: - # noinspection PyUnresolvedReferences,PyProtectedMember - parameter._default = catch_func() # pylint: disable=W0212 - return parameter - - @catch(AbstractEventLoop) - def catch_loop(self) -> AbstractEventLoop: - return asyncio.get_event_loop() - - -class HandlerDispatcher(BaseDispatcher): - """Handler 参数分发器""" - - def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None: - super().__init__(update=update, context=context, **kwargs) - self._update = update - self._context = context - - def dispatch( - self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None - ) -> Callable[..., R]: - self._update = update or self._update - self._context = context or self._context - if self._update is None: - from core.builtins.contexts import UpdateCV - - self._update = UpdateCV.get() - if self._context is None: - from core.builtins.contexts import CallbackContextCV - - self._context = CallbackContextCV.get() - return super().dispatch(func) - - def dispatch_by_default(self, parameter: Parameter) -> Parameter: - """HandlerDispatcher 默认不使用 dispatch_by_default""" - return parameter - - @catch(Update) - def catch_update(self) -> Update: - return self._update - - @catch(CallbackContext) - def catch_context(self) -> CallbackContext: - return self._context - - @catch(Message) - def catch_message(self) -> Message: - return self._update.effective_message - - @catch(User) - def catch_user(self) -> User: - return self._update.effective_user - - @catch(Chat) - def catch_chat(self) -> Chat: - return self._update.effective_chat - - -class JobDispatcher(BaseDispatcher): - """Job 参数分发器""" - - def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None: - super().__init__(context=context, **kwargs) - self._context = context - - def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]: - self._context = context or self._context - if self._context is None: - from core.builtins.contexts import CallbackContextCV - - self._context = CallbackContextCV.get() - return super().dispatch(func) - - @catch("data") - def catch_data(self) -> Any: - return self._context.job.data - - @catch(Job) - def catch_job(self) -> Job: - return self._context.job - - @catch(CallbackContext) - def catch_context(self) -> CallbackContext: - return self._context - - -def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher): - def decorate(func: Callable[P, R]) -> Callable[P, R]: - @wraps(func, assigned=WRAPPER_ASSIGNMENTS) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - return dispatcher().dispatch(func)(*args, **kwargs) - - return wrapper - - return decorate diff --git a/core/builtins/executor.py b/core/builtins/executor.py deleted file mode 100644 index 7fd1a54..0000000 --- a/core/builtins/executor.py +++ /dev/null @@ -1,131 +0,0 @@ -"""执行器""" -import inspect -from functools import cached_property -from multiprocessing import RLock as Lock -from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar - -from telegram import Update -from telegram.ext import CallbackContext -from typing_extensions import ParamSpec, Self - -from core.builtins.contexts import handler_contexts, job_contexts - -if TYPE_CHECKING: - from core.application import Application - from core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher - from multiprocessing.synchronize import RLock as LockType - -__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor") - -T = TypeVar("T") -R = TypeVar("R") -P = ParamSpec("P") - - -class BaseExecutor: - """执行器 - Args: - name(str): 该执行器的名称。执行器的名称是唯一的。 - - 只支持执行只拥有 POSITIONAL_OR_KEYWORD 和 KEYWORD_ONLY 两种参数类型的函数 - """ - - _lock: ClassVar["LockType"] = Lock() - _instances: ClassVar[Dict[str, Self]] = {} - _application: "Optional[Application]" = None - - def set_application(self, application: "Application") -> None: - self._application = application - - @property - def application(self) -> "Application": - if self._application is None: - raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") - return self._application - - def __new__(cls: Type[T], name: str, *args, **kwargs) -> T: - with cls._lock: - if (instance := cls._instances.get(name)) is None: - instance = object.__new__(cls) - instance.__init__(name, *args, **kwargs) - cls._instances.update({name: instance}) - return instance - - @cached_property - def name(self) -> str: - """当前执行器的名称""" - return self._name - - def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None: - self._name = name - self._dispatcher = dispatcher - - -class Executor(BaseExecutor, Generic[P, R]): - async def __call__( - self, - target: Callable[P, R], - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - **kwargs, - ) -> R: - dispatcher = self._dispatcher or dispatcher - dispatcher_instance = dispatcher(**kwargs) - dispatcher_instance.set_application(application=self.application) - dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数 - - # 执行 - if inspect.iscoroutinefunction(target): - result = await dispatched_func() - else: - result = dispatched_func() - - return result - - -class HandlerExecutor(BaseExecutor, Generic[P, R]): - """Handler专用执行器""" - - _callback: Callable[P, R] - _dispatcher: "HandlerDispatcher" - - def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None: - if dispatcher is None: - from core.builtins.dispatcher import HandlerDispatcher - - dispatcher = HandlerDispatcher - super().__init__("handler", dispatcher) - self._callback = func - self._dispatcher = dispatcher() - - def set_application(self, application: "Application") -> None: - self._application = application - if self._dispatcher is not None: - self._dispatcher.set_application(application) - - async def __call__(self, update: Update, context: CallbackContext) -> R: - with handler_contexts(update, context): - dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context) - return await dispatched_func() - - -class JobExecutor(BaseExecutor): - """Job 专用执行器""" - - def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None: - if dispatcher is None: - from core.builtins.dispatcher import JobDispatcher - - dispatcher = JobDispatcher - super().__init__("job", dispatcher) - self._callback = func - self._dispatcher = dispatcher() - - def set_application(self, application: "Application") -> None: - self._application = application - if self._dispatcher is not None: - self._dispatcher.set_application(application) - - async def __call__(self, context: CallbackContext) -> R: - with job_contexts(context): - dispatched_func = self._dispatcher.dispatch(self._callback, context=context) - return await dispatched_func() diff --git a/core/builtins/reloader.py b/core/builtins/reloader.py deleted file mode 100644 index 6b09f07..0000000 --- a/core/builtins/reloader.py +++ /dev/null @@ -1,185 +0,0 @@ -import inspect -import multiprocessing -import os -import signal -import threading -from pathlib import Path -from typing import Callable, Iterator, List, Optional, TYPE_CHECKING - -from watchfiles import watch - -from utils.const import HANDLED_SIGNALS, PROJECT_ROOT -from utils.log import logger -from utils.typedefs import StrOrPath - -if TYPE_CHECKING: - from multiprocessing.process import BaseProcess - -__all__ = ("Reloader",) - -multiprocessing.allow_connection_pickling() -spawn = multiprocessing.get_context("spawn") - - -class FileFilter: - """监控文件过滤""" - - def __init__(self, includes: List[str], excludes: List[str]) -> None: - default_includes = ["*.py"] - self.includes = [default for default in default_includes if default not in excludes] - self.includes.extend(includes) - self.includes = list(set(self.includes)) - - default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__] - self.excludes = [default for default in default_excludes if default not in includes] - self.exclude_dirs = [] - for e in excludes: - p = Path(e) - try: - is_dir = p.is_dir() - except OSError: - is_dir = False - - if is_dir: - self.exclude_dirs.append(p) - else: - self.excludes.append(e) - self.excludes = list(set(self.excludes)) - - def __call__(self, path: Path) -> bool: - for include_pattern in self.includes: - if path.match(include_pattern): - for exclude_dir in self.exclude_dirs: - if exclude_dir in path.parents: - return False - - for exclude_pattern in self.excludes: - if path.match(exclude_pattern): - return False - - return True - return False - - -class Reloader: - _target: Callable[..., None] - _process: "BaseProcess" - - @property - def process(self) -> "BaseProcess": - return self._process - - @property - def target(self) -> Callable[..., None]: - return self._target - - def __init__( - self, - target: Callable[..., None], - *, - reload_delay: float = 0.25, - reload_dirs: List[StrOrPath] = None, - reload_includes: List[str] = None, - reload_excludes: List[str] = None, - ): - if inspect.iscoroutinefunction(target): - raise ValueError("不支持异步函数") - self._target = target - - self.reload_delay = reload_delay - - _reload_dirs = [] - for reload_dir in reload_dirs or []: - _reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir))) - - self.reload_dirs = [] - for reload_dir in _reload_dirs: - append = True - for parent in reload_dir.parents: - if parent in _reload_dirs: - append = False - break - if append: - self.reload_dirs.append(reload_dir) - - if not self.reload_dirs: - logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"}) - - self._should_exit = threading.Event() - - frame = inspect.currentframe().f_back - - self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]]) - self.watcher = watch( - *self.reload_dirs, - watch_filter=None, - stop_event=self._should_exit, - yield_on_timeout=True, - ) - - def get_changes(self) -> Optional[List[Path]]: - if not self._process.is_alive(): - logger.info("目标进程已经关闭", extra={"tag": "Reloader"}) - self._should_exit.set() - try: - changes = next(self.watcher) - except StopIteration: - return None - if changes: - unique_paths = {Path(c[1]) for c in changes} - return [p for p in unique_paths if self.watch_filter(p)] - return None - - def __iter__(self) -> Iterator[Optional[List[Path]]]: - return self - - def __next__(self) -> Optional[List[Path]]: - return self.get_changes() - - def run(self) -> None: - self.startup() - for changes in self: - if changes: - logger.warning( - "检测到文件 %s 发生改变, 正在重载...", - [str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes], - extra={"tag": "Reloader"}, - ) - self.restart() - - self.shutdown() - - def signal_handler(self, *_) -> None: - """当接收到结束信号量时""" - self._process.join(3) - if self._process.is_alive(): - self._process.terminate() - self._process.join() - self._should_exit.set() - - def startup(self) -> None: - """启动进程""" - logger.info("目标进程正在启动", extra={"tag": "Reloader"}) - - for sig in HANDLED_SIGNALS: - signal.signal(sig, self.signal_handler) - - self._process = spawn.Process(target=self._target) - self._process.start() - logger.success("目标进程启动成功", extra={"tag": "Reloader"}) - - def restart(self) -> None: - """重启进程""" - self._process.terminate() - self._process.join(10) - - self._process = spawn.Process(target=self._target) - self._process.start() - logger.info("目标进程已经重载", extra={"tag": "Reloader"}) - - def shutdown(self) -> None: - """关闭进程""" - self._process.terminate() - self._process.join(10) - - logger.info("重载器已经关闭", extra={"tag": "Reloader"}) diff --git a/core/config.py b/core/config.py index ce5c859..74a6622 100644 --- a/core/config.py +++ b/core/config.py @@ -1,161 +1,3 @@ -from enum import Enum -from pathlib import Path -from typing import List, Optional, Union - -import dotenv -from pydantic import AnyUrl, Field - -from core.basemodel import Settings -from utils.const import PROJECT_ROOT -from utils.typedefs import NaturalNumber +from gram_core.config import ApplicationConfig, config, JoinGroups __all__ = ("ApplicationConfig", "config", "JoinGroups") - -dotenv.load_dotenv() - - -class JoinGroups(str, Enum): - NO_ALLOW = "NO_ALLOW" - ALLOW_AUTH_USER = "ALLOW_AUTH_USER" - ALLOW_USER = "ALLOW_USER" - ALLOW_ALL = "ALLOW_ALL" - - -class DatabaseConfig(Settings): - driver_name: str = "mysql+asyncmy" - host: Optional[str] = None - port: Optional[int] = None - username: Optional[str] = None - password: Optional[str] = None - database: Optional[str] = None - - class Config(Settings.Config): - env_prefix = "db_" - - -class RedisConfig(Settings): - host: str = "127.0.0.1" - port: int = 6379 - database: int = Field(default=0, env="redis_db") - password: Optional[str] = None - - class Config(Settings.Config): - env_prefix = "redis_" - - -class LoggerConfig(Settings): - name: str = "PaiGram" - width: Optional[int] = None - time_format: str = "[%Y-%m-%d %X]" - traceback_max_frames: int = 20 - path: Path = PROJECT_ROOT / "logs" - render_keywords: List[str] = ["BOT"] - locals_max_length: int = 10 - locals_max_string: int = 80 - locals_max_depth: Optional[NaturalNumber] = None - filtered_names: List[str] = ["uvicorn"] - - class Config(Settings.Config): - env_prefix = "logger_" - - -class MTProtoConfig(Settings): - api_id: Optional[int] = None - api_hash: Optional[str] = None - - -class WebServerConfig(Settings): - enable: bool = False - """是否启用WebServer""" - - url: AnyUrl = "http://localhost:8080" - host: str = "localhost" - port: int = 8080 - - class Config(Settings.Config): - env_prefix = "web_" - - -class ErrorConfig(Settings): - pb_url: str = "" - pb_sunset: int = 43200 - pb_max_lines: int = 1000 - sentry_dsn: str = "" - notification_chat_id: Optional[str] = None - - class Config(Settings.Config): - env_prefix = "error_" - - -class ReloadConfig(Settings): - delay: float = 0.25 - dirs: List[str] = [] - include: List[str] = [] - exclude: List[str] = [] - - class Config(Settings.Config): - env_prefix = "reload_" - - -class NoticeConfig(Settings): - user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!" - - class Config(Settings.Config): - env_prefix = "notice_" - - -class ApplicationConfig(Settings): - debug: bool = False - """debug 开关""" - retry: int = 5 - """重试次数""" - auto_reload: bool = False - """自动重载""" - - proxy_url: Optional[AnyUrl] = None - """代理链接""" - upload_bbs_host: Optional[AnyUrl] = "https://upload-bbs.miyoushe.com" - - bot_token: str = "" - """BOT的token""" - - owner: Optional[int] = None - - channels: List[int] = [] - """文章推送群组""" - - verify_groups: List[Union[int, str]] = [] - """启用群验证功能的群组""" - join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW - """是否允许机器人被邀请到其它群组""" - - timeout: int = 10 - connection_pool_size: int = 256 - read_timeout: Optional[float] = None - write_timeout: Optional[float] = None - connect_timeout: Optional[float] = None - pool_timeout: Optional[float] = None - update_read_timeout: Optional[float] = None - update_write_timeout: Optional[float] = None - update_connect_timeout: Optional[float] = None - update_pool_timeout: Optional[float] = None - - genshin_ttl: Optional[int] = None - - enka_network_api_agent: str = "" - pass_challenge_api: str = "" - pass_challenge_app_key: str = "" - pass_challenge_user_web: str = "" - - reload: ReloadConfig = ReloadConfig() - database: DatabaseConfig = DatabaseConfig() - logger: LoggerConfig = LoggerConfig() - webserver: WebServerConfig = WebServerConfig() - redis: RedisConfig = RedisConfig() - mtproto: MTProtoConfig = MTProtoConfig() - error: ErrorConfig = ErrorConfig() - notice: NoticeConfig = NoticeConfig() - - -ApplicationConfig.update_forward_refs() -config = ApplicationConfig() diff --git a/core/dependence/aiobrowser.py b/core/dependence/aiobrowser.py index 50c4037..7559873 100644 --- a/core/dependence/aiobrowser.py +++ b/core/dependence/aiobrowser.py @@ -1,56 +1,3 @@ -from typing import Optional, TYPE_CHECKING - -from playwright.async_api import Error, async_playwright - -from core.base_service import BaseService -from utils.log import logger - -if TYPE_CHECKING: - from playwright.async_api import Playwright as AsyncPlaywright, Browser +from gram_core.dependence.aiobrowser import AioBrowser __all__ = ("AioBrowser",) - - -class AioBrowser(BaseService.Dependence): - @property - def browser(self): - return self._browser - - def __init__(self, loop=None): - self._browser: Optional["Browser"] = None - self._playwright: Optional["AsyncPlaywright"] = None - self._loop = loop - - async def get_browser(self): - if self._browser is None: - await self.initialize() - return self._browser - - async def initialize(self): - if self._playwright is None: - logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True}) - self._playwright = await async_playwright().start() - logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True}) - if self._browser is None: - logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True}) - try: - self._browser = await self._playwright.chromium.launch(timeout=5000) - logger.success("[blue]Browser[/] 启动成功", extra={"markup": True}) - except Error as err: - if "playwright install" in str(err): - logger.error( - "检查到 [blue]playwright[/] 刚刚安装或者未升级\n" - "请运行以下命令下载新浏览器\n" - "[blue bold]playwright install chromium[/]", - extra={"markup": True}, - ) - raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium") - raise err - - return self._browser - - async def shutdown(self): - if self._browser is not None: - await self._browser.close() - if self._playwright is not None: - self._playwright.stop() diff --git a/core/dependence/aiobrowser.pyi b/core/dependence/aiobrowser.pyi deleted file mode 100644 index b823a61..0000000 --- a/core/dependence/aiobrowser.pyi +++ /dev/null @@ -1,16 +0,0 @@ -from asyncio import AbstractEventLoop - -from playwright.async_api import Browser, Playwright as AsyncPlaywright - -from core.base_service import BaseService - -__all__ = ("AioBrowser",) - -class AioBrowser(BaseService.Dependence): - _browser: Browser | None - _playwright: AsyncPlaywright | None - _loop: AbstractEventLoop - - @property - def browser(self) -> Browser | None: ... - async def get_browser(self) -> Browser: ... diff --git a/core/dependence/database.py b/core/dependence/database.py index 67dc156..a6693bf 100644 --- a/core/dependence/database.py +++ b/core/dependence/database.py @@ -1,51 +1,3 @@ -import contextlib -from typing import Optional - -from sqlalchemy.engine import URL -from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.orm import sessionmaker -from typing_extensions import Self - -from core.base_service import BaseService -from core.config import ApplicationConfig -from core.sqlmodel.session import AsyncSession +from gram_core.dependence.database import Database __all__ = ("Database",) - - -class Database(BaseService.Dependence): - @classmethod - def from_config(cls, config: ApplicationConfig) -> Self: - return cls(**config.database.dict()) - - def __init__( - self, - driver_name: str, - host: Optional[str] = None, - port: Optional[int] = None, - username: Optional[str] = None, - password: Optional[str] = None, - database: Optional[str] = None, - ): - self.database = database # skipcq: PTC-W0052 - self.password = password - self.username = username - self.port = port - self.host = host - self.url = URL.create( - driver_name, - username=self.username, - password=self.password, - host=self.host, - port=self.port, - database=self.database, - ) - self.engine = create_async_engine(self.url) - self.Session = sessionmaker(bind=self.engine, class_=AsyncSession) - - @contextlib.asynccontextmanager - async def session(self) -> AsyncSession: - yield self.Session() - - async def shutdown(self): - self.Session.close_all() diff --git a/core/dependence/mtproto.py b/core/dependence/mtproto.py index b6f9ddc..1386d57 100644 --- a/core/dependence/mtproto.py +++ b/core/dependence/mtproto.py @@ -1,67 +1,3 @@ -import os -from typing import Optional -from urllib.parse import urlparse +from gram_core.dependence.mtproto import MTProto -import aiofiles - -from core.base_service import BaseService -from core.config import config as bot_config -from utils.log import logger - -try: - from pyrogram import Client - from pyrogram.session import session - - session.log.debug = lambda *args, **kwargs: None # 关闭日记 - PYROGRAM_AVAILABLE = True -except ImportError: - Client = None - session = None - PYROGRAM_AVAILABLE = False - - -class MTProto(BaseService.Dependence): - async def get_session(self): - async with aiofiles.open(self.session_path, mode="r") as f: - return await f.read() - - async def set_session(self, b: str): - async with aiofiles.open(self.session_path, mode="w+") as f: - await f.write(b) - - def session_exists(self): - return os.path.exists(self.session_path) - - def __init__(self): - self.name = "paigram" - current_dir = os.getcwd() - self.session_path = os.path.join(current_dir, "paigram.session") - self.client: Optional[Client] = None - self.proxy: Optional[dict] = None - http_proxy = os.environ.get("HTTP_PROXY") - if http_proxy is not None: - http_proxy_url = urlparse(http_proxy) - self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port} - - async def initialize(self): # pylint: disable=W0221 - if not PYROGRAM_AVAILABLE: - logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None") - return - if bot_config.mtproto.api_id is None: - logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None") - return - if bot_config.mtproto.api_hash is None: - logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None") - return - self.client = Client( - api_id=bot_config.mtproto.api_id, - api_hash=bot_config.mtproto.api_hash, - name=self.name, - bot_token=bot_config.bot_token, - proxy=self.proxy, - ) - await self.client.start() - - async def shutdown(self): # pylint: disable=W0221 - if self.client is not None: - await self.client.stop(block=False) +__all__ = ("MTProto",) diff --git a/core/dependence/mtproto.pyi b/core/dependence/mtproto.pyi deleted file mode 100644 index a5f69a1..0000000 --- a/core/dependence/mtproto.pyi +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations -from typing import TypedDict - -from core.base_service import BaseService - -try: - from pyrogram import Client - from pyrogram.session import session - - PYROGRAM_AVAILABLE = True -except ImportError: - Client = None - session = None - PYROGRAM_AVAILABLE = False - -__all__ = ("MTProto",) - -class _ProxyType(TypedDict): - scheme: str - hostname: str | None - port: int | None - -class MTProto(BaseService.Dependence): - name: str - session_path: str - client: Client | None - proxy: _ProxyType | None - - async def get_session(self) -> str: ... - async def set_session(self, b: str) -> None: ... - def session_exists(self) -> bool: ... diff --git a/core/dependence/redisdb.py b/core/dependence/redisdb.py index c94f067..f494b7e 100644 --- a/core/dependence/redisdb.py +++ b/core/dependence/redisdb.py @@ -1,50 +1,3 @@ -from typing import Optional, Union - -import fakeredis.aioredis -from redis import asyncio as aioredis -from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError -from typing_extensions import Self - -from core.base_service import BaseService -from core.config import ApplicationConfig -from utils.log import logger +from gram_core.dependence.redisdb import RedisDB __all__ = ["RedisDB"] - - -class RedisDB(BaseService.Dependence): - @classmethod - def from_config(cls, config: ApplicationConfig) -> Self: - return cls(**config.redis.dict()) - - def __init__( - self, host: str = "127.0.0.1", port: int = 6379, database: Union[str, int] = 0, password: Optional[str] = None - ): - self.client = aioredis.Redis(host=host, port=port, db=database, password=password) - self.ttl = 600 - - async def ping(self): - # noinspection PyUnresolvedReferences - if await self.client.ping(): - logger.info("连接 [red]Redis[/] 成功", extra={"markup": True}) - else: - logger.info("连接 [red]Redis[/] 失败", extra={"markup": True}) - raise RuntimeError("连接 Redis 失败") - - async def start_fake_redis(self): - self.client = fakeredis.aioredis.FakeRedis() - await self.ping() - - async def initialize(self): - logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True}) - try: - await self.ping() - except (RedisTimeoutError, RedisConnectionError) as exc: - if isinstance(exc, RedisTimeoutError): - logger.warning("连接 [red]Redis[/] 超时,使用 [red]fakeredis[/] 模拟", extra={"markup": True}) - if isinstance(exc, RedisConnectionError): - logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True}) - await self.start_fake_redis() - - async def shutdown(self): - await self.client.close() diff --git a/core/error.py b/core/error.py index 344b684..ce92952 100644 --- a/core/error.py +++ b/core/error.py @@ -1,7 +1,4 @@ """此模块包含核心模块的错误的基类""" -from typing import Union +from gram_core.error import ServiceNotFoundError - -class ServiceNotFoundError(Exception): - def __init__(self, name: Union[str, type]): - super().__init__(f"No service named '{name if isinstance(name, str) else name.__name__}'") +__all__ = ("ServiceNotFoundError",) diff --git a/core/handler/adminhandler.py b/core/handler/adminhandler.py index c972a84..12cac02 100644 --- a/core/handler/adminhandler.py +++ b/core/handler/adminhandler.py @@ -1,59 +1,3 @@ -import asyncio -from typing import TypeVar, TYPE_CHECKING, Any, Optional +from gram_core.handler.adminhandler import AdminHandler -from telegram import Update -from telegram.ext import ApplicationHandlerStop, BaseHandler - -from core.error import ServiceNotFoundError -from core.services.users.services import UserAdminService -from utils.log import logger - -if TYPE_CHECKING: - from core.application import Application - from telegram.ext import Application as TelegramApplication - -RT = TypeVar("RT") -UT = TypeVar("UT") - -CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") - - -class AdminHandler(BaseHandler[Update, CCT]): - _lock = asyncio.Lock() - - def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None: - self.handler = handler - self.application = application - self.user_service: Optional["UserAdminService"] = None - super().__init__(self.handler.callback, self.handler.block) - - def check_update(self, update: object) -> bool: - if not isinstance(update, Update): - return False - return self.handler.check_update(update) - - async def _user_service(self) -> "UserAdminService": - async with self._lock: - if self.user_service is not None: - return self.user_service - user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None) - if user_service is None: - raise ServiceNotFoundError("UserAdminService") - self.user_service = user_service - return self.user_service - - async def handle_update( - self, - update: "UT", - application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]", - check_result: Any, - context: "CCT", - ) -> RT: - user_service = await self._user_service() - user = update.effective_user - if await user_service.is_admin(user.id): - return await self.handler.handle_update(update, application, check_result, context) - message = update.effective_message - logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id) - await message.reply_text("权限不足") - raise ApplicationHandlerStop +__all__ = ("AdminHandler",) diff --git a/core/handler/callbackqueryhandler.py b/core/handler/callbackqueryhandler.py index f931e4e..72fe086 100644 --- a/core/handler/callbackqueryhandler.py +++ b/core/handler/callbackqueryhandler.py @@ -1,62 +1,3 @@ -import asyncio -from contextlib import AbstractAsyncContextManager -from types import TracebackType -from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type +from gram_core.handler.callbackqueryhandler import CallbackQueryHandler, OverlappingException, OverlappingContext -from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop - -from utils.log import logger - -if TYPE_CHECKING: - from telegram.ext import Application - -RT = TypeVar("RT") -UT = TypeVar("UT") -CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") - - -class OverlappingException(Exception): - pass - - -class OverlappingContext(AbstractAsyncContextManager): - _lock = asyncio.Lock() - - def __init__(self, context: "CCT"): - self.context = context - - async def __aenter__(self) -> None: - async with self._lock: - flag = self.context.user_data.get("overlapping", False) - if flag: - raise OverlappingException - self.context.user_data["overlapping"] = True - return None - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], - ) -> None: - async with self._lock: - del self.context.user_data["overlapping"] - return None - - -class CallbackQueryHandler(BaseCallbackQueryHandler): - async def handle_update( - self, - update: "UT", - application: "Application[Any, CCT, Any, Any, Any, Any]", - check_result: Any, - context: "CCT", - ) -> RT: - self.collect_additional_context(context, update, application, check_result) - try: - async with OverlappingContext(context): - return await self.callback(update, context) - except OverlappingException as exc: - user = update.effective_user - logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id) - raise ApplicationHandlerStop from exc +__all__ = ("CallbackQueryHandler", "OverlappingException", "OverlappingContext") diff --git a/core/handler/limiterhandler.py b/core/handler/limiterhandler.py index 53bc4c0..cc4bfc5 100644 --- a/core/handler/limiterhandler.py +++ b/core/handler/limiterhandler.py @@ -1,71 +1,3 @@ -import asyncio -from typing import TypeVar, Optional +from gram_core.handler.limiterhandler import LimiterHandler -from telegram import Update -from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler - -from utils.log import logger - -UT = TypeVar("UT") -CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") - - -class LimiterHandler(TypeHandler[UT, CCT]): - _lock = asyncio.Lock() - - def __init__( - self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None - ): - """Limiter Handler 通过 - `Leaky bucket algorithm `_ - 实现对用户的输入的精确控制 - - 输入超过一定速率后,代码会抛出 - :class:`telegram.ext.ApplicationHandlerStop` - 异常并在一段时间内防止用户执行任何其他操作 - - :param max_rate: 在抛出异常之前最多允许 频率/秒 的速度 - :param time_period: 在限制速率的时间段的持续时间 - :param amount: 提供的容量 - :param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount - """ - self.max_rate = max_rate - self.amount = amount - self._rate_per_sec = max_rate / time_period - self.limit_time = limit_time - super().__init__(Update, self.limiter_callback) - - async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE): - if update.inline_query is not None: - return - loop = asyncio.get_running_loop() - async with self._lock: - time = loop.time() - user_data = context.user_data - if user_data is None: - return - user_limit_time = user_data.get("limit_time") - if user_limit_time is not None: - if time >= user_limit_time: - del user_data["limit_time"] - else: - raise ApplicationHandlerStop - last_task_time = user_data.get("last_task_time", 0) - if last_task_time: - task_level = user_data.get("task_level", 0) - elapsed = time - last_task_time - decrement = elapsed * self._rate_per_sec - task_level = max(task_level - decrement, 0) - user_data["task_level"] = task_level - if not task_level + self.amount <= self.max_rate: - if self.limit_time: - limit_time = self.limit_time - else: - limit_time = 1 / self._rate_per_sec * self.amount - user_data["limit_time"] = time + limit_time - user = update.effective_user - logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s 秒", user.full_name, user.id, limit_time) - raise ApplicationHandlerStop - user_data["last_task_time"] = time - task_level = user_data.get("task_level", 0) - user_data["task_level"] = task_level + self.amount +__all__ = ("LimiterHandler",) diff --git a/core/manager.py b/core/manager.py deleted file mode 100644 index aad1512..0000000 --- a/core/manager.py +++ /dev/null @@ -1,286 +0,0 @@ -import asyncio -from importlib import import_module -from pathlib import Path -from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar - -from arkowrapper import ArkoWrapper -from async_timeout import timeout -from typing_extensions import ParamSpec - -from core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services -from core.config import config as bot_config -from utils.const import PLUGIN_DIR, PROJECT_ROOT -from utils.helpers import gen_pkg -from utils.log import logger - -if TYPE_CHECKING: - from core.application import Application - from core.plugin import PluginType - from core.builtins.executor import Executor - -__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers") - -R = TypeVar("R") -T = TypeVar("T") -P = ParamSpec("P") - - -def _load_module(path: Path) -> None: - for pkg in gen_pkg(path): - try: - logger.debug('正在导入 "%s"', pkg) - import_module(pkg) - except Exception as e: - logger.exception( - '在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True} - ) - raise SystemExit from e - - -class Manager(Generic[T]): - """生命周期控制基类""" - - _executor: Optional["Executor"] = None - _lib: Dict[Type[T], T] = {} - _application: "Optional[Application]" = None - - def set_application(self, application: "Application") -> None: - self._application = application - - @property - def application(self) -> "Application": - if self._application is None: - raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") - return self._application - - @property - def executor(self) -> "Executor": - """执行器""" - if self._executor is None: - raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.") - return self._executor - - def build_executor(self, name: str): - from core.builtins.executor import Executor - from core.builtins.dispatcher import BaseDispatcher - - self._executor = Executor(name, dispatcher=BaseDispatcher) - self._executor.set_application(self.application) - - -class DependenceManager(Manager[DependenceType]): - """基础依赖管理""" - - _dependency: Dict[Type[DependenceType], DependenceType] = {} - - @property - def dependency(self) -> List[DependenceType]: - return list(self._dependency.values()) - - @property - def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]: - return self._dependency - - async def start_dependency(self) -> None: - _load_module(PROJECT_ROOT / "core/dependence") - - for dependence in filter(lambda x: x.is_dependence, get_all_services()): - dependence: Type[DependenceType] - instance: DependenceType - try: - if hasattr(dependence, "from_config"): # 如果有 from_config 方法 - instance = dependence.from_config(bot_config) # 用 from_config 实例化服务 - else: - instance = await self.executor(dependence) - - await instance.initialize() - logger.success('基础服务 "%s" 启动成功', dependence.__name__) - - self._lib[dependence] = instance - self._dependency[dependence] = instance - - except Exception as e: - logger.exception('基础服务 "%s" 初始化失败,BOT 将自动关闭', dependence.__name__) - raise SystemExit from e - - async def stop_dependency(self) -> None: - async def task(d): - try: - async with timeout(5): - await d.shutdown() - logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__) - except asyncio.TimeoutError: - logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__) - except Exception as e: - logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e) - - tasks = [] - for dependence in self._dependency.values(): - tasks.append(asyncio.create_task(task(dependence))) - - await asyncio.gather(*tasks) - - -class ComponentManager(Manager[ComponentType]): - """组件管理""" - - _components: Dict[Type[ComponentType], ComponentType] = {} - - @property - def components(self) -> List[ComponentType]: - return list(self._components.values()) - - @property - def components_map(self) -> Dict[Type[ComponentType], ComponentType]: - return self._components - - async def init_components(self): - for path in filter( - lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir() - ): - _load_module(path) - components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component) - retry_times = 0 - max_retry_times = len(components) - while components: - start_len = len(components) - for component in list(components): - component: Type[ComponentType] - instance: ComponentType - try: - instance = await self.executor(component) - self._lib[component] = instance - self._components[component] = instance - components = components.remove(component) - except Exception as e: # pylint: disable=W0703 - logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True}) - end_len = len(list(components)) - if start_len == end_len: - retry_times += 1 - - if retry_times == max_retry_times and components: - for component in components: - logger.error('组件 "%s" 初始化失败', component.__name__) - raise SystemExit - - -class ServiceManager(Manager[BaseServiceType]): - """服务控制类""" - - _services: Dict[Type[BaseServiceType], BaseServiceType] = {} - - @property - def services(self) -> List[BaseServiceType]: - return list(self._services.values()) - - @property - def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]: - return self._services - - async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType: - instance: BaseServiceType - try: - if hasattr(target, "from_config"): # 如果有 from_config 方法 - instance = target.from_config(bot_config) # 用 from_config 实例化服务 - else: - instance = await self.executor(target) - - await instance.initialize() - logger.success('服务 "%s" 启动成功', target.__name__) - - return instance - - except Exception as e: # pylint: disable=W0703 - logger.exception('服务 "%s" 初始化失败,BOT 将自动关闭', target.__name__) - raise SystemExit from e - - async def start_services(self) -> None: - for path in filter( - lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir() - ): - _load_module(path) - - for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类 - instance = await self._initialize_service(service) - - self._lib[service] = instance - self._services[service] = instance - - async def stop_services(self) -> None: - """关闭服务""" - if not self._services: - return - - async def task(s): - try: - async with timeout(5): - await s.shutdown() - logger.success('服务 "%s" 关闭成功', s.__class__.__name__) - except asyncio.TimeoutError: - logger.warning('服务 "%s" 关闭超时', s.__class__.__name__) - except Exception as e: - logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e) - - logger.info("正在关闭服务") - tasks = [] - for service in self._services.values(): - tasks.append(asyncio.create_task(task(service))) - - await asyncio.gather(*tasks) - - -class PluginManager(Manager["PluginType"]): - """插件管理""" - - _plugins: Dict[Type["PluginType"], "PluginType"] = {} - - @property - def plugins(self) -> List["PluginType"]: - """所有已经加载的插件""" - return list(self._plugins.values()) - - @property - def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]: - return self._plugins - - async def install_plugins(self) -> None: - """安装所有插件""" - from core.plugin import get_all_plugins - - for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()): - _load_module(path) - - for plugin in get_all_plugins(): - plugin: Type["PluginType"] - - try: - instance: "PluginType" = await self.executor(plugin) - except Exception as e: # pylint: disable=W0703 - logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) - continue - - self._plugins[plugin] = instance - - if self._application is not None: - instance.set_application(self._application) - - await asyncio.create_task(self.plugin_install_task(plugin, instance)) - - @staticmethod - async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"): - try: - await instance.install() - logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}") - except Exception as e: # pylint: disable=W0703 - logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) - - async def uninstall_plugins(self) -> None: - for plugin in self._plugins.values(): - try: - await plugin.uninstall() - except Exception as e: # pylint: disable=W0703 - logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) - - -class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager): - """BOT 除自身外的生命周期管理类""" diff --git a/core/override/telegram.py b/core/override/telegram.py deleted file mode 100644 index d3fac5a..0000000 --- a/core/override/telegram.py +++ /dev/null @@ -1,117 +0,0 @@ -"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化""" -from typing import Any, AsyncIterable, Optional - -import httpcore -from httpx import ( - AsyncByteStream, - AsyncHTTPTransport as DefaultAsyncHTTPTransport, - Limits, - Response as DefaultResponse, - Timeout, -) -from telegram.request import HTTPXRequest as DefaultHTTPXRequest - -try: - import ujson as jsonlib -except ImportError: - import json as jsonlib - -__all__ = ("HTTPXRequest",) - - -class Response(DefaultResponse): - def json(self, **kwargs: Any) -> Any: - # noinspection PyProtectedMember - from httpx._utils import guess_json_utf - - if self.charset_encoding is None and self.content and len(self.content) > 3: - encoding = guess_json_utf(self.content) - if encoding is not None: - return jsonlib.loads(self.content.decode(encoding), **kwargs) - return jsonlib.loads(self.text, **kwargs) - - -# noinspection PyProtectedMember -class AsyncHTTPTransport(DefaultAsyncHTTPTransport): - async def handle_async_request(self, request) -> Response: - from httpx._transports.default import ( - map_httpcore_exceptions, - AsyncResponseStream, - ) - - if not isinstance(request.stream, AsyncByteStream): - raise AssertionError - - req = httpcore.Request( - method=request.method, - url=httpcore.URL( - scheme=request.url.raw_scheme, - host=request.url.raw_host, - port=request.url.port, - target=request.url.raw_path, - ), - headers=request.headers.raw, - content=request.stream, - extensions=request.extensions, - ) - with map_httpcore_exceptions(): - resp = await self._pool.handle_async_request(req) - - if not isinstance(resp.stream, AsyncIterable): - raise AssertionError - - return Response( - status_code=resp.status, - headers=resp.headers, - stream=AsyncResponseStream(resp.stream), - extensions=resp.extensions, - ) - - -class HTTPXRequest(DefaultHTTPXRequest): - def __init__( # pylint: disable=W0231 - self, - connection_pool_size: int = 1, - proxy_url: str = None, - read_timeout: Optional[float] = 5.0, - write_timeout: Optional[float] = 5.0, - connect_timeout: Optional[float] = 5.0, - pool_timeout: Optional[float] = 1.0, - http_version: str = "1.1", - ): - self._http_version = http_version - timeout = Timeout( - connect=connect_timeout, - read=read_timeout, - write=write_timeout, - pool=pool_timeout, - ) - limits = Limits( - max_connections=connection_pool_size, - max_keepalive_connections=connection_pool_size, - ) - if http_version not in ("1.1", "2"): - raise ValueError("`http_version` must be either '1.1' or '2'.") - http1 = http_version == "1.1" - self._client_kwargs = dict( - timeout=timeout, - proxies=proxy_url, - limits=limits, - transport=AsyncHTTPTransport(limits=limits), - http1=http1, - http2=not http1, - ) - - try: - self._client = self._build_client() - except ImportError as exc: - if "httpx[http2]" not in str(exc) and "httpx[socks]" not in str(exc): - raise exc - - if "httpx[socks]" in str(exc): - raise RuntimeError( - "To use Socks5 proxies, PTB must be installed via `pip install " "python-telegram-bot[socks]`." - ) from exc - raise RuntimeError( - "To use HTTP/2, PTB must be installed via `pip install " "python-telegram-bot[http2]`." - ) from exc diff --git a/core/plugin/__init__.py b/core/plugin/__init__.py index fcd11c6..e0b3051 100644 --- a/core/plugin/__init__.py +++ b/core/plugin/__init__.py @@ -1,8 +1,8 @@ """插件""" -from core.plugin._handler import conversation, error_handler, handler -from core.plugin._job import TimeType, job -from core.plugin._plugin import Plugin, PluginType, get_all_plugins +from gram_core.plugin._handler import conversation, error_handler, handler +from gram_core.plugin._job import TimeType, job +from gram_core.plugin._plugin import Plugin, PluginType, get_all_plugins __all__ = ( "Plugin", diff --git a/core/plugin/_funcs.py b/core/plugin/_funcs.py deleted file mode 100644 index 1f222ed..0000000 --- a/core/plugin/_funcs.py +++ /dev/null @@ -1,178 +0,0 @@ -from pathlib import Path -from typing import List, Optional, Union, TYPE_CHECKING - -import aiofiles -import httpx -from httpx import UnsupportedProtocol -from telegram import Chat, Message, ReplyKeyboardRemove, Update -from telegram.error import Forbidden, NetworkError -from telegram.ext import CallbackContext, ConversationHandler, Job - -from core.dependence.redisdb import RedisDB -from core.plugin._handler import conversation, handler -from utils.const import CACHE_DIR, REQUEST_HEADERS -from utils.error import UrlResourcesNotFoundError -from utils.helpers import sha1 -from utils.log import logger - -if TYPE_CHECKING: - from core.application import Application - -try: - import ujson as json -except ImportError: - import json - -__all__ = ( - "PluginFuncs", - "ConversationFuncs", -) - - -class PluginFuncs: - _application: "Optional[Application]" = None - - def set_application(self, application: "Application") -> None: - self._application = application - - @property - def application(self) -> "Application": - if self._application is None: - raise RuntimeError("No application was set for this PluginManager.") - return self._application - - async def _delete_message(self, context: CallbackContext) -> None: - job = context.job - message_id = job.data - chat_info = f"chat_id[{job.chat_id}]" - - try: - chat = await self.get_chat(job.chat_id) - full_name = chat.full_name - if full_name: - chat_info = f"{full_name}[{chat.id}]" - else: - chat_info = f"{chat.title}[{chat.id}]" - except (NetworkError, Forbidden) as exc: - logger.warning("获取 chat info 失败 %s", exc.message) - except Exception as exc: - logger.warning("获取 chat info 消息失败 %s", str(exc)) - - logger.debug("删除消息 %s message_id[%s]", chat_info, message_id) - - try: - # noinspection PyTypeChecker - await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id) - except NetworkError as exc: - logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) - except Forbidden as exc: - logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) - except Exception as exc: - logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc) - - async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat: - application = self.application - redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None) - - if not redis_db: - return await application.bot.get_chat(chat_id) - - qname = f"bot:chat:{chat_id}" - - data = await redis_db.client.get(qname) - if data: - json_data = json.loads(data) - return Chat.de_json(json_data, application.telegram.bot) - - chat_info = await application.telegram.bot.get_chat(chat_id) - await redis_db.client.set(qname, chat_info.to_json(), ex=expire) - return chat_info - - def add_delete_message_job( - self, - message: Optional[Union[int, Message]] = None, - *, - delay: int = 60, - name: Optional[str] = None, - chat: Optional[Union[int, Chat]] = None, - context: Optional[CallbackContext] = None, - ) -> Job: - """延迟删除消息""" - - if isinstance(message, Message): - if chat is None: - chat = message.chat_id - message = message.id - - chat = chat.id if isinstance(chat, Chat) else chat - - job_queue = self.application.job_queue or context.job_queue - - if job_queue is None or chat is None: - raise RuntimeError - - return job_queue.run_once( - callback=self._delete_message, - when=delay, - data=message, - name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message", - chat_id=chat, - job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"}, - ) - - @staticmethod - async def download_resource(url: str, return_path: bool = False) -> str: - url_sha1 = sha1(url) # url 的 hash 值 - pathed_url = Path(url) - - file_name = url_sha1 + pathed_url.suffix - file_path = CACHE_DIR.joinpath(file_name) - - if not file_path.exists(): # 若文件不存在,则下载 - async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client: - try: - response = await client.get(url) - except UnsupportedProtocol: - logger.error("链接不支持 url[%s]", url) - return "" - - if response.is_error: - logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code) - raise UrlResourcesNotFoundError(url) - - if response.status_code != 200: - logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code) - raise UrlResourcesNotFoundError(url) - - async with aiofiles.open(file_path, mode="wb") as f: - await f.write(response.content) - - logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path) - - return file_path if return_path else Path(file_path).as_uri() - - @staticmethod - def get_args(context: CallbackContext) -> List[str]: - args = context.args - match = context.match - - if args is None: - if match is not None and (command := match.groups()[0]): - temp = [] - command_parts = command.split(" ") - for command_part in command_parts: - if command_part: - temp.append(command_part) - return temp - return [] - if len(args) >= 1: - return args - return [] - - -class ConversationFuncs: - @conversation.fallback - @handler.command(command="cancel", block=False) - async def cancel(self, update: Update, _) -> int: - await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END diff --git a/core/plugin/_handler.py b/core/plugin/_handler.py deleted file mode 100644 index 223b631..0000000 --- a/core/plugin/_handler.py +++ /dev/null @@ -1,380 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from functools import wraps -from importlib import import_module -from typing import ( - Any, - Callable, - ClassVar, - Dict, - List, - Optional, - Pattern, - TYPE_CHECKING, - Type, - TypeVar, - Union, -) - -from pydantic import BaseModel - -# noinspection PyProtectedMember -from telegram._utils.defaultvalue import DEFAULT_TRUE - -# noinspection PyProtectedMember -from telegram._utils.types import DVInput -from telegram.ext import BaseHandler -from telegram.ext.filters import BaseFilter -from typing_extensions import ParamSpec - -from core.handler.callbackqueryhandler import CallbackQueryHandler -from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS - -if TYPE_CHECKING: - from core.builtins.dispatcher import AbstractDispatcher - -__all__ = ( - "handler", - "conversation", - "ConversationDataType", - "ConversationData", - "HandlerData", - "ErrorHandlerData", - "error_handler", -) - -P = ParamSpec("P") -T = TypeVar("T") -R = TypeVar("R") -UT = TypeVar("UT") - -HandlerType = TypeVar("HandlerType", bound=BaseHandler) -HandlerCls = Type[HandlerType] - -Module = import_module("telegram.ext") - -HANDLER_DATA_ATTR_NAME = "_handler_datas" -"""用于储存生成 handler 时所需要的参数(例如 block)的属性名""" - -ERROR_HANDLER_ATTR_NAME = "_error_handler_data" - -CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data" -"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名""" - -WRAPPER_ASSIGNMENTS = list( - set( - _WRAPPER_ASSIGNMENTS - + [ - HANDLER_DATA_ATTR_NAME, - ERROR_HANDLER_ATTR_NAME, - CONVERSATION_HANDLER_ATTR_NAME, - ] - ) -) - - -@dataclass(init=True) -class HandlerData: - type: Type[HandlerType] - admin: bool - kwargs: Dict[str, Any] - dispatcher: Optional[Type["AbstractDispatcher"]] = None - - -class _Handler: - _type: Type["HandlerType"] - - kwargs: Dict[str, Any] = {} - - def __init_subclass__(cls, **kwargs) -> None: - """用于获取 python-telegram-bot 中对应的 handler class""" - - handler_name = f"{cls.__name__.strip('_')}Handler" - - if handler_name == "CallbackQueryHandler": - cls._type = CallbackQueryHandler - return - - cls._type = getattr(Module, handler_name, None) - - def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None: - self.dispatcher = dispatcher - self.admin = admin - self.kwargs = kwargs - - def __call__(self, func: Callable[P, R]) -> Callable[P, R]: - """decorator实现,从 func 生成 Handler""" - - handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, []) - handler_datas.append( - HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher) - ) - setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas) - - return func - - -class _CallbackQuery(_Handler): - def __init__( - self, - pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None, - *, - block: DVInput[bool] = DEFAULT_TRUE, - admin: bool = False, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher) - - -class _ChatJoinRequest(_Handler): - def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): - super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher) - - -class _ChatMember(_Handler): - def __init__( - self, - chat_member_types: int = -1, - *, - block: DVInput[bool] = DEFAULT_TRUE, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher) - - -class _ChosenInlineResult(_Handler): - def __init__( - self, - block: DVInput[bool] = DEFAULT_TRUE, - *, - pattern: Union[str, Pattern] = None, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super().__init__(block=block, pattern=pattern, dispatcher=dispatcher) - - -class _Command(_Handler): - def __init__( - self, - command: Union[str, List[str]], - filters: "BaseFilter" = None, - *, - block: DVInput[bool] = DEFAULT_TRUE, - admin: bool = False, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super(_Command, self).__init__( - command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher - ) - - -class _InlineQuery(_Handler): - def __init__( - self, - pattern: Union[str, Pattern] = None, - chat_types: List[str] = None, - *, - block: DVInput[bool] = DEFAULT_TRUE, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher) - - -class _Message(_Handler): - def __init__( - self, - filters: BaseFilter, - *, - block: DVInput[bool] = DEFAULT_TRUE, - admin: bool = False, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ) -> None: - super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher) - - -class _PollAnswer(_Handler): - def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): - super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher) - - -class _Poll(_Handler): - def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): - super(_Poll, self).__init__(block=block, dispatcher=dispatcher) - - -class _PreCheckoutQuery(_Handler): - def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): - super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher) - - -class _Prefix(_Handler): - def __init__( - self, - prefix: str, - command: str, - filters: BaseFilter = None, - *, - block: DVInput[bool] = DEFAULT_TRUE, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super(_Prefix, self).__init__( - prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher - ) - - -class _ShippingQuery(_Handler): - def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): - super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher) - - -class _StringCommand(_Handler): - def __init__( - self, - command: str, - *, - admin: bool = False, - block: DVInput[bool] = DEFAULT_TRUE, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher) - - -class _StringRegex(_Handler): - def __init__( - self, - pattern: Union[str, Pattern], - *, - block: DVInput[bool] = DEFAULT_TRUE, - admin: bool = False, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher) - - -class _Type(_Handler): - # noinspection PyShadowingBuiltins - def __init__( - self, - type: Type[UT], # pylint: disable=W0622 - strict: bool = False, - *, - block: DVInput[bool] = DEFAULT_TRUE, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): # pylint: disable=redefined-builtin - super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher) - - -# noinspection PyPep8Naming -class handler(_Handler): - callback_query = _CallbackQuery - chat_join_request = _ChatJoinRequest - chat_member = _ChatMember - chosen_inline_result = _ChosenInlineResult - command = _Command - inline_query = _InlineQuery - message = _Message - poll_answer = _PollAnswer - pool = _Poll - pre_checkout_query = _PreCheckoutQuery - prefix = _Prefix - shipping_query = _ShippingQuery - string_command = _StringCommand - string_regex = _StringRegex - type = _Type - - def __init__( - self, - handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]], - *, - admin: bool = False, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - **kwargs: P.kwargs, - ) -> None: - self._type = handler_type - super().__init__(admin=admin, dispatcher=dispatcher, **kwargs) - - -class ConversationDataType(Enum): - """conversation handler 的类型""" - - Entry = "entry" - State = "state" - Fallback = "fallback" - - -class ConversationData(BaseModel): - """用于储存 conversation handler 的数据""" - - type: ConversationDataType - state: Optional[Any] = None - - -class _ConversationType: - _type: ClassVar[ConversationDataType] - - def __init_subclass__(cls, **kwargs) -> None: - cls._type = ConversationDataType(cls.__name__.lstrip("_").lower()) - - -def _entry(func: Callable[P, R]) -> Callable[P, R]: - setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry)) - - @wraps(func, assigned=WRAPPER_ASSIGNMENTS) - def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: - return func(*args, **kwargs) - - return wrapped - - -class _State(_ConversationType): - def __init__(self, state: Any) -> None: - self.state = state - - def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]: - setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state)) - return func - - -def _fallback(func: Callable[P, R]) -> Callable[P, R]: - setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback)) - - @wraps(func, assigned=WRAPPER_ASSIGNMENTS) - def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: - return func(*args, **kwargs) - - return wrapped - - -# noinspection PyPep8Naming -class conversation(_Handler): - entry_point = _entry - state = _State - fallback = _fallback - - -@dataclass(init=True) -class ErrorHandlerData: - block: bool - func: Optional[Callable] = None - - -# noinspection PyPep8Naming -class error_handler: - _func: Callable[P, R] - - def __init__( - self, - *, - block: bool = DEFAULT_TRUE, - ): - self._block = block - - def __call__(self, func: Callable[P, T]) -> Callable[P, T]: - self._func = func - wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self) - - handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, []) - handler_datas.append(ErrorHandlerData(block=self._block)) - setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas) - - return self._func diff --git a/core/plugin/_job.py b/core/plugin/_job.py deleted file mode 100644 index 393ad87..0000000 --- a/core/plugin/_job.py +++ /dev/null @@ -1,173 +0,0 @@ -"""插件""" -import datetime -import re -from dataclasses import dataclass, field -from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union - -# noinspection PyProtectedMember -from telegram._utils.types import JSONDict - -# noinspection PyProtectedMember -from telegram.ext._utils.types import JobCallback -from typing_extensions import ParamSpec - -if TYPE_CHECKING: - from core.builtins.dispatcher import AbstractDispatcher - -__all__ = ["TimeType", "job", "JobData"] - -P = ParamSpec("P") -T = TypeVar("T") -R = TypeVar("R") - -TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time] - -_JOB_ATTR_NAME = "_job_data" - - -@dataclass(init=True) -class JobData: - name: str - data: Any - chat_id: int - user_id: int - type: str - job_kwargs: JSONDict = field(default_factory=dict) - kwargs: JSONDict = field(default_factory=dict) - dispatcher: Optional[Type["AbstractDispatcher"]] = None - - -class _Job: - kwargs: Dict = {} - - def __init__( - self, - name: str = None, - data: object = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - *, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - **kwargs, - ): - self.name = name - self.data = data - self.chat_id = chat_id - self.user_id = user_id - self.job_kwargs = {} if job_kwargs is None else job_kwargs - self.kwargs = kwargs - if dispatcher is None: - from core.builtins.dispatcher import JobDispatcher - - dispatcher = JobDispatcher - - self.dispatcher = dispatcher - - def __call__(self, func: JobCallback) -> JobCallback: - data = JobData( - name=self.name, - data=self.data, - chat_id=self.chat_id, - user_id=self.user_id, - job_kwargs=self.job_kwargs, - kwargs=self.kwargs, - type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"), - dispatcher=self.dispatcher, - ) - if hasattr(func, _JOB_ATTR_NAME): - job_datas = getattr(func, _JOB_ATTR_NAME) - job_datas.append(data) - setattr(func, _JOB_ATTR_NAME, job_datas) - else: - setattr(func, _JOB_ATTR_NAME, [data]) - return func - - -class _RunOnce(_Job): - def __init__( - self, - when: TimeType, - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - *, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when) - - -class _RunRepeating(_Job): - def __init__( - self, - interval: Union[float, datetime.timedelta], - first: TimeType = None, - last: TimeType = None, - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - *, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super().__init__( - name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last - ) - - -class _RunMonthly(_Job): - def __init__( - self, - when: datetime.time, - day: int, - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - *, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day) - - -class _RunDaily(_Job): - def __init__( - self, - time: datetime.time, - days: Tuple[int, ...] = tuple(range(7)), - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - *, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days) - - -class _RunCustom(_Job): - def __init__( - self, - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - *, - dispatcher: Optional[Type["AbstractDispatcher"]] = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher) - - -# noinspection PyPep8Naming -class job: - run_once = _RunOnce - run_repeating = _RunRepeating - run_monthly = _RunMonthly - run_daily = _RunDaily - run_custom = _RunCustom diff --git a/core/plugin/_plugin.py b/core/plugin/_plugin.py deleted file mode 100644 index c5b655e..0000000 --- a/core/plugin/_plugin.py +++ /dev/null @@ -1,314 +0,0 @@ -"""插件""" -import asyncio -from abc import ABC -from dataclasses import asdict -from datetime import timedelta -from functools import partial, wraps -from itertools import chain -from multiprocessing import RLock as Lock -from types import MethodType -from typing import ( - Any, - ClassVar, - Dict, - Iterable, - List, - Optional, - TYPE_CHECKING, - Type, - TypeVar, - Union, -) - -from pydantic import BaseModel -from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler -from typing_extensions import ParamSpec - -from core.handler.adminhandler import AdminHandler -from core.plugin._funcs import ConversationFuncs, PluginFuncs -from core.plugin._handler import ConversationDataType -from utils.const import WRAPPER_ASSIGNMENTS -from utils.helpers import isabstract -from utils.log import logger - -if TYPE_CHECKING: - from core.application import Application - from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData - from core.plugin._job import JobData - from multiprocessing.synchronize import RLock as LockType - -__all__ = ("Plugin", "PluginType", "get_all_plugins") - -wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS) -P = ParamSpec("P") -T = TypeVar("T") -R = TypeVar("R") - -HandlerType = TypeVar("HandlerType", bound=BaseHandler) - -_HANDLER_DATA_ATTR_NAME = "_handler_datas" -"""用于储存生成 handler 时所需要的参数(例如 block)的属性名""" - -_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data" -"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名""" - -_ERROR_HANDLER_ATTR_NAME = "_error_handler_data" - -_JOB_ATTR_NAME = "_job_data" - -_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"] - - -class _Plugin(PluginFuncs): - """插件""" - - _lock: ClassVar["LockType"] = Lock() - _asyncio_lock: ClassVar["LockType"] = asyncio.Lock() - _installed: bool = False - - _handlers: Optional[List[HandlerType]] = None - _error_handlers: Optional[List["ErrorHandlerData"]] = None - _jobs: Optional[List[Job]] = None - _application: "Optional[Application]" = None - - def set_application(self, application: "Application") -> None: - self._application = application - - @property - def application(self) -> "Application": - if self._application is None: - raise RuntimeError("No application was set for this Plugin.") - return self._application - - @property - def handlers(self) -> List[HandlerType]: - """该插件的所有 handler""" - with self._lock: - if self._handlers is None: - self._handlers = [] - - for attr in dir(self): - if ( - not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) - and isinstance(func := getattr(self, attr), MethodType) - and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, [])) - ): - for data in datas: - data: "HandlerData" - if data.admin: - self._handlers.append( - AdminHandler( - handler=data.type( - callback=func, - **data.kwargs, - ), - application=self.application, - ) - ) - else: - self._handlers.append( - data.type( - callback=func, - **data.kwargs, - ) - ) - return self._handlers - - @property - def error_handlers(self) -> List["ErrorHandlerData"]: - with self._lock: - if self._error_handlers is None: - self._error_handlers = [] - for attr in dir(self): - if ( - not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) - and isinstance(func := getattr(self, attr), MethodType) - and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, [])) - ): - for data in datas: - data: "ErrorHandlerData" - data.func = func - self._error_handlers.append(data) - - return self._error_handlers - - def _install_jobs(self) -> None: - if self._jobs is None: - self._jobs = [] - for attr in dir(self): - # noinspection PyUnboundLocalVariable - if ( - not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) - and isinstance(func := getattr(self, attr), MethodType) - and (datas := getattr(func, _JOB_ATTR_NAME, [])) - ): - for data in datas: - data: "JobData" - self._jobs.append( - getattr(self.application.telegram.job_queue, data.type)( - callback=func, - **data.kwargs, - **{ - key: value - for key, value in asdict(data).items() - if key not in ["type", "kwargs", "dispatcher"] - }, - ) - ) - - @property - def jobs(self) -> List[Job]: - with self._lock: - if self._jobs is None: - self._jobs = [] - self._install_jobs() - return self._jobs - - async def initialize(self) -> None: - """初始化插件""" - - async def shutdown(self) -> None: - """销毁插件""" - - async def install(self) -> None: - """安装""" - group = id(self) - if not self._installed: - await self.initialize() - # initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题 - async with self._asyncio_lock: - self._install_jobs() - - for h in self.handlers: - if not isinstance(h, TypeHandler): - self.application.telegram.add_handler(h, group) - else: - self.application.telegram.add_handler(h, -1) - - for h in self.error_handlers: - self.application.telegram.add_error_handler(h.func, h.block) - self._installed = True - - async def uninstall(self) -> None: - """卸载""" - group = id(self) - - with self._lock: - if self._installed: - if group in self.application.telegram.handlers: - del self.application.telegram.handlers[id(self)] - - for h in self.handlers: - if isinstance(h, TypeHandler): - self.application.telegram.remove_handler(h, -1) - for h in self.error_handlers: - self.application.telegram.remove_error_handler(h.func) - - for j in self.application.telegram.job_queue.jobs(): - j.schedule_removal() - await self.shutdown() - self._installed = False - - async def reload(self) -> None: - await self.uninstall() - await self.install() - - -class _Conversation(_Plugin, ConversationFuncs, ABC): - """Conversation类""" - - # noinspection SpellCheckingInspection - class Config(BaseModel): - allow_reentry: bool = False - per_chat: bool = True - per_user: bool = True - per_message: bool = False - conversation_timeout: Optional[Union[float, timedelta]] = None - name: Optional[str] = None - map_to_parent: Optional[Dict[object, object]] = None - block: bool = False - - def __init_subclass__(cls, **kwargs): - cls._conversation_kwargs = kwargs - super(_Conversation, cls).__init_subclass__() - return cls - - @property - def handlers(self) -> List[HandlerType]: - with self._lock: - if self._handlers is None: - self._handlers = [] - - entry_points: List[HandlerType] = [] - states: Dict[Any, List[HandlerType]] = {} - fallbacks: List[HandlerType] = [] - for attr in dir(self): - if ( - not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) - and (func := getattr(self, attr, None)) is not None - and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, [])) - ): - conversation_data: "ConversationData" - - handlers: List[HandlerType] = [] - for data in datas: - if data.admin: - handlers.append( - AdminHandler( - handler=data.type( - callback=func, - **data.kwargs, - ), - application=self.application, - ) - ) - else: - handlers.append( - data.type( - callback=func, - **data.kwargs, - ) - ) - - if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None): - if (_type := conversation_data.type) is ConversationDataType.Entry: - entry_points.extend(handlers) - elif _type is ConversationDataType.State: - if conversation_data.state in states: - states[conversation_data.state].extend(handlers) - else: - states[conversation_data.state] = handlers - elif _type is ConversationDataType.Fallback: - fallbacks.extend(handlers) - else: - self._handlers.extend(handlers) - else: - self._handlers.extend(handlers) - if entry_points and states and fallbacks: - kwargs = self._conversation_kwargs - kwargs.update(self.Config().dict()) - self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs)) - else: - temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks} - reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items())) - logger.warning( - "'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason) - ) - return self._handlers - - -class Plugin(_Plugin, ABC): - """插件""" - - Conversation = _Conversation - - -PluginType = TypeVar("PluginType", bound=_Plugin) - - -def get_all_plugins() -> Iterable[Type[PluginType]]: - """获取所有 Plugin 的子类""" - return filter( - lambda x: x.__name__[0] != "_" and not isabstract(x), - chain(Plugin.__subclasses__(), _Conversation.__subclasses__()), - ) diff --git a/core/ratelimiter.py b/core/ratelimiter.py deleted file mode 100644 index 0e6d369..0000000 --- a/core/ratelimiter.py +++ /dev/null @@ -1,67 +0,0 @@ -import asyncio -import contextlib -from typing import Callable, Coroutine, Any, Union, List, Dict, Optional, TypeVar, Type - -from telegram.error import RetryAfter -from telegram.ext import BaseRateLimiter, ApplicationHandlerStop - -from utils.log import logger - -JSONDict: Type[dict[str, Any]] = Dict[str, Any] -RL_ARGS = TypeVar("RL_ARGS") - - -class RateLimiter(BaseRateLimiter[int]): - _lock = asyncio.Lock() - __slots__ = ( - "_limiter_info", - "_retry_after_event", - ) - - def __init__(self): - self._limiter_info: Dict[Union[str, int], float] = {} - self._retry_after_event = asyncio.Event() - self._retry_after_event.set() - - async def process_request( - self, - callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, List[JSONDict]]]], - args: Any, - kwargs: Dict[str, Any], - endpoint: str, - data: Dict[str, Any], - rate_limit_args: Optional[RL_ARGS], - ) -> Union[bool, JSONDict, List[JSONDict]]: - chat_id = data.get("chat_id") - - with contextlib.suppress(ValueError, TypeError): - chat_id = int(chat_id) - - loop = asyncio.get_running_loop() - time = loop.time() - - await self._retry_after_event.wait() - - async with self._lock: - chat_limit_time = self._limiter_info.get(chat_id) - if chat_limit_time: - if time >= chat_limit_time: - raise ApplicationHandlerStop - del self._limiter_info[chat_id] - - try: - return await callback(*args, **kwargs) - except RetryAfter as exc: - logger.warning("chat_id[%s] 触发洪水限制 当前被服务器限制 retry_after[%s]秒", chat_id, exc.retry_after) - self._limiter_info[chat_id] = time + (exc.retry_after * 2) - sleep = exc.retry_after + 0.1 - self._retry_after_event.clear() - await asyncio.sleep(sleep) - finally: - self._retry_after_event.set() - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - pass diff --git a/core/services/cookies/cache.py b/core/services/cookies/cache.py index a666aee..8f99a01 100644 --- a/core/services/cookies/cache.py +++ b/core/services/cookies/cache.py @@ -1,97 +1,3 @@ -from typing import List, Union - -from core.base_service import BaseService -from core.basemodel import RegionEnum -from core.dependence.redisdb import RedisDB -from core.services.cookies.error import CookiesCachePoolExhausted -from utils.error import RegionNotFoundError +from gram_core.services.cookies.cache import PublicCookiesCache __all__ = ("PublicCookiesCache",) - - -class PublicCookiesCache(BaseService.Component): - """使用优先级(score)进行排序,对使用次数最少的Cookies进行审核""" - - def __init__(self, redis: RedisDB): - self.client = redis.client - self.score_qname = "cookie:public" - self.user_times_qname = "cookie:public:times" - self.end = 20 - self.user_times_ttl = 60 * 60 * 24 - - def get_public_cookies_queue_name(self, region: RegionEnum): - if region == RegionEnum.HYPERION: - return f"{self.score_qname}:yuanshen" - if region == RegionEnum.HOYOLAB: - return f"{self.score_qname}:genshin" - raise RegionNotFoundError(region.name) - - async def putback_public_cookies(self, uid: int, region: RegionEnum): - """重新添加单个到缓存列表 - :param uid: - :param region: - :return: - """ - qname = self.get_public_cookies_queue_name(region) - score_maps = {f"{uid}": 0} - result = await self.client.zrem(qname, f"{uid}") - if result == 1: - await self.client.zadd(qname, score_maps) - return result - - async def add_public_cookies(self, uid: Union[List[int], int], region: RegionEnum): - """单个或批量添加到缓存列表 - :param uid: - :param region: - :return: 成功返回列表大小 - """ - qname = self.get_public_cookies_queue_name(region) - if isinstance(uid, int): - score_maps = {f"{uid}": 0} - elif isinstance(uid, list): - score_maps = {f"{i}": 0 for i in uid} - else: - raise TypeError("uid variable type error") - async with self.client.pipeline(transaction=True) as pipe: - # nx:只添加新元素。不要更新已经存在的元素 - await pipe.zadd(qname, score_maps, nx=True) - await pipe.zcard(qname) - add, count = await pipe.execute() - return int(add), count - - async def get_public_cookies(self, region: RegionEnum): - """从缓存列表获取 - :param region: - :return: - """ - qname = self.get_public_cookies_queue_name(region) - scores = await self.client.zrange(qname, 0, self.end, withscores=True, score_cast_func=int) - if len(scores) <= 0: - raise CookiesCachePoolExhausted - key = scores[0][0] - score = scores[0][1] - async with self.client.pipeline(transaction=True) as pipe: - await pipe.zincrby(qname, 1, key) - await pipe.execute() - return int(key), score + 1 - - async def delete_public_cookies(self, uid: int, region: RegionEnum): - qname = self.get_public_cookies_queue_name(region) - async with self.client.pipeline(transaction=True) as pipe: - await pipe.zrem(qname, uid) - return await pipe.execute() - - async def get_public_cookies_count(self, limit: bool = True): - async with self.client.pipeline(transaction=True) as pipe: - if limit: - await pipe.zcount(0, self.end) - else: - await pipe.zcard(self.score_qname) - return await pipe.execute() - - async def incr_by_user_times(self, user_id: Union[List[int], int], amount: int = 1): - qname = f"{self.user_times_qname}:{user_id}" - times = await self.client.incrby(qname, amount) - if times <= 1: - await self.client.expire(qname, self.user_times_ttl) - return times diff --git a/core/services/cookies/error.py b/core/services/cookies/error.py index 239110a..cf0d1d2 100644 --- a/core/services/cookies/error.py +++ b/core/services/cookies/error.py @@ -1,12 +1,3 @@ -class CookieServiceError(Exception): - pass +from gram_core.services.cookies.error import CookieServiceError, CookiesCachePoolExhausted, TooManyRequestPublicCookies - -class CookiesCachePoolExhausted(CookieServiceError): - def __init__(self): - super().__init__("Cookies cache pool is exhausted") - - -class TooManyRequestPublicCookies(CookieServiceError): - def __init__(self, user_id): - super().__init__(f"{user_id} too many request public cookies") +__all__ = ("CookieServiceError", "CookiesCachePoolExhausted", "TooManyRequestPublicCookies") diff --git a/core/services/cookies/models.py b/core/services/cookies/models.py index 0cafa34..79dc3ac 100644 --- a/core/services/cookies/models.py +++ b/core/services/cookies/models.py @@ -1,39 +1,3 @@ -import enum -from typing import Optional, Dict - -from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index - -from core.basemodel import RegionEnum +from gram_core.services.cookies.models import Cookies, CookiesDataBase, CookiesStatusEnum __all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum") - - -class CookiesStatusEnum(int, enum.Enum): - STATUS_SUCCESS = 0 - INVALID_COOKIES = 1 - TOO_MANY_REQUESTS = 2 - - -class Cookies(SQLModel): - __table_args__ = ( - Index("index_user_account", "user_id", "account_id", unique=True), - dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), - ) - id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True)) - user_id: int = Field( - sa_column=Column(BigInteger()), - ) - account_id: int = Field( - default=None, - sa_column=Column( - BigInteger(), - ), - ) - data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON)) - status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum))) - region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum))) - is_share: Optional[bool] = Field(sa_column=Column(Boolean)) - - -class CookiesDataBase(Cookies, table=True): - __tablename__ = "cookies" diff --git a/core/services/cookies/repositories.py b/core/services/cookies/repositories.py index b0cc689..97e8993 100644 --- a/core/services/cookies/repositories.py +++ b/core/services/cookies/repositories.py @@ -1,55 +1,3 @@ -from typing import Optional, List - -from sqlmodel import select - -from core.base_service import BaseService -from core.basemodel import RegionEnum -from core.dependence.database import Database -from core.services.cookies.models import CookiesDataBase as Cookies -from core.sqlmodel.session import AsyncSession +from gram_core.services.cookies.repositories import CookiesRepository __all__ = ("CookiesRepository",) - - -class CookiesRepository(BaseService.Component): - def __init__(self, database: Database): - self.engine = database.engine - - async def get( - self, - user_id: int, - account_id: Optional[int] = None, - region: Optional[RegionEnum] = None, - ) -> Optional[Cookies]: - async with AsyncSession(self.engine) as session: - statement = select(Cookies).where(Cookies.user_id == user_id) - if account_id is not None: - statement = statement.where(Cookies.account_id == account_id) - if region is not None: - statement = statement.where(Cookies.region == region) - results = await session.exec(statement) - return results.first() - - async def add(self, cookies: Cookies) -> None: - async with AsyncSession(self.engine) as session: - session.add(cookies) - await session.commit() - - async def update(self, cookies: Cookies) -> Cookies: - async with AsyncSession(self.engine) as session: - session.add(cookies) - await session.commit() - await session.refresh(cookies) - return cookies - - async def delete(self, cookies: Cookies) -> None: - async with AsyncSession(self.engine) as session: - await session.delete(cookies) - await session.commit() - - async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]: - async with AsyncSession(self.engine) as session: - statement = select(Cookies).where(Cookies.region == region) - results = await session.exec(statement) - cookies = results.all() - return cookies diff --git a/core/services/cookies/services.py b/core/services/cookies/services.py index 143e5b9..2aaad2e 100644 --- a/core/services/cookies/services.py +++ b/core/services/cookies/services.py @@ -1,159 +1,80 @@ -from typing import List, Optional +from gram_core.base_service import BaseService +from gram_core.basemodel import RegionEnum +from gram_core.services.cookies.error import CookieServiceError +from gram_core.services.cookies.models import CookiesStatusEnum, CookiesDataBase as Cookies +from gram_core.services.cookies.services import ( + CookiesService, + PublicCookiesService as BasePublicCookiesService, + NeedContinue, +) from simnet import StarRailClient, Region, Game -from simnet.errors import InvalidCookies, BadRequest as SimnetBadRequest, TooManyRequests +from simnet.errors import InvalidCookies, TooManyRequests, BadRequest as SimnetBadRequest -from core.base_service import BaseService -from core.basemodel import RegionEnum -from core.services.cookies.cache import PublicCookiesCache -from core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies -from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum -from core.services.cookies.repositories import CookiesRepository from utils.log import logger __all__ = ("CookiesService", "PublicCookiesService") -class CookiesService(BaseService): - def __init__(self, cookies_repository: CookiesRepository) -> None: - self._repository: CookiesRepository = cookies_repository - - async def update(self, cookies: Cookies): - await self._repository.update(cookies) - - async def add(self, cookies: Cookies): - await self._repository.add(cookies) - - async def get( - self, - user_id: int, - account_id: Optional[int] = None, - region: Optional[RegionEnum] = None, - ) -> Optional[Cookies]: - return await self._repository.get(user_id, account_id, region) - - async def delete(self, cookies: Cookies) -> None: - return await self._repository.delete(cookies) - - -class PublicCookiesService(BaseService): - def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache): - self._cache = public_cookies_cache - self._repository: CookiesRepository = cookies_repository - self.count: int = 0 - self.user_times_limiter = 3 * 3 - - async def initialize(self) -> None: - logger.info("正在初始化公共Cookies池") - await self.refresh() - logger.success("刷新公共Cookies池成功") - - async def refresh(self): - """刷新公共Cookies 定时任务 - :return: - """ - user_list: List[int] = [] - cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2 - for cookies in cookies_list: - if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS: - user_list.append(cookies.user_id) - if len(user_list) > 0: - add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION) - logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count) - user_list.clear() - cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB) - for cookies in cookies_list: - if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS: - user_list.append(cookies.user_id) - if len(user_list) > 0: - add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB) - logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count) - - async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL): - """获取公共Cookies - :param user_id: 用户ID - :param region: 注册的服务器 - :return: - """ - user_times = await self._cache.incr_by_user_times(user_id) - if int(user_times) > self.user_times_limiter: - logger.warning("用户 %s 使用公共Cookies次数已经到达上限", user_id) - raise TooManyRequestPublicCookies(user_id) - while True: - public_id, count = await self._cache.get_public_cookies(region) - cookies = await self._repository.get(public_id, region=region) - if cookies is None: - await self._cache.delete_public_cookies(public_id, region) - continue - if region == RegionEnum.HYPERION: - client = StarRailClient(cookies=cookies.data, region=Region.CHINESE) - elif region == RegionEnum.HOYOLAB: - client = StarRailClient(cookies=cookies.data, region=Region.OVERSEAS, lang="zh-cn") +class PublicCookiesService(BaseService, BasePublicCookiesService): + async def check_public_cookie(self, region: RegionEnum, cookies: Cookies, public_id: int): + if region == RegionEnum.HYPERION: + client = StarRailClient(cookies=cookies.data, region=Region.CHINESE) + elif region == RegionEnum.HOYOLAB: + client = StarRailClient(cookies=cookies.data, region=Region.OVERSEAS, lang="zh-cn") + else: + raise CookieServiceError + try: + if client.account_id is None: + raise RuntimeError("account_id not found") + record_cards = await client.get_record_cards() + for record_card in record_cards: + if record_card.game == Game.STARRAIL: + await client.get_starrail_user(record_card.uid) + break else: - raise CookieServiceError - try: - if client.account_id is None: - raise RuntimeError("account_id not found") - record_cards = await client.get_record_cards() - for record_card in record_cards: - if record_card.game == Game.STARRAIL: - await client.get_starrail_user(record_card.uid) + accounts = await client.get_game_accounts() + for account in accounts: + if account.game == Game.STARRAIL: + await client.get_starrail_user(account.uid) break - else: - accounts = await client.get_game_accounts() - for account in accounts: - if account.game == Game.STARRAIL: - await client.get_starrail_user(account.uid) - break - except InvalidCookies as exc: - if exc.ret_code in (10001, -100): - logger.warning("用户 [%s] Cookies无效", public_id) - elif exc.ret_code == 10103: - logger.warning("用户 [%s] Cookies有效,但没有绑定到游戏帐户", public_id) - else: - logger.warning("Cookies无效 ") - logger.exception(exc) + except InvalidCookies as exc: + if exc.ret_code in (10001, -100): + logger.warning("用户 [%s] Cookies无效", public_id) + elif exc.ret_code == 10103: + logger.warning("用户 [%s] Cookies有效,但没有绑定到游戏帐户", public_id) + else: + logger.warning("Cookies无效 ") + logger.exception(exc) + cookies.status = CookiesStatusEnum.INVALID_COOKIES + await self._repository.update(cookies) + await self._cache.delete_public_cookies(cookies.user_id, region) + raise NeedContinue + except TooManyRequests: + logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id) + cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS + await self._repository.update(cookies) + await self._cache.delete_public_cookies(cookies.user_id, region) + raise NeedContinue + except SimnetBadRequest as exc: + if "invalid content type" in exc.message: + raise exc + if exc.ret_code == 1034: + logger.warning("用户 [%s] 触发验证", public_id) + else: + logger.warning("用户 [%s] 获取账号信息发生错误,错误信息为", public_id) + logger.exception(exc) + await self._cache.delete_public_cookies(cookies.user_id, region) + raise NeedContinue + except RuntimeError as exc: + if "account_id not found" in str(exc): cookies.status = CookiesStatusEnum.INVALID_COOKIES await self._repository.update(cookies) await self._cache.delete_public_cookies(cookies.user_id, region) - continue - except TooManyRequests: - logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id) - cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS - await self._repository.update(cookies) - await self._cache.delete_public_cookies(cookies.user_id, region) - continue - except SimnetBadRequest as exc: - if "invalid content type" in exc.message: - raise exc - if exc.ret_code == 1034: - logger.warning("用户 [%s] 触发验证", public_id) - else: - logger.warning("用户 [%s] 获取账号信息发生错误,错误信息为", public_id) - logger.exception(exc) - await self._cache.delete_public_cookies(cookies.user_id, region) - continue - except RuntimeError as exc: - if "account_id not found" in str(exc): - cookies.status = CookiesStatusEnum.INVALID_COOKIES - await self._repository.update(cookies) - await self._cache.delete_public_cookies(cookies.user_id, region) - continue - raise exc - except Exception as exc: - await self._cache.delete_public_cookies(cookies.user_id, region) - raise exc - finally: - await client.shutdown() - logger.info("用户 user_id[%s] 请求用户 user_id[%s] 的公共Cookies 该Cookies使用次数为%s次 ", user_id, public_id, count) - return cookies - - async def undo(self, user_id: int, cookies: Optional[Cookies] = None, status: Optional[CookiesStatusEnum] = None): - await self._cache.incr_by_user_times(user_id, -1) - if cookies is not None and status is not None: - cookies.status = status - await self._repository.update(cookies) - await self._cache.delete_public_cookies(cookies.user_id, cookies.region) - logger.info("用户 user_id[%s] 反馈用户 user_id[%s] 的Cookies状态为 %s", user_id, cookies.user_id, status.name) - else: - logger.info("用户 user_id[%s] 撤销一次公共Cookies计数", user_id) + raise NeedContinue + raise exc + except Exception as exc: + await self._cache.delete_public_cookies(cookies.user_id, region) + raise exc + finally: + await client.shutdown() diff --git a/core/services/devices/models.py b/core/services/devices/models.py index a7a9676..9ecd615 100644 --- a/core/services/devices/models.py +++ b/core/services/devices/models.py @@ -1,23 +1,3 @@ -from typing import Optional - -from sqlmodel import SQLModel, Field, Column, Integer, BigInteger +from gram_core.services.devices.models import Devices, DevicesDataBase __all__ = ("Devices", "DevicesDataBase") - - -class Devices(SQLModel): - __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True)) - account_id: int = Field( - default=None, - sa_column=Column( - BigInteger(), - ), - ) - device_id: str = Field() - device_fp: str = Field() - device_name: Optional[str] = Field(default=None) - - -class DevicesDataBase(Devices, table=True): - __tablename__ = "devices" diff --git a/core/services/devices/repositories.py b/core/services/devices/repositories.py index 23153b7..cc45b13 100644 --- a/core/services/devices/repositories.py +++ b/core/services/devices/repositories.py @@ -1,41 +1,3 @@ -from typing import Optional - -from sqlmodel import select - -from core.base_service import BaseService -from core.dependence.database import Database -from core.services.devices.models import DevicesDataBase as Devices -from core.sqlmodel.session import AsyncSession +from gram_core.services.devices.repositories import DevicesRepository __all__ = ("DevicesRepository",) - - -class DevicesRepository(BaseService.Component): - def __init__(self, database: Database): - self.engine = database.engine - - async def get( - self, - account_id: int, - ) -> Optional[Devices]: - async with AsyncSession(self.engine) as session: - statement = select(Devices).where(Devices.account_id == account_id) - results = await session.exec(statement) - return results.first() - - async def add(self, devices: Devices) -> None: - async with AsyncSession(self.engine) as session: - session.add(devices) - await session.commit() - - async def update(self, devices: Devices) -> Devices: - async with AsyncSession(self.engine) as session: - session.add(devices) - await session.commit() - await session.refresh(devices) - return devices - - async def delete(self, devices: Devices) -> None: - async with AsyncSession(self.engine) as session: - await session.delete(devices) - await session.commit() diff --git a/core/services/devices/services.py b/core/services/devices/services.py index 76bd7fc..f11d888 100644 --- a/core/services/devices/services.py +++ b/core/services/devices/services.py @@ -1,25 +1,3 @@ -from typing import Optional +from gram_core.services.devices.services import DevicesService -from core.base_service import BaseService -from core.services.devices.repositories import DevicesRepository -from core.services.devices.models import DevicesDataBase as Devices - - -class DevicesService(BaseService): - def __init__(self, devices_repository: DevicesRepository) -> None: - self._repository: DevicesRepository = devices_repository - - async def update(self, devices: Devices): - await self._repository.update(devices) - - async def add(self, devices: Devices): - await self._repository.add(devices) - - async def get( - self, - account_id: int, - ) -> Optional[Devices]: - return await self._repository.get(account_id) - - async def delete(self, devices: Devices) -> None: - return await self._repository.delete(devices) +__all__ = ("DevicesService",) diff --git a/core/services/players/error.py b/core/services/players/error.py index 623bed8..ec94f63 100644 --- a/core/services/players/error.py +++ b/core/services/players/error.py @@ -1,2 +1,3 @@ -class PlayerNotFoundError(Exception): - pass +from gram_core.services.players.error import PlayerNotFoundError + +__all__ = ("PlayerNotFoundError",) diff --git a/core/services/players/models.py b/core/services/players/models.py index 32aedd2..6bd9ebf 100644 --- a/core/services/players/models.py +++ b/core/services/players/models.py @@ -1,96 +1,3 @@ -from datetime import datetime -from typing import Optional - -from pydantic import BaseModel, BaseSettings -from sqlalchemy import TypeDecorator -from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime - -from core.basemodel import RegionEnum - -try: - import ujson as jsonlib -except ImportError: - import json as jsonlib +from gram_core.services.players.models import Player, PlayersDataBase, PlayerInfo, PlayerInfoSQLModel __all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel") - - -class Player(SQLModel): - __table_args__ = ( - Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True), - dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), - ) - id: Optional[int] = Field( - default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) - ) - user_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) - account_id: int = Field(default=None, primary_key=True, sa_column=Column(BigInteger())) - player_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) - region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum))) - is_chosen: Optional[bool] = Field(sa_column=Column(Boolean)) - - -class PlayersDataBase(Player, table=True): - __tablename__ = "players" - - -class ExtraPlayerInfo(BaseModel): - class Config(BaseSettings.Config): - json_loads = jsonlib.loads - json_dumps = jsonlib.dumps - - waifu_id: Optional[int] = None - - -class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223 - impl = VARCHAR(length=521) - - cache_ok = True - - def process_bind_param(self, value, dialect): - """ - :param value: ExtraPlayerInfo | obj | None - :param dialect: - :return: - """ - if value is not None: - if isinstance(value, ExtraPlayerInfo): - return value.json() - raise TypeError - return value - - def process_result_value(self, value, dialect): - """ - :param value: str | obj | None - :param dialect: - :return: - """ - if value is not None: - return ExtraPlayerInfo.parse_raw(value) - return None - - -class PlayerInfo(SQLModel): - __table_args__ = ( - Index("index_user_account_player", "user_id", "player_id", unique=True), - dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), - ) - id: Optional[int] = Field( - default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) - ) - user_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) - player_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) - nickname: Optional[str] = Field() - signature: Optional[str] = Field() - hand_image: Optional[int] = Field() - name_card: Optional[int] = Field() - extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType)) - create_time: Optional[datetime] = Field( - sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102 - ) - last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102 - is_update: Optional[bool] = Field(sa_column=Column(Boolean)) - - -class PlayerInfoSQLModel(PlayerInfo, table=True): - __tablename__ = "players_info" diff --git a/core/services/players/repositories.py b/core/services/players/repositories.py index e88221d..5f5a79c 100644 --- a/core/services/players/repositories.py +++ b/core/services/players/repositories.py @@ -1,110 +1,3 @@ -from typing import List, Optional - -from sqlmodel import select, delete - -from core.base_service import BaseService -from core.basemodel import RegionEnum -from core.dependence.database import Database -from core.services.players.models import PlayerInfoSQLModel -from core.services.players.models import PlayersDataBase as Player -from core.sqlmodel.session import AsyncSession +from gram_core.services.players.repositories import PlayersRepository, PlayerInfoRepository __all__ = ("PlayersRepository", "PlayerInfoRepository") - - -class PlayersRepository(BaseService.Component): - def __init__(self, database: Database): - self.engine = database.engine - - async def get( - self, - user_id: int, - player_id: Optional[int] = None, - account_id: Optional[int] = None, - region: Optional[RegionEnum] = None, - is_chosen: Optional[bool] = None, - ) -> Optional[Player]: - async with AsyncSession(self.engine) as session: - statement = select(Player).where(Player.user_id == user_id) - if player_id is not None: - statement = statement.where(Player.player_id == player_id) - if account_id is not None: - statement = statement.where(Player.account_id == account_id) - if region is not None: - statement = statement.where(Player.region == region) - if is_chosen is not None: - statement = statement.where(Player.is_chosen == is_chosen) - results = await session.exec(statement) - return results.first() - - async def add(self, player: Player) -> None: - async with AsyncSession(self.engine) as session: - session.add(player) - await session.commit() - await session.refresh(player) - - async def delete(self, player: Player) -> None: - async with AsyncSession(self.engine) as session: - await session.delete(player) - await session.commit() - - async def update(self, player: Player) -> None: - async with AsyncSession(self.engine) as session: - session.add(player) - await session.commit() - await session.refresh(player) - - async def get_all_by_user_id(self, user_id: int) -> List[Player]: - async with AsyncSession(self.engine) as session: - statement = select(Player).where(Player.user_id == user_id) - results = await session.exec(statement) - players = results.all() - return players - - -class PlayerInfoRepository(BaseService.Component): - def __init__(self, database: Database): - self.engine = database.engine - - async def get( - self, - user_id: int, - player_id: int, - ) -> Optional[PlayerInfoSQLModel]: - async with AsyncSession(self.engine) as session: - statement = ( - select(PlayerInfoSQLModel) - .where(PlayerInfoSQLModel.player_id == player_id) - .where(PlayerInfoSQLModel.user_id == user_id) - ) - results = await session.exec(statement) - return results.first() - - async def add(self, player: PlayerInfoSQLModel) -> None: - async with AsyncSession(self.engine) as session: - session.add(player) - await session.commit() - - async def delete(self, player: PlayerInfoSQLModel) -> None: - async with AsyncSession(self.engine) as session: - await session.delete(player) - await session.commit() - - async def delete_by_id( - self, - user_id: int, - player_id: int, - ) -> None: - async with AsyncSession(self.engine) as session: - statement = ( - delete(PlayerInfoSQLModel) - .where(PlayerInfoSQLModel.player_id == player_id) - .where(PlayerInfoSQLModel.user_id == user_id) - ) - await session.execute(statement) - - async def update(self, player: PlayerInfoSQLModel) -> None: - async with AsyncSession(self.engine) as session: - session.add(player) - await session.commit() - await session.refresh(player) diff --git a/core/services/players/services.py b/core/services/players/services.py index 8e3bc41..98d77b3 100644 --- a/core/services/players/services.py +++ b/core/services/players/services.py @@ -1,52 +1,18 @@ from datetime import datetime, timedelta -from typing import List, Optional +from typing import Optional from core.base_service import BaseService -from core.basemodel import RegionEnum from core.dependence.redisdb import RedisDB from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo -from core.services.players.repositories import PlayersRepository, PlayerInfoRepository +from core.services.players.repositories import PlayerInfoRepository from modules.apihelper.client.components.player_cards import PlayerCards, PlayerBaseInfo from utils.log import logger +from gram_core.services.players import PlayersService + __all__ = ("PlayersService", "PlayerInfoService") -class PlayersService(BaseService): - def __init__(self, players_repository: PlayersRepository) -> None: - self._repository = players_repository - - async def get( - self, - user_id: int, - player_id: Optional[int] = None, - account_id: Optional[int] = None, - region: Optional[RegionEnum] = None, - is_chosen: Optional[bool] = None, - ) -> Optional[Player]: - return await self._repository.get(user_id, player_id, account_id, region, is_chosen) - - async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]: - return await self._repository.get(user_id, region=region, is_chosen=True) - - async def add(self, player: Player) -> None: - await self._repository.add(player) - - async def update(self, player: Player) -> None: - await self._repository.update(player) - - async def get_all_by_user_id(self, user_id: int) -> List[Player]: - return await self._repository.get_all_by_user_id(user_id) - - async def remove_all_by_user_id(self, user_id: int): - players = await self._repository.get_all_by_user_id(user_id) - for player in players: - await self._repository.delete(player) - - async def delete(self, player: Player): - await self._repository.delete(player) - - class PlayerInfoService(BaseService): def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository): self.cache = redis.client diff --git a/core/services/task/models.py b/core/services/task/models.py index d8a3cd0..ba0f769 100644 --- a/core/services/task/models.py +++ b/core/services/task/models.py @@ -1,44 +1,3 @@ -import enum -from datetime import datetime -from typing import Optional, Dict, Any - -from sqlalchemy import func, BigInteger, JSON -from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer +from gram_core.services.task.models import Task, TaskStatusEnum, TaskTypeEnum __all__ = ("Task", "TaskStatusEnum", "TaskTypeEnum") - - -class TaskStatusEnum(int, enum.Enum): - STATUS_SUCCESS = 0 # 任务执行成功 - INVALID_COOKIES = 1 # Cookie无效 - ALREADY_CLAIMED = 2 # 已经获取奖励 - NEED_CHALLENGE = 3 # 需要验证码 - GENSHIN_EXCEPTION = 4 # API异常 - TIMEOUT_ERROR = 5 # 请求超时 - BAD_REQUEST = 6 # 请求失败 - FORBIDDEN = 7 # 这错误一般为通知失败 机器人被用户BAN - - -class TaskTypeEnum(int, enum.Enum): - SIGN = 0 # 签到 - RESIN = 1 # 开拓力 - REALM = 2 # 洞天宝钱 - EXPEDITION = 3 # 委托 - TRANSFORMER = 4 # 参量质变仪 - CARD = 5 # 生日画片 - - -class Task(SQLModel, table=True): - __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - id: Optional[int] = Field( - default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) - ) - user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True)) - chat_id: Optional[int] = Field(default=None, sa_column=Column(BigInteger())) - time_created: Optional[datetime] = Field( - sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102 - ) - time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102 - type: TaskTypeEnum = Field(primary_key=True, sa_column=Column(Enum(TaskTypeEnum))) - status: Optional[TaskStatusEnum] = Field(sa_column=Column(Enum(TaskStatusEnum))) - data: Optional[Dict[str, Any]] = Field(sa_column=Column(JSON)) diff --git a/core/services/task/repositories.py b/core/services/task/repositories.py index c509836..728225e 100644 --- a/core/services/task/repositories.py +++ b/core/services/task/repositories.py @@ -1,50 +1,3 @@ -from typing import List, Optional - -from sqlmodel import select - -from core.base_service import BaseService -from core.dependence.database import Database -from core.services.task.models import Task, TaskTypeEnum -from core.sqlmodel.session import AsyncSession +from gram_core.services.task.repositories import TaskRepository __all__ = ("TaskRepository",) - - -class TaskRepository(BaseService.Component): - def __init__(self, database: Database): - self.engine = database.engine - - async def add(self, task: Task): - async with AsyncSession(self.engine) as session: - session.add(task) - await session.commit() - - async def remove(self, task: Task): - async with AsyncSession(self.engine) as session: - await session.delete(task) - await session.commit() - - async def update(self, task: Task) -> Task: - async with AsyncSession(self.engine) as session: - session.add(task) - await session.commit() - await session.refresh(task) - return task - - async def get_by_user_id(self, user_id: int, task_type: TaskTypeEnum) -> Optional[Task]: - async with AsyncSession(self.engine) as session: - statement = select(Task).where(Task.user_id == user_id).where(Task.type == task_type) - results = await session.exec(statement) - return results.first() - - async def get_by_chat_id(self, chat_id: int, task_type: TaskTypeEnum) -> Optional[List[Task]]: - async with AsyncSession(self.engine) as session: - statement = select(Task).where(Task.chat_id == chat_id).where(Task.type == task_type) - results = await session.exec(statement) - return results.all() - - async def get_all(self, task_type: TaskTypeEnum) -> List[Task]: - async with AsyncSession(self.engine) as session: - query = select(Task).where(Task.type == task_type) - results = await session.exec(query) - return results.all() diff --git a/core/services/task/services.py b/core/services/task/services.py index 2f19307..01add98 100644 --- a/core/services/task/services.py +++ b/core/services/task/services.py @@ -1,9 +1,10 @@ -import datetime -from typing import Optional, Dict, Any - -from core.base_service import BaseService -from core.services.task.models import Task, TaskTypeEnum -from core.services.task.repositories import TaskRepository +from gram_core.services.task.services import ( + TaskServices, + SignServices, + TaskCardServices, + TaskResinServices, + TaskExpeditionServices, +) __all__ = [ "TaskServices", @@ -12,168 +13,3 @@ __all__ = [ "TaskResinServices", "TaskExpeditionServices", ] - - -class TaskServices(BaseService): - TASK_TYPE: TaskTypeEnum - - def __init__(self, task_repository: TaskRepository) -> None: - self._repository: TaskRepository = task_repository - - async def add(self, task: Task): - return await self._repository.add(task) - - async def remove(self, task: Task): - return await self._repository.remove(task) - - async def update(self, task: Task): - task.time_updated = datetime.datetime.now() - return await self._repository.update(task) - - async def get_by_user_id(self, user_id: int): - return await self._repository.get_by_user_id(user_id, self.TASK_TYPE) - - async def get_all(self): - return await self._repository.get_all(self.TASK_TYPE) - - def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): - return Task( - user_id=user_id, - chat_id=chat_id, - time_created=datetime.datetime.now(), - status=status, - type=self.TASK_TYPE, - data=data, - ) - - -class SignServices(BaseService): - TASK_TYPE = TaskTypeEnum.SIGN - - def __init__(self, task_repository: TaskRepository) -> None: - self._repository: TaskRepository = task_repository - - async def add(self, task: Task): - return await self._repository.add(task) - - async def remove(self, task: Task): - return await self._repository.remove(task) - - async def update(self, task: Task): - task.time_updated = datetime.datetime.now() - return await self._repository.update(task) - - async def get_by_user_id(self, user_id: int): - return await self._repository.get_by_user_id(user_id, self.TASK_TYPE) - - async def get_all(self): - return await self._repository.get_all(self.TASK_TYPE) - - def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): - return Task( - user_id=user_id, - chat_id=chat_id, - time_created=datetime.datetime.now(), - status=status, - type=self.TASK_TYPE, - data=data, - ) - - -class TaskCardServices(BaseService): - TASK_TYPE = TaskTypeEnum.CARD - - def __init__(self, task_repository: TaskRepository) -> None: - self._repository: TaskRepository = task_repository - - async def add(self, task: Task): - return await self._repository.add(task) - - async def remove(self, task: Task): - return await self._repository.remove(task) - - async def update(self, task: Task): - task.time_updated = datetime.datetime.now() - return await self._repository.update(task) - - async def get_by_user_id(self, user_id: int): - return await self._repository.get_by_user_id(user_id, self.TASK_TYPE) - - async def get_all(self): - return await self._repository.get_all(self.TASK_TYPE) - - def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): - return Task( - user_id=user_id, - chat_id=chat_id, - time_created=datetime.datetime.now(), - status=status, - type=self.TASK_TYPE, - data=data, - ) - - -class TaskResinServices(BaseService): - TASK_TYPE = TaskTypeEnum.RESIN - - def __init__(self, task_repository: TaskRepository) -> None: - self._repository: TaskRepository = task_repository - - async def add(self, task: Task): - return await self._repository.add(task) - - async def remove(self, task: Task): - return await self._repository.remove(task) - - async def update(self, task: Task): - task.time_updated = datetime.datetime.now() - return await self._repository.update(task) - - async def get_by_user_id(self, user_id: int): - return await self._repository.get_by_user_id(user_id, self.TASK_TYPE) - - async def get_all(self): - return await self._repository.get_all(self.TASK_TYPE) - - def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): - return Task( - user_id=user_id, - chat_id=chat_id, - time_created=datetime.datetime.now(), - status=status, - type=self.TASK_TYPE, - data=data, - ) - - -class TaskExpeditionServices(BaseService): - TASK_TYPE = TaskTypeEnum.EXPEDITION - - def __init__(self, task_repository: TaskRepository) -> None: - self._repository: TaskRepository = task_repository - - async def add(self, task: Task): - return await self._repository.add(task) - - async def remove(self, task: Task): - return await self._repository.remove(task) - - async def update(self, task: Task): - task.time_updated = datetime.datetime.now() - return await self._repository.update(task) - - async def get_by_user_id(self, user_id: int): - return await self._repository.get_by_user_id(user_id, self.TASK_TYPE) - - async def get_all(self): - return await self._repository.get_all(self.TASK_TYPE) - - def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): - return Task( - user_id=user_id, - chat_id=chat_id, - time_created=datetime.datetime.now(), - status=status, - type=self.TASK_TYPE, - data=data, - ) diff --git a/core/services/template/cache.py b/core/services/template/cache.py index 1adadbd..bc9047a 100644 --- a/core/services/template/cache.py +++ b/core/services/template/cache.py @@ -1,58 +1,3 @@ -import gzip -import pickle # nosec B403 -from hashlib import sha256 -from typing import Any, Optional - -from core.base_service import BaseService -from core.dependence.redisdb import RedisDB +from gram_core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache __all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"] - - -class TemplatePreviewCache(BaseService.Component): - """暂存渲染模板的数据用于预览""" - - def __init__(self, redis: RedisDB): - self.client = redis.client - self.qname = "bot:template:preview" - - async def get_data(self, key: str) -> Any: - data = await self.client.get(self.cache_key(key)) - if data: - # skipcq: BAN-B301 - return pickle.loads(gzip.decompress(data)) # nosec B301 - - async def set_data(self, key: str, data: Any, ttl: int = 8 * 60 * 60): - ck = self.cache_key(key) - await self.client.set(ck, gzip.compress(pickle.dumps(data))) - if ttl != -1: - await self.client.expire(ck, ttl) - - def cache_key(self, key: str) -> str: - return f"{self.qname}:{key}" - - -class HtmlToFileIdCache(BaseService.Component): - """html to file_id 的缓存""" - - def __init__(self, redis: RedisDB): - self.client = redis.client - self.qname = "bot:template:html-to-file-id" - - async def get_data(self, html: str, file_type: str) -> Optional[str]: - data = await self.client.get(self.cache_key(html, file_type)) - if data: - return data.decode() - - async def set_data(self, html: str, file_type: str, file_id: str, ttl: int = 24 * 60 * 60): - ck = self.cache_key(html, file_type) - await self.client.set(ck, file_id) - if ttl != -1: - await self.client.expire(ck, ttl) - - async def delete_data(self, html: str, file_type: str) -> bool: - return await self.client.delete(self.cache_key(html, file_type)) - - def cache_key(self, html: str, file_type: str) -> str: - key = sha256(html.encode()).hexdigest() - return f"{self.qname}:{file_type}:{key}" diff --git a/core/services/template/error.py b/core/services/template/error.py index 197e06c..4460bd9 100644 --- a/core/services/template/error.py +++ b/core/services/template/error.py @@ -1,14 +1,8 @@ -class TemplateException(Exception): - pass +from gram_core.services.template.error import ( + ErrorFileType, + FileIdNotFound, + QuerySelectorNotFound, + TemplateException, +) - -class QuerySelectorNotFound(TemplateException): - pass - - -class ErrorFileType(TemplateException): - pass - - -class FileIdNotFound(TemplateException): - pass +__all__ = ("TemplateException", "QuerySelectorNotFound", "ErrorFileType", "FileIdNotFound") diff --git a/core/services/template/models.py b/core/services/template/models.py index a5737cd..bc929e4 100644 --- a/core/services/template/models.py +++ b/core/services/template/models.py @@ -1,146 +1,3 @@ -from enum import Enum -from typing import List, Optional, Union - -from telegram import InputMediaDocument, InputMediaPhoto, Message -from telegram.error import BadRequest - -from core.services.template.cache import HtmlToFileIdCache -from core.services.template.error import ErrorFileType, FileIdNotFound +from gram_core.services.template.models import FileType, RenderResult, RenderGroupResult __all__ = ["FileType", "RenderResult", "RenderGroupResult"] - - -class FileType(Enum): - PHOTO = 1 - DOCUMENT = 2 - - @staticmethod - def media_type(file_type: "FileType"): - """对应的 Telegram media 类型""" - if file_type == FileType.PHOTO: - return InputMediaPhoto - if file_type == FileType.DOCUMENT: - return InputMediaDocument - raise ErrorFileType - - -class RenderResult: - """渲染结果""" - - def __init__( - self, - html: str, - photo: Union[bytes, str], - file_type: FileType, - cache: HtmlToFileIdCache, - ttl: int = 24 * 60 * 60, - caption: Optional[str] = None, - parse_mode: Optional[str] = None, - filename: Optional[str] = None, - ): - """ - `html`: str 渲染生成的 html - `photo`: Union[bytes, str] 渲染生成的图片。bytes 表示是图片,str 则为 file_id - """ - self.caption = caption - self.parse_mode = parse_mode - self.filename = filename - self.html = html - self.photo = photo - self.file_type = file_type - self._cache = cache - self.ttl = ttl - - async def reply_photo(self, message: Message, *args, **kwargs): - """是 `message.reply_photo` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" - if self.file_type != FileType.PHOTO: - raise ErrorFileType - - try: - reply = await message.reply_photo(photo=self.photo, *args, **kwargs) - except BadRequest as exc: - if "Wrong file identifier" in exc.message and isinstance(self.photo, str): - await self._cache.delete_data(self.html, self.file_type.name) - raise BadRequest(message="Wrong file identifier specified") - raise exc - - await self.cache_file_id(reply) - - return reply - - async def reply_document(self, message: Message, *args, **kwargs): - """是 `message.reply_document` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" - if self.file_type != FileType.DOCUMENT: - raise ErrorFileType - - try: - reply = await message.reply_document(document=self.photo, *args, **kwargs) - except BadRequest as exc: - if "Wrong file identifier" in exc.message and isinstance(self.photo, str): - await self._cache.delete_data(self.html, self.file_type.name) - raise BadRequest(message="Wrong file identifier specified") - raise exc - - await self.cache_file_id(reply) - - return reply - - async def edit_media(self, message: Message, *args, **kwargs): - """是 `message.edit_media` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" - if self.file_type != FileType.PHOTO: - raise ErrorFileType - - media = InputMediaPhoto( - media=self.photo, caption=self.caption, parse_mode=self.parse_mode, filename=self.filename - ) - - try: - edit_media = await message.edit_media(media, *args, **kwargs) - except BadRequest as exc: - if "Wrong file identifier" in exc.message and isinstance(self.photo, str): - await self._cache.delete_data(self.html, self.file_type.name) - raise BadRequest(message="Wrong file identifier specified") - raise exc - - await self.cache_file_id(edit_media) - - return edit_media - - async def cache_file_id(self, reply: Message): - """缓存 telegram 返回的 file_id""" - if self.is_file_id(): - return - - if self.file_type == FileType.PHOTO and reply.photo: - file_id = reply.photo[0].file_id - elif self.file_type == FileType.DOCUMENT and reply.document: - file_id = reply.document.file_id - else: - raise FileIdNotFound - await self._cache.set_data(self.html, self.file_type.name, file_id, self.ttl) - - def is_file_id(self) -> bool: - return isinstance(self.photo, str) - - -class RenderGroupResult: - def __init__(self, results: List[RenderResult]): - self.results = results - - async def reply_media_group(self, message: Message, *args, **kwargs): - """是 `message.reply_media_group` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" - - reply = await message.reply_media_group( - media=[ - FileType.media_type(result.file_type)( - media=result.photo, caption=result.caption, parse_mode=result.parse_mode, filename=result.filename - ) - for result in self.results - ], - *args, - **kwargs, - ) - - for index, value in enumerate(reply): - result = self.results[index] - await result.cache_file_id(value) diff --git a/core/services/template/services.py b/core/services/template/services.py index 07ae1e1..4eed693 100644 --- a/core/services/template/services.py +++ b/core/services/template/services.py @@ -1,207 +1,3 @@ -import asyncio -from typing import Optional -from urllib.parse import urlencode, urljoin, urlsplit -from uuid import uuid4 - -from fastapi import FastAPI, HTTPException -from fastapi.responses import FileResponse, HTMLResponse -from fastapi.staticfiles import StaticFiles -from jinja2 import Environment, FileSystemLoader, Template -from playwright.async_api import ViewportSize - -from core.application import Application -from core.base_service import BaseService -from core.config import config as application_config -from core.dependence.aiobrowser import AioBrowser -from core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache -from core.services.template.error import QuerySelectorNotFound -from core.services.template.models import FileType, RenderResult -from utils.const import PROJECT_ROOT -from utils.log import logger +from gram_core.services.template.services import TemplateService, TemplatePreviewer __all__ = ("TemplateService", "TemplatePreviewer") - - -class TemplateService(BaseService): - def __init__( - self, - app: Application, - browser: AioBrowser, - html_to_file_id_cache: HtmlToFileIdCache, - preview_cache: TemplatePreviewCache, - template_dir: str = "resources", - ): - self._browser = browser - self.template_dir = PROJECT_ROOT / template_dir - - self._jinja2_env = Environment( - loader=FileSystemLoader(template_dir), - enable_async=True, - autoescape=True, - auto_reload=application_config.debug, - ) - self.using_preview = application_config.debug and application_config.webserver.enable - - if self.using_preview: - self.previewer = TemplatePreviewer(self, preview_cache, app.web_app) - - self.html_to_file_id_cache = html_to_file_id_cache - - def get_template(self, template_name: str) -> Template: - return self._jinja2_env.get_template(template_name) - - async def render_async(self, template_name: str, template_data: dict) -> str: - """模板渲染 - :param template_name: 模板文件名 - :param template_data: 模板数据 - """ - loop = asyncio.get_event_loop() - start_time = loop.time() - template = self.get_template(template_name) - html = await template.render_async(**template_data) - logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time)) - return html - - async def render( - self, - template_name: str, - template_data: dict, - viewport: Optional[ViewportSize] = None, - full_page: bool = True, - evaluate: Optional[str] = None, - query_selector: Optional[str] = None, - file_type: FileType = FileType.PHOTO, - ttl: int = 24 * 60 * 60, - caption: Optional[str] = None, - parse_mode: Optional[str] = None, - filename: Optional[str] = None, - ) -> RenderResult: - """模板渲染成图片 - :param template_name: 模板文件名 - :param template_data: 模板数据 - :param viewport: 截图大小 - :param full_page: 是否长截图 - :param evaluate: 页面加载后运行的 js - :param query_selector: 截图选择器 - :param file_type: 缓存的文件类型 - :param ttl: 缓存时间 - :param caption: 图片描述 - :param parse_mode: 图片描述解析模式 - :param filename: 文件名字 - :return: - """ - loop = asyncio.get_event_loop() - start_time = loop.time() - template = self.get_template(template_name) - - if self.using_preview: - preview_url = await self.previewer.get_preview_url(template_name, template_data) - logger.debug("调试模板 URL: \n%s", preview_url) - - html = await template.render_async(**template_data) - logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time)) - - file_id = await self.html_to_file_id_cache.get_data(html, file_type.name) - if file_id and not application_config.debug: - logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id) - return RenderResult( - html=html, - photo=file_id, - file_type=file_type, - cache=self.html_to_file_id_cache, - ttl=ttl, - caption=caption, - parse_mode=parse_mode, - filename=filename, - ) - - browser = await self._browser.get_browser() - start_time = loop.time() - page = await browser.new_page(viewport=viewport) - uri = (PROJECT_ROOT / template.filename).as_uri() - await page.goto(uri) - await page.set_content(html, wait_until="networkidle") - if evaluate: - await page.evaluate(evaluate) - clip = None - if query_selector: - try: - card = await page.query_selector(query_selector) - if not card: - raise QuerySelectorNotFound - clip = await card.bounding_box() - if not clip: - raise QuerySelectorNotFound - except QuerySelectorNotFound: - logger.warning("未找到 %s 元素", query_selector) - png_data = await page.screenshot(clip=clip, full_page=full_page) - await page.close() - logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time)) - return RenderResult( - html=html, - photo=png_data, - file_type=file_type, - cache=self.html_to_file_id_cache, - ttl=ttl, - caption=caption, - parse_mode=parse_mode, - filename=filename, - ) - - -class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug): - def __init__( - self, - template_service: TemplateService, - cache: TemplatePreviewCache, - web_app: FastAPI, - ): - self.web_app = web_app - self.template_service = template_service - self.cache = cache - self.register_routes() - - async def get_preview_url(self, template: str, data: dict): - """获取预览 URL""" - components = urlsplit(application_config.webserver.url) - path = urljoin("/preview/", template) - query = {} - - # 如果有数据,暂存在 redis 中 - if data: - key = str(uuid4()) - await self.cache.set_data(key, data) - query["key"] = key - - # noinspection PyProtectedMember - return components._replace(path=path, query=urlencode(query)).geturl() - - def register_routes(self): - """注册预览用到的路由""" - - @self.web_app.get("/preview/{path:path}") - async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612 - # 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源 - if not path.endswith((".html", ".jinja2")): - full_path = self.template_service.template_dir / path - if not full_path.is_file(): - raise HTTPException(status_code=404, detail=f"Template '{path}' not found") - return FileResponse(full_path) - - # 取回暂存的渲染数据 - data = await self.cache.get_data(key) if key else {} - if key and data is None: - raise HTTPException(status_code=404, detail=f"Template data {key} not found") - - # 渲染 jinja2 模板 - html = await self.template_service.render_async(path, data) - # 将本地 URL file:// 修改为 HTTP url,因为浏览器内不允许加载本地文件 - # file:///project_dir/cache/image.jpg => /cache/image.jpg - html = html.replace(PROJECT_ROOT.as_uri(), "") - return HTMLResponse(html) - - # 其他静态资源 - for name in ["cache", "resources"]: - directory = PROJECT_ROOT / name - directory.mkdir(exist_ok=True) - self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name) diff --git a/core/services/users/cache.py b/core/services/users/cache.py index 121e024..52ded7c 100644 --- a/core/services/users/cache.py +++ b/core/services/users/cache.py @@ -1,24 +1,3 @@ -from typing import List - -from core.base_service import BaseService -from core.dependence.redisdb import RedisDB +from gram_core.services.users.cache import UserAdminCache __all__ = ("UserAdminCache",) - - -class UserAdminCache(BaseService.Component): - def __init__(self, redis: RedisDB): - self.client = redis.client - self.qname = "users:admin" - - async def ismember(self, user_id: int) -> bool: - return await self.client.sismember(self.qname, user_id) - - async def get_all(self) -> List[int]: - return [int(str_data) for str_data in await self.client.smembers(self.qname)] - - async def set(self, user_id: int) -> bool: - return await self.client.sadd(self.qname, user_id) - - async def remove(self, user_id: int) -> bool: - return await self.client.srem(self.qname, user_id) diff --git a/core/services/users/models.py b/core/services/users/models.py index 5d5fa1c..e5cdc61 100644 --- a/core/services/users/models.py +++ b/core/services/users/models.py @@ -1,34 +1,7 @@ -import enum -from datetime import datetime -from typing import Optional - -from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer +from gram_core.services.users.models import User, UserDataBase, PermissionsEnum __all__ = ( "User", "UserDataBase", "PermissionsEnum", ) - - -class PermissionsEnum(int, enum.Enum): - OWNER = 1 - ADMIN = 2 - PUBLIC = 3 - - -class User(SQLModel): - __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - id: Optional[int] = Field( - default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) - ) - user_id: int = Field(unique=True, sa_column=Column(BigInteger())) - permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum))) - locale: Optional[str] = Field() - ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True))) - ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True))) - is_banned: Optional[int] = Field() - - -class UserDataBase(User, table=True): - __tablename__ = "users" diff --git a/core/services/users/repositories.py b/core/services/users/repositories.py index 64bb090..18fb418 100644 --- a/core/services/users/repositories.py +++ b/core/services/users/repositories.py @@ -1,44 +1,3 @@ -from typing import Optional, List - -from sqlmodel import select - -from core.base_service import BaseService -from core.dependence.database import Database -from core.services.users.models import UserDataBase as User -from core.sqlmodel.session import AsyncSession +from gram_core.services.users.repositories import UserRepository __all__ = ("UserRepository",) - - -class UserRepository(BaseService.Component): - def __init__(self, database: Database): - self.engine = database.engine - - async def get_by_user_id(self, user_id: int) -> Optional[User]: - async with AsyncSession(self.engine) as session: - statement = select(User).where(User.user_id == user_id) - results = await session.exec(statement) - return results.first() - - async def add(self, user: User): - async with AsyncSession(self.engine) as session: - session.add(user) - await session.commit() - - async def update(self, user: User) -> User: - async with AsyncSession(self.engine) as session: - session.add(user) - await session.commit() - await session.refresh(user) - return user - - async def remove(self, user: User): - async with AsyncSession(self.engine) as session: - await session.delete(user) - await session.commit() - - async def get_all(self) -> List[User]: - async with AsyncSession(self.engine) as session: - statement = select(User) - results = await session.exec(statement) - return results.all() diff --git a/core/services/users/services.py b/core/services/users/services.py index 84461cd..692dc49 100644 --- a/core/services/users/services.py +++ b/core/services/users/services.py @@ -1,83 +1,3 @@ -from typing import List, Optional - -from core.base_service import BaseService -from core.config import config -from core.services.users.cache import UserAdminCache -from core.services.users.models import PermissionsEnum, UserDataBase as User -from core.services.users.repositories import UserRepository +from gram_core.services.users.services import UserService, UserAdminService __all__ = ("UserService", "UserAdminService") - -from utils.log import logger - - -class UserService(BaseService): - def __init__(self, user_repository: UserRepository) -> None: - self._repository: UserRepository = user_repository - - async def get_user_by_id(self, user_id: int) -> Optional[User]: - """从数据库获取用户信息 - :param user_id:用户ID - :return: User - """ - return await self._repository.get_by_user_id(user_id) - - async def remove(self, user: User): - return await self._repository.remove(user) - - async def update_user(self, user: User): - return await self._repository.add(user) - - -class UserAdminService(BaseService): - def __init__(self, user_repository: UserRepository, cache: UserAdminCache): - self.user_repository = user_repository - self._cache = cache - - async def initialize(self): - owner = config.owner - if owner: - user = await self.user_repository.get_by_user_id(owner) - if user: - if user.permissions != PermissionsEnum.OWNER: - user.permissions = PermissionsEnum.OWNER - await self._cache.set(user.user_id) - await self.user_repository.update(user) - else: - user = User(user_id=owner, permissions=PermissionsEnum.OWNER) - await self._cache.set(user.user_id) - await self.user_repository.add(user) - else: - logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限") - users = await self.user_repository.get_all() - for user in users: - await self._cache.set(user.user_id) - - async def is_admin(self, user_id: int) -> bool: - return await self._cache.ismember(user_id) - - async def get_admin_list(self) -> List[int]: - return await self._cache.get_all() - - async def add_admin(self, user_id: int) -> bool: - user = await self.user_repository.get_by_user_id(user_id) - if user: - if user.permissions == PermissionsEnum.OWNER: - return False - if user.permissions != PermissionsEnum.ADMIN: - user.permissions = PermissionsEnum.ADMIN - await self.user_repository.update(user) - else: - user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN) - await self.user_repository.add(user) - return await self._cache.set(user_id) - - async def delete_admin(self, user_id: int) -> bool: - user = await self.user_repository.get_by_user_id(user_id) - if user: - if user.permissions == PermissionsEnum.OWNER: - return True # 假装移除成功 - user.permissions = PermissionsEnum.PUBLIC - await self.user_repository.update(user) - return await self._cache.remove(user.user_id) - return False diff --git a/core/override/__init__.py b/core/sqlmodel/__init__.py similarity index 100% rename from core/override/__init__.py rename to core/sqlmodel/__init__.py diff --git a/core/sqlmodel/session.py b/core/sqlmodel/session.py index 88e4d3d..92df4a9 100644 --- a/core/sqlmodel/session.py +++ b/core/sqlmodel/session.py @@ -1,118 +1,3 @@ -from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload - -from sqlalchemy import util -from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession -from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine -from sqlalchemy.sql.base import Executable as _Executable -from sqlmodel.engine.result import Result, ScalarResult -from sqlmodel.orm.session import Session -from sqlmodel.sql.base import Executable -from sqlmodel.sql.expression import Select, SelectOfScalar -from typing_extensions import Literal - -_TSelectParam = TypeVar("_TSelectParam") +from gram_core.sqlmodel.session import AsyncSession __all__ = ("AsyncSession",) - - -class AsyncSession(_AsyncSession): # pylint: disable=W0223 - sync_session_class = Session - sync_session: Session - - def __init__( - self, - bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, - binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, - sync_session_class: Type[Session] = Session, - **kw: Any, - ): - super().__init__( - bind=bind, - binds=binds, - sync_session_class=sync_session_class, - **kw, - ) - - @overload - async def exec( - self, - statement: Select[_TSelectParam], - *, - params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, - **kw: Any, - ) -> Result[_TSelectParam]: - ... - - @overload - async def exec( - self, - statement: SelectOfScalar[_TSelectParam], - *, - params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, - **kw: Any, - ) -> ScalarResult[_TSelectParam]: - ... - - async def exec( - self, - statement: Union[ - Select[_TSelectParam], - SelectOfScalar[_TSelectParam], - Executable[_TSelectParam], - ], - *, - params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, - **kw: Any, - ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]: - results = super().execute( - statement, - params=params, - execution_options=execution_options, - bind_arguments=bind_arguments, - **kw, - ) - if isinstance(statement, SelectOfScalar): - return (await results).scalars() # type: ignore - return await results # type: ignore - - async def execute( # pylint: disable=W0221 - self, - statement: _Executable, - params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, - **kw: Any, - ) -> Result[Any]: - return await super().execute( # type: ignore - statement=statement, - params=params, - execution_options=execution_options, - bind_arguments=bind_arguments, - **kw, - ) - - async def get( # pylint: disable=W0221 - self, - entity: Type[_TSelectParam], - ident: Any, - options: Optional[Sequence[Any]] = None, - populate_existing: bool = False, - with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, - identity_token: Optional[Any] = None, - execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, - ) -> Optional[_TSelectParam]: - return await super().get( - entity=entity, - ident=ident, - options=options, - populate_existing=populate_existing, - with_for_update=with_for_update, - identity_token=identity_token, - execution_options=execution_options, - ) diff --git a/poetry.lock b/poetry.lock index af5f9cf..d291f14 100644 --- a/poetry.lock +++ b/poetry.lock @@ -606,6 +606,21 @@ files = [ [package.dependencies] gitdb = ">=4.0.1,<5" +[[package]] +name = "gram-core" +version = "0.1.0" +description = "telegram robot base core." +optional = false +python-versions = ">=3.8" +files = [] +develop = false + +[package.source] +type = "git" +url = "https://github.com/PaiGramTeam/GramCore.git" +reference = "HEAD" +resolved_reference = "7fb5d4c0e01731e6901829fe317be19023c2c4c7" + [[package]] name = "greenlet" version = "1.1.3" @@ -2316,4 +2331,4 @@ test = ["flaky", "pytest", "pytest-asyncio"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "8378e2149d042db3c75d422e969b658c255654e517c32bdb60b6dafef01e071e" +content-hash = "847b0e55ac501bd6126b5370fc08c018401147d8dd5c68c3216833b64b1213ea" diff --git a/pyproject.toml b/pyproject.toml index e8b2e51..c1ecfd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ pillow = "^10.0.0" playwright = "^1.27.1" aiosqlite = { extras = ["sqlite"], version = "^0.19.0" } simnet = { git = "https://github.com/PaiGramTeam/SIMNet" } +gram-core = {git = "https://github.com/PaiGramTeam/GramCore.git"} [tool.poetry.extras] pyro = ["Pyrogram", "TgCrypto"] diff --git a/requirements.txt b/requirements.txt index d7d0d6d..c4835be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ fastapi==0.99.1 ; python_version >= "3.8" and python_version < "4.0" flaky==3.7.0 ; python_version >= "3.8" and python_version < "4.0" gitdb==4.0.10 ; python_version >= "3.8" and python_version < "4.0" gitpython==3.1.32 ; python_version >= "3.8" and python_version < "4.0" +gram-core @ git+https://github.com/PaiGramTeam/GramCore.git@main ; python_version >= "3.8" and python_version < "4.0" greenlet==1.1.3 ; python_version >= "3.8" and python_version < "4.0" h11==0.14.0 ; python_version >= "3.8" and python_version < "4.0" httpcore==0.17.3 ; python_version >= "3.8" and python_version < "4.0" diff --git a/run.py b/run.py index ce733e3..357bf63 100644 --- a/run.py +++ b/run.py @@ -19,7 +19,7 @@ def run(): def main(): - from core.builtins.reloader import Reloader + from gram_core.builtins.reloader import Reloader from core.config import config if config.auto_reload: # 是否启动重载器