mirror of
https://github.com/PaiGramTeam/PamGram.git
synced 2024-11-21 13:48:19 +00:00
♻️ separate core code
Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
This commit is contained in:
parent
b13df44e19
commit
b1141c65b8
@ -1,289 +1,5 @@
|
|||||||
"""BOT"""
|
"""BOT"""
|
||||||
import asyncio
|
|
||||||
import signal
|
|
||||||
from functools import wraps
|
|
||||||
from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func
|
|
||||||
from ssl import SSLZeroReturnError
|
|
||||||
from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar
|
|
||||||
|
|
||||||
import pytz
|
from gram_core.application import Application
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from telegram import Bot, Update
|
|
||||||
from telegram.error import NetworkError, TelegramError, TimedOut
|
|
||||||
from telegram.ext import (
|
|
||||||
Application as TelegramApplication,
|
|
||||||
ApplicationBuilder as TelegramApplicationBuilder,
|
|
||||||
Defaults,
|
|
||||||
JobQueue,
|
|
||||||
)
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
from uvicorn import Server
|
|
||||||
|
|
||||||
from core.config import config as application_config
|
|
||||||
from core.handler.limiterhandler import LimiterHandler
|
|
||||||
from core.manager import Managers
|
|
||||||
from core.override.telegram import HTTPXRequest
|
|
||||||
from core.ratelimiter import RateLimiter
|
|
||||||
from utils.const import WRAPPER_ASSIGNMENTS
|
|
||||||
from utils.log import logger
|
|
||||||
from utils.models.signal import Singleton
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from asyncio import Task
|
|
||||||
from types import FrameType
|
|
||||||
|
|
||||||
__all__ = ("Application",)
|
__all__ = ("Application",)
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
T = TypeVar("T")
|
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
|
|
||||||
class Application(Singleton):
|
|
||||||
"""Application"""
|
|
||||||
|
|
||||||
_web_server_task: Optional["Task"] = None
|
|
||||||
|
|
||||||
_startup_funcs: List[Callable] = []
|
|
||||||
_shutdown_funcs: List[Callable] = []
|
|
||||||
|
|
||||||
def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None:
|
|
||||||
self._running = False
|
|
||||||
self.managers = managers
|
|
||||||
self.telegram = telegram
|
|
||||||
self.web_server = web_server
|
|
||||||
self.managers.set_application(application=self) # 给 managers 设置 application
|
|
||||||
self.managers.build_executor("Application")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(cls):
|
|
||||||
managers = Managers()
|
|
||||||
telegram = (
|
|
||||||
TelegramApplicationBuilder()
|
|
||||||
.get_updates_read_timeout(application_config.update_read_timeout)
|
|
||||||
.get_updates_write_timeout(application_config.update_write_timeout)
|
|
||||||
.get_updates_connect_timeout(application_config.update_connect_timeout)
|
|
||||||
.get_updates_pool_timeout(application_config.update_pool_timeout)
|
|
||||||
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
|
|
||||||
.token(application_config.bot_token)
|
|
||||||
.request(
|
|
||||||
HTTPXRequest(
|
|
||||||
connection_pool_size=application_config.connection_pool_size,
|
|
||||||
proxy_url=application_config.proxy_url,
|
|
||||||
read_timeout=application_config.read_timeout,
|
|
||||||
write_timeout=application_config.write_timeout,
|
|
||||||
connect_timeout=application_config.connect_timeout,
|
|
||||||
pool_timeout=application_config.pool_timeout,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.rate_limiter(RateLimiter())
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
web_server = Server(
|
|
||||||
uvicorn.Config(
|
|
||||||
app=FastAPI(debug=application_config.debug),
|
|
||||||
port=application_config.webserver.port,
|
|
||||||
host=application_config.webserver.host,
|
|
||||||
log_config=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return cls(managers, telegram, web_server)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def running(self) -> bool:
|
|
||||||
"""bot 是否正在运行"""
|
|
||||||
with self._lock:
|
|
||||||
return self._running
|
|
||||||
|
|
||||||
@property
|
|
||||||
def web_app(self) -> FastAPI:
|
|
||||||
"""fastapi app"""
|
|
||||||
return self.web_server.config.app
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bot(self) -> Optional[Bot]:
|
|
||||||
return self.telegram.bot
|
|
||||||
|
|
||||||
@property
|
|
||||||
def job_queue(self) -> Optional[JobQueue]:
|
|
||||||
return self.telegram.job_queue
|
|
||||||
|
|
||||||
async def _on_startup(self) -> None:
|
|
||||||
for func in self._startup_funcs:
|
|
||||||
await self.managers.executor(func, block=getattr(func, "block", False))
|
|
||||||
|
|
||||||
async def _on_shutdown(self) -> None:
|
|
||||||
for func in self._shutdown_funcs:
|
|
||||||
await self.managers.executor(func, block=getattr(func, "block", False))
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
"""BOT 初始化"""
|
|
||||||
self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制
|
|
||||||
await self.managers.start_dependency() # 启动基础服务
|
|
||||||
await self.managers.init_components() # 实例化组件
|
|
||||||
await self.managers.start_services() # 启动其他服务
|
|
||||||
await self.managers.install_plugins() # 安装插件
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
"""BOT 关闭"""
|
|
||||||
await self.managers.uninstall_plugins() # 卸载插件
|
|
||||||
await self.managers.stop_services() # 终止其他服务
|
|
||||||
await self.managers.stop_dependency() # 终止基础服务
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
"""启动 BOT"""
|
|
||||||
logger.info("正在启动 BOT 中...")
|
|
||||||
|
|
||||||
def error_callback(exc: TelegramError) -> None:
|
|
||||||
"""错误信息回调"""
|
|
||||||
self.telegram.create_task(self.telegram.process_error(error=exc, update=None))
|
|
||||||
|
|
||||||
await self.telegram.initialize()
|
|
||||||
logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True})
|
|
||||||
|
|
||||||
if application_config.webserver.enable: # 如果使用 web app
|
|
||||||
server_config = self.web_server.config
|
|
||||||
server_config.setup_event_loop()
|
|
||||||
if not server_config.loaded:
|
|
||||||
server_config.load()
|
|
||||||
self.web_server.lifespan = server_config.lifespan_class(server_config)
|
|
||||||
try:
|
|
||||||
await self.web_server.startup()
|
|
||||||
except OSError as e:
|
|
||||||
if e.errno == 10048:
|
|
||||||
logger.error("Web Server 端口被占用:%s", e)
|
|
||||||
logger.error("Web Server 启动失败,正在退出")
|
|
||||||
raise SystemExit from None
|
|
||||||
|
|
||||||
if self.web_server.should_exit:
|
|
||||||
logger.error("Web Server 启动失败,正在退出")
|
|
||||||
raise SystemExit from None
|
|
||||||
logger.success("Web Server 启动成功")
|
|
||||||
|
|
||||||
self._web_server_task = asyncio.create_task(self.web_server.main_loop())
|
|
||||||
|
|
||||||
for _ in range(5): # 连接至 telegram 服务器
|
|
||||||
try:
|
|
||||||
await self.telegram.updater.start_polling(
|
|
||||||
error_callback=error_callback, allowed_updates=Update.ALL_TYPES
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except TimedOut:
|
|
||||||
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
|
|
||||||
continue
|
|
||||||
except NetworkError as e:
|
|
||||||
logger.exception()
|
|
||||||
if isinstance(e, SSLZeroReturnError):
|
|
||||||
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
|
||||||
else:
|
|
||||||
logger.error("网络连接出现问题, 请检查您的网络状况.")
|
|
||||||
raise SystemExit from e
|
|
||||||
|
|
||||||
await self.initialize()
|
|
||||||
logger.success("BOT 初始化成功")
|
|
||||||
logger.debug("BOT 开始启动")
|
|
||||||
|
|
||||||
await self._on_startup()
|
|
||||||
await self.telegram.start()
|
|
||||||
self._running = True
|
|
||||||
logger.success("BOT 启动成功")
|
|
||||||
|
|
||||||
def stop_signal_handler(self, signum: int):
|
|
||||||
"""终止信号处理"""
|
|
||||||
signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")}
|
|
||||||
logger.debug("接收到了终止信号 %s 正在退出...", signals[signum])
|
|
||||||
if self._web_server_task:
|
|
||||||
self._web_server_task.cancel()
|
|
||||||
|
|
||||||
async def idle(self) -> None:
|
|
||||||
"""在接收到中止信号之前,堵塞loop"""
|
|
||||||
|
|
||||||
task = None
|
|
||||||
|
|
||||||
def stop_handler(signum: int, _: "FrameType") -> None:
|
|
||||||
self.stop_signal_handler(signum)
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
for s in (SIGINT, SIGTERM, SIGABRT):
|
|
||||||
signal_func(s, stop_handler)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
task = asyncio.create_task(asyncio.sleep(600))
|
|
||||||
|
|
||||||
try:
|
|
||||||
await task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
"""关闭"""
|
|
||||||
logger.info("BOT 正在关闭")
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
await self._on_shutdown()
|
|
||||||
|
|
||||||
if self.telegram.updater.running:
|
|
||||||
await self.telegram.updater.stop()
|
|
||||||
|
|
||||||
await self.shutdown()
|
|
||||||
|
|
||||||
if self.telegram.running:
|
|
||||||
await self.telegram.stop()
|
|
||||||
|
|
||||||
await self.telegram.shutdown()
|
|
||||||
if self.web_server is not None:
|
|
||||||
try:
|
|
||||||
await self.web_server.shutdown()
|
|
||||||
logger.info("Web Server 已经关闭")
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.success("BOT 关闭成功")
|
|
||||||
|
|
||||||
def launch(self) -> None:
|
|
||||||
"""启动"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(self.start())
|
|
||||||
loop.run_until_complete(self.idle())
|
|
||||||
except (SystemExit, KeyboardInterrupt) as exc:
|
|
||||||
logger.debug("接收到了终止信号,BOT 即将关闭", exc_info=exc) # 接收到了终止信号
|
|
||||||
except NetworkError as e:
|
|
||||||
if isinstance(e, SSLZeroReturnError):
|
|
||||||
logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
|
||||||
else:
|
|
||||||
logger.critical("网络连接出现问题, 请检查您的网络状况.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e)
|
|
||||||
finally:
|
|
||||||
loop.run_until_complete(self.stop())
|
|
||||||
|
|
||||||
if application_config.reload:
|
|
||||||
raise SystemExit from None
|
|
||||||
|
|
||||||
def on_startup(self, func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
"""注册一个在 BOT 启动时执行的函数"""
|
|
||||||
|
|
||||||
if func not in self._startup_funcs:
|
|
||||||
self._startup_funcs.append(func)
|
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
"""注册一个在 BOT 停止时执行的函数"""
|
|
||||||
|
|
||||||
if func not in self._shutdown_funcs:
|
|
||||||
self._shutdown_funcs.append(func)
|
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
@ -1,60 +1,3 @@
|
|||||||
from abc import ABC
|
from gram_core.base_service import BaseService, BaseServiceType, DependenceType, ComponentType, get_all_services
|
||||||
from itertools import chain
|
|
||||||
from typing import ClassVar, Iterable, Type, TypeVar
|
|
||||||
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from utils.helpers import isabstract
|
|
||||||
|
|
||||||
__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services")
|
__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services")
|
||||||
|
|
||||||
|
|
||||||
class _BaseService:
|
|
||||||
"""服务基类"""
|
|
||||||
|
|
||||||
_is_component: ClassVar[bool] = False
|
|
||||||
_is_dependence: ClassVar[bool] = False
|
|
||||||
|
|
||||||
def __init_subclass__(cls, load: bool = True, **kwargs):
|
|
||||||
cls.is_dependence = cls._is_dependence
|
|
||||||
cls.is_component = cls._is_component
|
|
||||||
cls.load = load
|
|
||||||
|
|
||||||
async def __aenter__(self) -> Self:
|
|
||||||
await self.initialize()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
||||||
await self.shutdown()
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
"""Initialize resources used by this service"""
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
"""Stop & clear resources used by this service"""
|
|
||||||
|
|
||||||
|
|
||||||
class _Dependence(_BaseService, ABC):
|
|
||||||
_is_dependence: ClassVar[bool] = True
|
|
||||||
|
|
||||||
|
|
||||||
class _Component(_BaseService, ABC):
|
|
||||||
_is_component: ClassVar[bool] = True
|
|
||||||
|
|
||||||
|
|
||||||
class BaseService(_BaseService, ABC):
|
|
||||||
Dependence: Type[_BaseService] = _Dependence
|
|
||||||
Component: Type[_BaseService] = _Component
|
|
||||||
|
|
||||||
|
|
||||||
BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService)
|
|
||||||
DependenceType = TypeVar("DependenceType", bound=_Dependence)
|
|
||||||
ComponentType = TypeVar("ComponentType", bound=_Component)
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
def get_all_services() -> Iterable[Type[_BaseService]]:
|
|
||||||
return filter(
|
|
||||||
lambda x: x.__name__[0] != "_" and x.load and not isabstract(x),
|
|
||||||
chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()),
|
|
||||||
)
|
|
||||||
|
@ -1,29 +1,3 @@
|
|||||||
import enum
|
from gram_core.basemodel import RegionEnum, Settings
|
||||||
|
|
||||||
try:
|
|
||||||
import ujson as jsonlib
|
|
||||||
except ImportError:
|
|
||||||
import json as jsonlib
|
|
||||||
|
|
||||||
from pydantic import BaseSettings
|
|
||||||
|
|
||||||
__all__ = ("RegionEnum", "Settings")
|
__all__ = ("RegionEnum", "Settings")
|
||||||
|
|
||||||
|
|
||||||
class RegionEnum(int, enum.Enum):
|
|
||||||
"""账号数据所在服务器"""
|
|
||||||
|
|
||||||
NULL = 0
|
|
||||||
HYPERION = 1 # 米忽悠国服 hyperion
|
|
||||||
HOYOLAB = 2 # 米忽悠国际服 hoyolab
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
cls.update_forward_refs()
|
|
||||||
return super(Settings, cls).__new__(cls) # pylint: disable=E1120
|
|
||||||
|
|
||||||
class Config(BaseSettings.Config):
|
|
||||||
case_sensitive = False
|
|
||||||
json_loads = jsonlib.loads
|
|
||||||
json_dumps = jsonlib.dumps
|
|
||||||
|
@ -1 +0,0 @@
|
|||||||
"""bot builtins"""
|
|
@ -1,38 +0,0 @@
|
|||||||
"""上下文管理"""
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from telegram.ext import CallbackContext
|
|
||||||
from telegram import Update
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"CallbackContextCV",
|
|
||||||
"UpdateCV",
|
|
||||||
"handler_contexts",
|
|
||||||
"job_contexts",
|
|
||||||
]
|
|
||||||
|
|
||||||
CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback")
|
|
||||||
UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate")
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def handler_contexts(update: "Update", context: "CallbackContext") -> None:
|
|
||||||
context_token = CallbackContextCV.set(context)
|
|
||||||
update_token = UpdateCV.set(update)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
CallbackContextCV.reset(context_token)
|
|
||||||
UpdateCV.reset(update_token)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def job_contexts(context: "CallbackContext") -> None:
|
|
||||||
token = CallbackContextCV.set(context)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
CallbackContextCV.reset(token)
|
|
@ -1,309 +0,0 @@
|
|||||||
"""参数分发器"""
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from asyncio import AbstractEventLoop
|
|
||||||
from functools import cached_property, lru_cache, partial, wraps
|
|
||||||
from inspect import Parameter, Signature
|
|
||||||
from itertools import chain
|
|
||||||
from types import GenericAlias, MethodType
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from arkowrapper import ArkoWrapper
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from telegram import Bot as TelegramBot, Chat, Message, Update, User
|
|
||||||
from telegram.ext import Application as TelegramApplication, CallbackContext, Job
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
from uvicorn import Server
|
|
||||||
|
|
||||||
from core.application import Application
|
|
||||||
from utils.const import WRAPPER_ASSIGNMENTS
|
|
||||||
from utils.typedefs import R, T
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"catch",
|
|
||||||
"AbstractDispatcher",
|
|
||||||
"BaseDispatcher",
|
|
||||||
"HandlerDispatcher",
|
|
||||||
"JobDispatcher",
|
|
||||||
"dispatched",
|
|
||||||
)
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
TargetType = Union[Type, str, Callable[[Any], bool]]
|
|
||||||
|
|
||||||
_CATCH_TARGET_ATTR = "_catch_targets"
|
|
||||||
|
|
||||||
|
|
||||||
def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
||||||
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
setattr(func, _CATCH_TARGET_ATTR, targets)
|
|
||||||
|
|
||||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorate
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(64)
|
|
||||||
def get_signature(func: Union[type, Callable]) -> Signature:
|
|
||||||
if isinstance(func, type):
|
|
||||||
return inspect.signature(func.__init__)
|
|
||||||
return inspect.signature(func)
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractDispatcher(ABC):
|
|
||||||
"""参数分发器"""
|
|
||||||
|
|
||||||
IGNORED_ATTRS = []
|
|
||||||
|
|
||||||
_args: List[Any] = []
|
|
||||||
_kwargs: Dict[Union[str, Type], Any] = {}
|
|
||||||
_application: "Optional[Application]" = None
|
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
|
||||||
self._application = application
|
|
||||||
|
|
||||||
@property
|
|
||||||
def application(self) -> "Application":
|
|
||||||
if self._application is None:
|
|
||||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
|
||||||
return self._application
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
self._args = list(args)
|
|
||||||
self._kwargs = dict(kwargs)
|
|
||||||
|
|
||||||
for _, value in kwargs.items():
|
|
||||||
type_arg = type(value)
|
|
||||||
if type_arg != str:
|
|
||||||
self._kwargs[type_arg] = value
|
|
||||||
|
|
||||||
for arg in args:
|
|
||||||
type_arg = type(arg)
|
|
||||||
if type_arg != str:
|
|
||||||
self._kwargs[type_arg] = arg
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def catch_funcs(self) -> List[MethodType]:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
return list(
|
|
||||||
ArkoWrapper(dir(self))
|
|
||||||
.filter(lambda x: not x.startswith("_"))
|
|
||||||
.filter(
|
|
||||||
lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"]
|
|
||||||
)
|
|
||||||
.map(lambda x: getattr(self, x))
|
|
||||||
.filter(lambda x: isinstance(x, MethodType))
|
|
||||||
.filter(lambda x: hasattr(x, "_catch_targets"))
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]:
|
|
||||||
result = {}
|
|
||||||
for catch_func in self.catch_funcs:
|
|
||||||
catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR)
|
|
||||||
for catch_target in catch_targets:
|
|
||||||
result[catch_target] = catch_func
|
|
||||||
return result
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def dispatch_funcs(self) -> List[MethodType]:
|
|
||||||
return list(
|
|
||||||
ArkoWrapper(dir(self))
|
|
||||||
.filter(lambda x: x.startswith("dispatch_by_"))
|
|
||||||
.map(lambda x: getattr(self, x))
|
|
||||||
.filter(lambda x: isinstance(x, MethodType))
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
|
||||||
"""默认的 dispatch 方法"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
|
||||||
"""使用 catch_func 获取并分配参数"""
|
|
||||||
|
|
||||||
def dispatch(self, func: Callable[P, R]) -> Callable[..., R]:
|
|
||||||
"""将参数分配给函数,从而合成一个无需参数即可执行的函数"""
|
|
||||||
params = {}
|
|
||||||
signature = get_signature(func)
|
|
||||||
parameters: Dict[str, Parameter] = dict(signature.parameters)
|
|
||||||
|
|
||||||
for name, parameter in list(parameters.items()):
|
|
||||||
parameter: Parameter
|
|
||||||
if any(
|
|
||||||
[
|
|
||||||
name == "self" and isinstance(func, (type, MethodType)),
|
|
||||||
parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL],
|
|
||||||
]
|
|
||||||
):
|
|
||||||
del parameters[name]
|
|
||||||
continue
|
|
||||||
|
|
||||||
for dispatch_func in self.dispatch_funcs:
|
|
||||||
parameters[name] = dispatch_func(parameter)
|
|
||||||
|
|
||||||
for name, parameter in parameters.items():
|
|
||||||
if parameter.default != Parameter.empty:
|
|
||||||
params[name] = parameter.default
|
|
||||||
else:
|
|
||||||
params[name] = None
|
|
||||||
|
|
||||||
return partial(func, **params)
|
|
||||||
|
|
||||||
@catch(Application)
|
|
||||||
def catch_application(self) -> Application:
|
|
||||||
return self.application
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDispatcher(AbstractDispatcher):
|
|
||||||
"""默认参数分发器"""
|
|
||||||
|
|
||||||
_instances: Sequence[Any]
|
|
||||||
|
|
||||||
def _get_kwargs(self) -> Dict[Type[T], T]:
|
|
||||||
result = self._get_default_kwargs()
|
|
||||||
result[AbstractDispatcher] = self
|
|
||||||
result.update(self._kwargs)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _get_default_kwargs(self) -> Dict[Type[T], T]:
|
|
||||||
application = self.application
|
|
||||||
_default_kwargs = {
|
|
||||||
FastAPI: application.web_app,
|
|
||||||
Server: application.web_server,
|
|
||||||
TelegramApplication: application.telegram,
|
|
||||||
TelegramBot: application.telegram.bot,
|
|
||||||
}
|
|
||||||
if not application.running:
|
|
||||||
for obj in chain(
|
|
||||||
application.managers.dependency,
|
|
||||||
application.managers.components,
|
|
||||||
application.managers.services,
|
|
||||||
application.managers.plugins,
|
|
||||||
):
|
|
||||||
_default_kwargs[type(obj)] = obj
|
|
||||||
return {k: v for k, v in _default_kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
|
||||||
annotation = parameter.annotation
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None:
|
|
||||||
parameter._default = value # pylint: disable=W0212
|
|
||||||
return parameter
|
|
||||||
|
|
||||||
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
|
||||||
annotation = parameter.annotation
|
|
||||||
if annotation != Any and isinstance(annotation, GenericAlias):
|
|
||||||
return parameter
|
|
||||||
|
|
||||||
catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name)
|
|
||||||
if catch_func is not None:
|
|
||||||
# noinspection PyUnresolvedReferences,PyProtectedMember
|
|
||||||
parameter._default = catch_func() # pylint: disable=W0212
|
|
||||||
return parameter
|
|
||||||
|
|
||||||
@catch(AbstractEventLoop)
|
|
||||||
def catch_loop(self) -> AbstractEventLoop:
|
|
||||||
return asyncio.get_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
class HandlerDispatcher(BaseDispatcher):
|
|
||||||
"""Handler 参数分发器"""
|
|
||||||
|
|
||||||
def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
|
||||||
super().__init__(update=update, context=context, **kwargs)
|
|
||||||
self._update = update
|
|
||||||
self._context = context
|
|
||||||
|
|
||||||
def dispatch(
|
|
||||||
self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None
|
|
||||||
) -> Callable[..., R]:
|
|
||||||
self._update = update or self._update
|
|
||||||
self._context = context or self._context
|
|
||||||
if self._update is None:
|
|
||||||
from core.builtins.contexts import UpdateCV
|
|
||||||
|
|
||||||
self._update = UpdateCV.get()
|
|
||||||
if self._context is None:
|
|
||||||
from core.builtins.contexts import CallbackContextCV
|
|
||||||
|
|
||||||
self._context = CallbackContextCV.get()
|
|
||||||
return super().dispatch(func)
|
|
||||||
|
|
||||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
|
||||||
"""HandlerDispatcher 默认不使用 dispatch_by_default"""
|
|
||||||
return parameter
|
|
||||||
|
|
||||||
@catch(Update)
|
|
||||||
def catch_update(self) -> Update:
|
|
||||||
return self._update
|
|
||||||
|
|
||||||
@catch(CallbackContext)
|
|
||||||
def catch_context(self) -> CallbackContext:
|
|
||||||
return self._context
|
|
||||||
|
|
||||||
@catch(Message)
|
|
||||||
def catch_message(self) -> Message:
|
|
||||||
return self._update.effective_message
|
|
||||||
|
|
||||||
@catch(User)
|
|
||||||
def catch_user(self) -> User:
|
|
||||||
return self._update.effective_user
|
|
||||||
|
|
||||||
@catch(Chat)
|
|
||||||
def catch_chat(self) -> Chat:
|
|
||||||
return self._update.effective_chat
|
|
||||||
|
|
||||||
|
|
||||||
class JobDispatcher(BaseDispatcher):
|
|
||||||
"""Job 参数分发器"""
|
|
||||||
|
|
||||||
def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
|
||||||
super().__init__(context=context, **kwargs)
|
|
||||||
self._context = context
|
|
||||||
|
|
||||||
def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]:
|
|
||||||
self._context = context or self._context
|
|
||||||
if self._context is None:
|
|
||||||
from core.builtins.contexts import CallbackContextCV
|
|
||||||
|
|
||||||
self._context = CallbackContextCV.get()
|
|
||||||
return super().dispatch(func)
|
|
||||||
|
|
||||||
@catch("data")
|
|
||||||
def catch_data(self) -> Any:
|
|
||||||
return self._context.job.data
|
|
||||||
|
|
||||||
@catch(Job)
|
|
||||||
def catch_job(self) -> Job:
|
|
||||||
return self._context.job
|
|
||||||
|
|
||||||
@catch(CallbackContext)
|
|
||||||
def catch_context(self) -> CallbackContext:
|
|
||||||
return self._context
|
|
||||||
|
|
||||||
|
|
||||||
def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher):
|
|
||||||
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
return dispatcher().dispatch(func)(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorate
|
|
@ -1,131 +0,0 @@
|
|||||||
"""执行器"""
|
|
||||||
import inspect
|
|
||||||
from functools import cached_property
|
|
||||||
from multiprocessing import RLock as Lock
|
|
||||||
from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar
|
|
||||||
|
|
||||||
from telegram import Update
|
|
||||||
from telegram.ext import CallbackContext
|
|
||||||
from typing_extensions import ParamSpec, Self
|
|
||||||
|
|
||||||
from core.builtins.contexts import handler_contexts, job_contexts
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.application import Application
|
|
||||||
from core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher
|
|
||||||
from multiprocessing.synchronize import RLock as LockType
|
|
||||||
|
|
||||||
__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor")
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
R = TypeVar("R")
|
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseExecutor:
|
|
||||||
"""执行器
|
|
||||||
Args:
|
|
||||||
name(str): 该执行器的名称。执行器的名称是唯一的。
|
|
||||||
|
|
||||||
只支持执行只拥有 POSITIONAL_OR_KEYWORD 和 KEYWORD_ONLY 两种参数类型的函数
|
|
||||||
"""
|
|
||||||
|
|
||||||
_lock: ClassVar["LockType"] = Lock()
|
|
||||||
_instances: ClassVar[Dict[str, Self]] = {}
|
|
||||||
_application: "Optional[Application]" = None
|
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
|
||||||
self._application = application
|
|
||||||
|
|
||||||
@property
|
|
||||||
def application(self) -> "Application":
|
|
||||||
if self._application is None:
|
|
||||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
|
||||||
return self._application
|
|
||||||
|
|
||||||
def __new__(cls: Type[T], name: str, *args, **kwargs) -> T:
|
|
||||||
with cls._lock:
|
|
||||||
if (instance := cls._instances.get(name)) is None:
|
|
||||||
instance = object.__new__(cls)
|
|
||||||
instance.__init__(name, *args, **kwargs)
|
|
||||||
cls._instances.update({name: instance})
|
|
||||||
return instance
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def name(self) -> str:
|
|
||||||
"""当前执行器的名称"""
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
|
|
||||||
self._name = name
|
|
||||||
self._dispatcher = dispatcher
|
|
||||||
|
|
||||||
|
|
||||||
class Executor(BaseExecutor, Generic[P, R]):
|
|
||||||
async def __call__(
|
|
||||||
self,
|
|
||||||
target: Callable[P, R],
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> R:
|
|
||||||
dispatcher = self._dispatcher or dispatcher
|
|
||||||
dispatcher_instance = dispatcher(**kwargs)
|
|
||||||
dispatcher_instance.set_application(application=self.application)
|
|
||||||
dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数
|
|
||||||
|
|
||||||
# 执行
|
|
||||||
if inspect.iscoroutinefunction(target):
|
|
||||||
result = await dispatched_func()
|
|
||||||
else:
|
|
||||||
result = dispatched_func()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class HandlerExecutor(BaseExecutor, Generic[P, R]):
|
|
||||||
"""Handler专用执行器"""
|
|
||||||
|
|
||||||
_callback: Callable[P, R]
|
|
||||||
_dispatcher: "HandlerDispatcher"
|
|
||||||
|
|
||||||
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None:
|
|
||||||
if dispatcher is None:
|
|
||||||
from core.builtins.dispatcher import HandlerDispatcher
|
|
||||||
|
|
||||||
dispatcher = HandlerDispatcher
|
|
||||||
super().__init__("handler", dispatcher)
|
|
||||||
self._callback = func
|
|
||||||
self._dispatcher = dispatcher()
|
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
|
||||||
self._application = application
|
|
||||||
if self._dispatcher is not None:
|
|
||||||
self._dispatcher.set_application(application)
|
|
||||||
|
|
||||||
async def __call__(self, update: Update, context: CallbackContext) -> R:
|
|
||||||
with handler_contexts(update, context):
|
|
||||||
dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context)
|
|
||||||
return await dispatched_func()
|
|
||||||
|
|
||||||
|
|
||||||
class JobExecutor(BaseExecutor):
|
|
||||||
"""Job 专用执行器"""
|
|
||||||
|
|
||||||
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
|
|
||||||
if dispatcher is None:
|
|
||||||
from core.builtins.dispatcher import JobDispatcher
|
|
||||||
|
|
||||||
dispatcher = JobDispatcher
|
|
||||||
super().__init__("job", dispatcher)
|
|
||||||
self._callback = func
|
|
||||||
self._dispatcher = dispatcher()
|
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
|
||||||
self._application = application
|
|
||||||
if self._dispatcher is not None:
|
|
||||||
self._dispatcher.set_application(application)
|
|
||||||
|
|
||||||
async def __call__(self, context: CallbackContext) -> R:
|
|
||||||
with job_contexts(context):
|
|
||||||
dispatched_func = self._dispatcher.dispatch(self._callback, context=context)
|
|
||||||
return await dispatched_func()
|
|
@ -1,185 +0,0 @@
|
|||||||
import inspect
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
|
|
||||||
|
|
||||||
from watchfiles import watch
|
|
||||||
|
|
||||||
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
|
|
||||||
from utils.log import logger
|
|
||||||
from utils.typedefs import StrOrPath
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from multiprocessing.process import BaseProcess
|
|
||||||
|
|
||||||
__all__ = ("Reloader",)
|
|
||||||
|
|
||||||
multiprocessing.allow_connection_pickling()
|
|
||||||
spawn = multiprocessing.get_context("spawn")
|
|
||||||
|
|
||||||
|
|
||||||
class FileFilter:
|
|
||||||
"""监控文件过滤"""
|
|
||||||
|
|
||||||
def __init__(self, includes: List[str], excludes: List[str]) -> None:
|
|
||||||
default_includes = ["*.py"]
|
|
||||||
self.includes = [default for default in default_includes if default not in excludes]
|
|
||||||
self.includes.extend(includes)
|
|
||||||
self.includes = list(set(self.includes))
|
|
||||||
|
|
||||||
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
|
|
||||||
self.excludes = [default for default in default_excludes if default not in includes]
|
|
||||||
self.exclude_dirs = []
|
|
||||||
for e in excludes:
|
|
||||||
p = Path(e)
|
|
||||||
try:
|
|
||||||
is_dir = p.is_dir()
|
|
||||||
except OSError:
|
|
||||||
is_dir = False
|
|
||||||
|
|
||||||
if is_dir:
|
|
||||||
self.exclude_dirs.append(p)
|
|
||||||
else:
|
|
||||||
self.excludes.append(e)
|
|
||||||
self.excludes = list(set(self.excludes))
|
|
||||||
|
|
||||||
def __call__(self, path: Path) -> bool:
|
|
||||||
for include_pattern in self.includes:
|
|
||||||
if path.match(include_pattern):
|
|
||||||
for exclude_dir in self.exclude_dirs:
|
|
||||||
if exclude_dir in path.parents:
|
|
||||||
return False
|
|
||||||
|
|
||||||
for exclude_pattern in self.excludes:
|
|
||||||
if path.match(exclude_pattern):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class Reloader:
|
|
||||||
_target: Callable[..., None]
|
|
||||||
_process: "BaseProcess"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def process(self) -> "BaseProcess":
|
|
||||||
return self._process
|
|
||||||
|
|
||||||
@property
|
|
||||||
def target(self) -> Callable[..., None]:
|
|
||||||
return self._target
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
target: Callable[..., None],
|
|
||||||
*,
|
|
||||||
reload_delay: float = 0.25,
|
|
||||||
reload_dirs: List[StrOrPath] = None,
|
|
||||||
reload_includes: List[str] = None,
|
|
||||||
reload_excludes: List[str] = None,
|
|
||||||
):
|
|
||||||
if inspect.iscoroutinefunction(target):
|
|
||||||
raise ValueError("不支持异步函数")
|
|
||||||
self._target = target
|
|
||||||
|
|
||||||
self.reload_delay = reload_delay
|
|
||||||
|
|
||||||
_reload_dirs = []
|
|
||||||
for reload_dir in reload_dirs or []:
|
|
||||||
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
|
|
||||||
|
|
||||||
self.reload_dirs = []
|
|
||||||
for reload_dir in _reload_dirs:
|
|
||||||
append = True
|
|
||||||
for parent in reload_dir.parents:
|
|
||||||
if parent in _reload_dirs:
|
|
||||||
append = False
|
|
||||||
break
|
|
||||||
if append:
|
|
||||||
self.reload_dirs.append(reload_dir)
|
|
||||||
|
|
||||||
if not self.reload_dirs:
|
|
||||||
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
|
|
||||||
|
|
||||||
self._should_exit = threading.Event()
|
|
||||||
|
|
||||||
frame = inspect.currentframe().f_back
|
|
||||||
|
|
||||||
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
|
|
||||||
self.watcher = watch(
|
|
||||||
*self.reload_dirs,
|
|
||||||
watch_filter=None,
|
|
||||||
stop_event=self._should_exit,
|
|
||||||
yield_on_timeout=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_changes(self) -> Optional[List[Path]]:
|
|
||||||
if not self._process.is_alive():
|
|
||||||
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
|
|
||||||
self._should_exit.set()
|
|
||||||
try:
|
|
||||||
changes = next(self.watcher)
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
if changes:
|
|
||||||
unique_paths = {Path(c[1]) for c in changes}
|
|
||||||
return [p for p in unique_paths if self.watch_filter(p)]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Optional[List[Path]]]:
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self) -> Optional[List[Path]]:
|
|
||||||
return self.get_changes()
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
self.startup()
|
|
||||||
for changes in self:
|
|
||||||
if changes:
|
|
||||||
logger.warning(
|
|
||||||
"检测到文件 %s 发生改变, 正在重载...",
|
|
||||||
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
|
|
||||||
extra={"tag": "Reloader"},
|
|
||||||
)
|
|
||||||
self.restart()
|
|
||||||
|
|
||||||
self.shutdown()
|
|
||||||
|
|
||||||
def signal_handler(self, *_) -> None:
|
|
||||||
"""当接收到结束信号量时"""
|
|
||||||
self._process.join(3)
|
|
||||||
if self._process.is_alive():
|
|
||||||
self._process.terminate()
|
|
||||||
self._process.join()
|
|
||||||
self._should_exit.set()
|
|
||||||
|
|
||||||
def startup(self) -> None:
|
|
||||||
"""启动进程"""
|
|
||||||
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
|
|
||||||
|
|
||||||
for sig in HANDLED_SIGNALS:
|
|
||||||
signal.signal(sig, self.signal_handler)
|
|
||||||
|
|
||||||
self._process = spawn.Process(target=self._target)
|
|
||||||
self._process.start()
|
|
||||||
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
|
|
||||||
|
|
||||||
def restart(self) -> None:
|
|
||||||
"""重启进程"""
|
|
||||||
self._process.terminate()
|
|
||||||
self._process.join(10)
|
|
||||||
|
|
||||||
self._process = spawn.Process(target=self._target)
|
|
||||||
self._process.start()
|
|
||||||
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
|
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
|
||||||
"""关闭进程"""
|
|
||||||
self._process.terminate()
|
|
||||||
self._process.join(10)
|
|
||||||
|
|
||||||
logger.info("重载器已经关闭", extra={"tag": "Reloader"})
|
|
160
core/config.py
160
core/config.py
@ -1,161 +1,3 @@
|
|||||||
from enum import Enum
|
from gram_core.config import ApplicationConfig, config, JoinGroups
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import dotenv
|
|
||||||
from pydantic import AnyUrl, Field
|
|
||||||
|
|
||||||
from core.basemodel import Settings
|
|
||||||
from utils.const import PROJECT_ROOT
|
|
||||||
from utils.typedefs import NaturalNumber
|
|
||||||
|
|
||||||
__all__ = ("ApplicationConfig", "config", "JoinGroups")
|
__all__ = ("ApplicationConfig", "config", "JoinGroups")
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
class JoinGroups(str, Enum):
|
|
||||||
NO_ALLOW = "NO_ALLOW"
|
|
||||||
ALLOW_AUTH_USER = "ALLOW_AUTH_USER"
|
|
||||||
ALLOW_USER = "ALLOW_USER"
|
|
||||||
ALLOW_ALL = "ALLOW_ALL"
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseConfig(Settings):
|
|
||||||
driver_name: str = "mysql+asyncmy"
|
|
||||||
host: Optional[str] = None
|
|
||||||
port: Optional[int] = None
|
|
||||||
username: Optional[str] = None
|
|
||||||
password: Optional[str] = None
|
|
||||||
database: Optional[str] = None
|
|
||||||
|
|
||||||
class Config(Settings.Config):
|
|
||||||
env_prefix = "db_"
|
|
||||||
|
|
||||||
|
|
||||||
class RedisConfig(Settings):
|
|
||||||
host: str = "127.0.0.1"
|
|
||||||
port: int = 6379
|
|
||||||
database: int = Field(default=0, env="redis_db")
|
|
||||||
password: Optional[str] = None
|
|
||||||
|
|
||||||
class Config(Settings.Config):
|
|
||||||
env_prefix = "redis_"
|
|
||||||
|
|
||||||
|
|
||||||
class LoggerConfig(Settings):
|
|
||||||
name: str = "PaiGram"
|
|
||||||
width: Optional[int] = None
|
|
||||||
time_format: str = "[%Y-%m-%d %X]"
|
|
||||||
traceback_max_frames: int = 20
|
|
||||||
path: Path = PROJECT_ROOT / "logs"
|
|
||||||
render_keywords: List[str] = ["BOT"]
|
|
||||||
locals_max_length: int = 10
|
|
||||||
locals_max_string: int = 80
|
|
||||||
locals_max_depth: Optional[NaturalNumber] = None
|
|
||||||
filtered_names: List[str] = ["uvicorn"]
|
|
||||||
|
|
||||||
class Config(Settings.Config):
|
|
||||||
env_prefix = "logger_"
|
|
||||||
|
|
||||||
|
|
||||||
class MTProtoConfig(Settings):
|
|
||||||
api_id: Optional[int] = None
|
|
||||||
api_hash: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class WebServerConfig(Settings):
|
|
||||||
enable: bool = False
|
|
||||||
"""是否启用WebServer"""
|
|
||||||
|
|
||||||
url: AnyUrl = "http://localhost:8080"
|
|
||||||
host: str = "localhost"
|
|
||||||
port: int = 8080
|
|
||||||
|
|
||||||
class Config(Settings.Config):
|
|
||||||
env_prefix = "web_"
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorConfig(Settings):
|
|
||||||
pb_url: str = ""
|
|
||||||
pb_sunset: int = 43200
|
|
||||||
pb_max_lines: int = 1000
|
|
||||||
sentry_dsn: str = ""
|
|
||||||
notification_chat_id: Optional[str] = None
|
|
||||||
|
|
||||||
class Config(Settings.Config):
|
|
||||||
env_prefix = "error_"
|
|
||||||
|
|
||||||
|
|
||||||
class ReloadConfig(Settings):
|
|
||||||
delay: float = 0.25
|
|
||||||
dirs: List[str] = []
|
|
||||||
include: List[str] = []
|
|
||||||
exclude: List[str] = []
|
|
||||||
|
|
||||||
class Config(Settings.Config):
|
|
||||||
env_prefix = "reload_"
|
|
||||||
|
|
||||||
|
|
||||||
class NoticeConfig(Settings):
|
|
||||||
user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!"
|
|
||||||
|
|
||||||
class Config(Settings.Config):
|
|
||||||
env_prefix = "notice_"
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationConfig(Settings):
|
|
||||||
debug: bool = False
|
|
||||||
"""debug 开关"""
|
|
||||||
retry: int = 5
|
|
||||||
"""重试次数"""
|
|
||||||
auto_reload: bool = False
|
|
||||||
"""自动重载"""
|
|
||||||
|
|
||||||
proxy_url: Optional[AnyUrl] = None
|
|
||||||
"""代理链接"""
|
|
||||||
upload_bbs_host: Optional[AnyUrl] = "https://upload-bbs.miyoushe.com"
|
|
||||||
|
|
||||||
bot_token: str = ""
|
|
||||||
"""BOT的token"""
|
|
||||||
|
|
||||||
owner: Optional[int] = None
|
|
||||||
|
|
||||||
channels: List[int] = []
|
|
||||||
"""文章推送群组"""
|
|
||||||
|
|
||||||
verify_groups: List[Union[int, str]] = []
|
|
||||||
"""启用群验证功能的群组"""
|
|
||||||
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
|
|
||||||
"""是否允许机器人被邀请到其它群组"""
|
|
||||||
|
|
||||||
timeout: int = 10
|
|
||||||
connection_pool_size: int = 256
|
|
||||||
read_timeout: Optional[float] = None
|
|
||||||
write_timeout: Optional[float] = None
|
|
||||||
connect_timeout: Optional[float] = None
|
|
||||||
pool_timeout: Optional[float] = None
|
|
||||||
update_read_timeout: Optional[float] = None
|
|
||||||
update_write_timeout: Optional[float] = None
|
|
||||||
update_connect_timeout: Optional[float] = None
|
|
||||||
update_pool_timeout: Optional[float] = None
|
|
||||||
|
|
||||||
genshin_ttl: Optional[int] = None
|
|
||||||
|
|
||||||
enka_network_api_agent: str = ""
|
|
||||||
pass_challenge_api: str = ""
|
|
||||||
pass_challenge_app_key: str = ""
|
|
||||||
pass_challenge_user_web: str = ""
|
|
||||||
|
|
||||||
reload: ReloadConfig = ReloadConfig()
|
|
||||||
database: DatabaseConfig = DatabaseConfig()
|
|
||||||
logger: LoggerConfig = LoggerConfig()
|
|
||||||
webserver: WebServerConfig = WebServerConfig()
|
|
||||||
redis: RedisConfig = RedisConfig()
|
|
||||||
mtproto: MTProtoConfig = MTProtoConfig()
|
|
||||||
error: ErrorConfig = ErrorConfig()
|
|
||||||
notice: NoticeConfig = NoticeConfig()
|
|
||||||
|
|
||||||
|
|
||||||
ApplicationConfig.update_forward_refs()
|
|
||||||
config = ApplicationConfig()
|
|
||||||
|
@ -1,56 +1,3 @@
|
|||||||
from typing import Optional, TYPE_CHECKING
|
from gram_core.dependence.aiobrowser import AioBrowser
|
||||||
|
|
||||||
from playwright.async_api import Error, async_playwright
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from playwright.async_api import Playwright as AsyncPlaywright, Browser
|
|
||||||
|
|
||||||
__all__ = ("AioBrowser",)
|
__all__ = ("AioBrowser",)
|
||||||
|
|
||||||
|
|
||||||
class AioBrowser(BaseService.Dependence):
|
|
||||||
@property
|
|
||||||
def browser(self):
|
|
||||||
return self._browser
|
|
||||||
|
|
||||||
def __init__(self, loop=None):
|
|
||||||
self._browser: Optional["Browser"] = None
|
|
||||||
self._playwright: Optional["AsyncPlaywright"] = None
|
|
||||||
self._loop = loop
|
|
||||||
|
|
||||||
async def get_browser(self):
|
|
||||||
if self._browser is None:
|
|
||||||
await self.initialize()
|
|
||||||
return self._browser
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
if self._playwright is None:
|
|
||||||
logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True})
|
|
||||||
self._playwright = await async_playwright().start()
|
|
||||||
logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True})
|
|
||||||
if self._browser is None:
|
|
||||||
logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True})
|
|
||||||
try:
|
|
||||||
self._browser = await self._playwright.chromium.launch(timeout=5000)
|
|
||||||
logger.success("[blue]Browser[/] 启动成功", extra={"markup": True})
|
|
||||||
except Error as err:
|
|
||||||
if "playwright install" in str(err):
|
|
||||||
logger.error(
|
|
||||||
"检查到 [blue]playwright[/] 刚刚安装或者未升级\n"
|
|
||||||
"请运行以下命令下载新浏览器\n"
|
|
||||||
"[blue bold]playwright install chromium[/]",
|
|
||||||
extra={"markup": True},
|
|
||||||
)
|
|
||||||
raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium")
|
|
||||||
raise err
|
|
||||||
|
|
||||||
return self._browser
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
if self._browser is not None:
|
|
||||||
await self._browser.close()
|
|
||||||
if self._playwright is not None:
|
|
||||||
self._playwright.stop()
|
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
from asyncio import AbstractEventLoop
|
|
||||||
|
|
||||||
from playwright.async_api import Browser, Playwright as AsyncPlaywright
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
|
|
||||||
__all__ = ("AioBrowser",)
|
|
||||||
|
|
||||||
class AioBrowser(BaseService.Dependence):
|
|
||||||
_browser: Browser | None
|
|
||||||
_playwright: AsyncPlaywright | None
|
|
||||||
_loop: AbstractEventLoop
|
|
||||||
|
|
||||||
@property
|
|
||||||
def browser(self) -> Browser | None: ...
|
|
||||||
async def get_browser(self) -> Browser: ...
|
|
@ -1,51 +1,3 @@
|
|||||||
import contextlib
|
from gram_core.dependence.database import Database
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy.engine import URL
|
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.config import ApplicationConfig
|
|
||||||
from core.sqlmodel.session import AsyncSession
|
|
||||||
|
|
||||||
__all__ = ("Database",)
|
__all__ = ("Database",)
|
||||||
|
|
||||||
|
|
||||||
class Database(BaseService.Dependence):
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: ApplicationConfig) -> Self:
|
|
||||||
return cls(**config.database.dict())
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
driver_name: str,
|
|
||||||
host: Optional[str] = None,
|
|
||||||
port: Optional[int] = None,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
database: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.database = database # skipcq: PTC-W0052
|
|
||||||
self.password = password
|
|
||||||
self.username = username
|
|
||||||
self.port = port
|
|
||||||
self.host = host
|
|
||||||
self.url = URL.create(
|
|
||||||
driver_name,
|
|
||||||
username=self.username,
|
|
||||||
password=self.password,
|
|
||||||
host=self.host,
|
|
||||||
port=self.port,
|
|
||||||
database=self.database,
|
|
||||||
)
|
|
||||||
self.engine = create_async_engine(self.url)
|
|
||||||
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
|
||||||
async def session(self) -> AsyncSession:
|
|
||||||
yield self.Session()
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
self.Session.close_all()
|
|
||||||
|
@ -1,67 +1,3 @@
|
|||||||
import os
|
from gram_core.dependence.mtproto import MTProto
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import aiofiles
|
__all__ = ("MTProto",)
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.config import config as bot_config
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pyrogram import Client
|
|
||||||
from pyrogram.session import session
|
|
||||||
|
|
||||||
session.log.debug = lambda *args, **kwargs: None # 关闭日记
|
|
||||||
PYROGRAM_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
Client = None
|
|
||||||
session = None
|
|
||||||
PYROGRAM_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
class MTProto(BaseService.Dependence):
|
|
||||||
async def get_session(self):
|
|
||||||
async with aiofiles.open(self.session_path, mode="r") as f:
|
|
||||||
return await f.read()
|
|
||||||
|
|
||||||
async def set_session(self, b: str):
|
|
||||||
async with aiofiles.open(self.session_path, mode="w+") as f:
|
|
||||||
await f.write(b)
|
|
||||||
|
|
||||||
def session_exists(self):
|
|
||||||
return os.path.exists(self.session_path)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = "paigram"
|
|
||||||
current_dir = os.getcwd()
|
|
||||||
self.session_path = os.path.join(current_dir, "paigram.session")
|
|
||||||
self.client: Optional[Client] = None
|
|
||||||
self.proxy: Optional[dict] = None
|
|
||||||
http_proxy = os.environ.get("HTTP_PROXY")
|
|
||||||
if http_proxy is not None:
|
|
||||||
http_proxy_url = urlparse(http_proxy)
|
|
||||||
self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port}
|
|
||||||
|
|
||||||
async def initialize(self): # pylint: disable=W0221
|
|
||||||
if not PYROGRAM_AVAILABLE:
|
|
||||||
logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None")
|
|
||||||
return
|
|
||||||
if bot_config.mtproto.api_id is None:
|
|
||||||
logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None")
|
|
||||||
return
|
|
||||||
if bot_config.mtproto.api_hash is None:
|
|
||||||
logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None")
|
|
||||||
return
|
|
||||||
self.client = Client(
|
|
||||||
api_id=bot_config.mtproto.api_id,
|
|
||||||
api_hash=bot_config.mtproto.api_hash,
|
|
||||||
name=self.name,
|
|
||||||
bot_token=bot_config.bot_token,
|
|
||||||
proxy=self.proxy,
|
|
||||||
)
|
|
||||||
await self.client.start()
|
|
||||||
|
|
||||||
async def shutdown(self): # pylint: disable=W0221
|
|
||||||
if self.client is not None:
|
|
||||||
await self.client.stop(block=False)
|
|
||||||
|
@ -1,31 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from typing import TypedDict
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pyrogram import Client
|
|
||||||
from pyrogram.session import session
|
|
||||||
|
|
||||||
PYROGRAM_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
Client = None
|
|
||||||
session = None
|
|
||||||
PYROGRAM_AVAILABLE = False
|
|
||||||
|
|
||||||
__all__ = ("MTProto",)
|
|
||||||
|
|
||||||
class _ProxyType(TypedDict):
|
|
||||||
scheme: str
|
|
||||||
hostname: str | None
|
|
||||||
port: int | None
|
|
||||||
|
|
||||||
class MTProto(BaseService.Dependence):
|
|
||||||
name: str
|
|
||||||
session_path: str
|
|
||||||
client: Client | None
|
|
||||||
proxy: _ProxyType | None
|
|
||||||
|
|
||||||
async def get_session(self) -> str: ...
|
|
||||||
async def set_session(self, b: str) -> None: ...
|
|
||||||
def session_exists(self) -> bool: ...
|
|
@ -1,50 +1,3 @@
|
|||||||
from typing import Optional, Union
|
from gram_core.dependence.redisdb import RedisDB
|
||||||
|
|
||||||
import fakeredis.aioredis
|
|
||||||
from redis import asyncio as aioredis
|
|
||||||
from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.config import ApplicationConfig
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
__all__ = ["RedisDB"]
|
__all__ = ["RedisDB"]
|
||||||
|
|
||||||
|
|
||||||
class RedisDB(BaseService.Dependence):
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: ApplicationConfig) -> Self:
|
|
||||||
return cls(**config.redis.dict())
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, host: str = "127.0.0.1", port: int = 6379, database: Union[str, int] = 0, password: Optional[str] = None
|
|
||||||
):
|
|
||||||
self.client = aioredis.Redis(host=host, port=port, db=database, password=password)
|
|
||||||
self.ttl = 600
|
|
||||||
|
|
||||||
async def ping(self):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
if await self.client.ping():
|
|
||||||
logger.info("连接 [red]Redis[/] 成功", extra={"markup": True})
|
|
||||||
else:
|
|
||||||
logger.info("连接 [red]Redis[/] 失败", extra={"markup": True})
|
|
||||||
raise RuntimeError("连接 Redis 失败")
|
|
||||||
|
|
||||||
async def start_fake_redis(self):
|
|
||||||
self.client = fakeredis.aioredis.FakeRedis()
|
|
||||||
await self.ping()
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True})
|
|
||||||
try:
|
|
||||||
await self.ping()
|
|
||||||
except (RedisTimeoutError, RedisConnectionError) as exc:
|
|
||||||
if isinstance(exc, RedisTimeoutError):
|
|
||||||
logger.warning("连接 [red]Redis[/] 超时,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
|
|
||||||
if isinstance(exc, RedisConnectionError):
|
|
||||||
logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
|
|
||||||
await self.start_fake_redis()
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
await self.client.close()
|
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
"""此模块包含核心模块的错误的基类"""
|
"""此模块包含核心模块的错误的基类"""
|
||||||
from typing import Union
|
from gram_core.error import ServiceNotFoundError
|
||||||
|
|
||||||
|
__all__ = ("ServiceNotFoundError",)
|
||||||
class ServiceNotFoundError(Exception):
|
|
||||||
def __init__(self, name: Union[str, type]):
|
|
||||||
super().__init__(f"No service named '{name if isinstance(name, str) else name.__name__}'")
|
|
||||||
|
@ -1,59 +1,3 @@
|
|||||||
import asyncio
|
from gram_core.handler.adminhandler import AdminHandler
|
||||||
from typing import TypeVar, TYPE_CHECKING, Any, Optional
|
|
||||||
|
|
||||||
from telegram import Update
|
__all__ = ("AdminHandler",)
|
||||||
from telegram.ext import ApplicationHandlerStop, BaseHandler
|
|
||||||
|
|
||||||
from core.error import ServiceNotFoundError
|
|
||||||
from core.services.users.services import UserAdminService
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.application import Application
|
|
||||||
from telegram.ext import Application as TelegramApplication
|
|
||||||
|
|
||||||
RT = TypeVar("RT")
|
|
||||||
UT = TypeVar("UT")
|
|
||||||
|
|
||||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
|
||||||
|
|
||||||
|
|
||||||
class AdminHandler(BaseHandler[Update, CCT]):
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None:
|
|
||||||
self.handler = handler
|
|
||||||
self.application = application
|
|
||||||
self.user_service: Optional["UserAdminService"] = None
|
|
||||||
super().__init__(self.handler.callback, self.handler.block)
|
|
||||||
|
|
||||||
def check_update(self, update: object) -> bool:
|
|
||||||
if not isinstance(update, Update):
|
|
||||||
return False
|
|
||||||
return self.handler.check_update(update)
|
|
||||||
|
|
||||||
async def _user_service(self) -> "UserAdminService":
|
|
||||||
async with self._lock:
|
|
||||||
if self.user_service is not None:
|
|
||||||
return self.user_service
|
|
||||||
user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None)
|
|
||||||
if user_service is None:
|
|
||||||
raise ServiceNotFoundError("UserAdminService")
|
|
||||||
self.user_service = user_service
|
|
||||||
return self.user_service
|
|
||||||
|
|
||||||
async def handle_update(
|
|
||||||
self,
|
|
||||||
update: "UT",
|
|
||||||
application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]",
|
|
||||||
check_result: Any,
|
|
||||||
context: "CCT",
|
|
||||||
) -> RT:
|
|
||||||
user_service = await self._user_service()
|
|
||||||
user = update.effective_user
|
|
||||||
if await user_service.is_admin(user.id):
|
|
||||||
return await self.handler.handle_update(update, application, check_result, context)
|
|
||||||
message = update.effective_message
|
|
||||||
logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id)
|
|
||||||
await message.reply_text("权限不足")
|
|
||||||
raise ApplicationHandlerStop
|
|
||||||
|
@ -1,62 +1,3 @@
|
|||||||
import asyncio
|
from gram_core.handler.callbackqueryhandler import CallbackQueryHandler, OverlappingException, OverlappingContext
|
||||||
from contextlib import AbstractAsyncContextManager
|
|
||||||
from types import TracebackType
|
|
||||||
from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type
|
|
||||||
|
|
||||||
from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop
|
__all__ = ("CallbackQueryHandler", "OverlappingException", "OverlappingContext")
|
||||||
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from telegram.ext import Application
|
|
||||||
|
|
||||||
RT = TypeVar("RT")
|
|
||||||
UT = TypeVar("UT")
|
|
||||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
|
||||||
|
|
||||||
|
|
||||||
class OverlappingException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OverlappingContext(AbstractAsyncContextManager):
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __init__(self, context: "CCT"):
|
|
||||||
self.context = context
|
|
||||||
|
|
||||||
async def __aenter__(self) -> None:
|
|
||||||
async with self._lock:
|
|
||||||
flag = self.context.user_data.get("overlapping", False)
|
|
||||||
if flag:
|
|
||||||
raise OverlappingException
|
|
||||||
self.context.user_data["overlapping"] = True
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def __aexit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[Type[BaseException]],
|
|
||||||
exc: Optional[BaseException],
|
|
||||||
tb: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
async with self._lock:
|
|
||||||
del self.context.user_data["overlapping"]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class CallbackQueryHandler(BaseCallbackQueryHandler):
|
|
||||||
async def handle_update(
|
|
||||||
self,
|
|
||||||
update: "UT",
|
|
||||||
application: "Application[Any, CCT, Any, Any, Any, Any]",
|
|
||||||
check_result: Any,
|
|
||||||
context: "CCT",
|
|
||||||
) -> RT:
|
|
||||||
self.collect_additional_context(context, update, application, check_result)
|
|
||||||
try:
|
|
||||||
async with OverlappingContext(context):
|
|
||||||
return await self.callback(update, context)
|
|
||||||
except OverlappingException as exc:
|
|
||||||
user = update.effective_user
|
|
||||||
logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id)
|
|
||||||
raise ApplicationHandlerStop from exc
|
|
||||||
|
@ -1,71 +1,3 @@
|
|||||||
import asyncio
|
from gram_core.handler.limiterhandler import LimiterHandler
|
||||||
from typing import TypeVar, Optional
|
|
||||||
|
|
||||||
from telegram import Update
|
__all__ = ("LimiterHandler",)
|
||||||
from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler
|
|
||||||
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
UT = TypeVar("UT")
|
|
||||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
|
||||||
|
|
||||||
|
|
||||||
class LimiterHandler(TypeHandler[UT, CCT]):
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None
|
|
||||||
):
|
|
||||||
"""Limiter Handler 通过
|
|
||||||
`Leaky bucket algorithm <https://en.wikipedia.org/wiki/Leaky_bucket>`_
|
|
||||||
实现对用户的输入的精确控制
|
|
||||||
|
|
||||||
输入超过一定速率后,代码会抛出
|
|
||||||
:class:`telegram.ext.ApplicationHandlerStop`
|
|
||||||
异常并在一段时间内防止用户执行任何其他操作
|
|
||||||
|
|
||||||
:param max_rate: 在抛出异常之前最多允许 频率/秒 的速度
|
|
||||||
:param time_period: 在限制速率的时间段的持续时间
|
|
||||||
:param amount: 提供的容量
|
|
||||||
:param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount
|
|
||||||
"""
|
|
||||||
self.max_rate = max_rate
|
|
||||||
self.amount = amount
|
|
||||||
self._rate_per_sec = max_rate / time_period
|
|
||||||
self.limit_time = limit_time
|
|
||||||
super().__init__(Update, self.limiter_callback)
|
|
||||||
|
|
||||||
async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
if update.inline_query is not None:
|
|
||||||
return
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
async with self._lock:
|
|
||||||
time = loop.time()
|
|
||||||
user_data = context.user_data
|
|
||||||
if user_data is None:
|
|
||||||
return
|
|
||||||
user_limit_time = user_data.get("limit_time")
|
|
||||||
if user_limit_time is not None:
|
|
||||||
if time >= user_limit_time:
|
|
||||||
del user_data["limit_time"]
|
|
||||||
else:
|
|
||||||
raise ApplicationHandlerStop
|
|
||||||
last_task_time = user_data.get("last_task_time", 0)
|
|
||||||
if last_task_time:
|
|
||||||
task_level = user_data.get("task_level", 0)
|
|
||||||
elapsed = time - last_task_time
|
|
||||||
decrement = elapsed * self._rate_per_sec
|
|
||||||
task_level = max(task_level - decrement, 0)
|
|
||||||
user_data["task_level"] = task_level
|
|
||||||
if not task_level + self.amount <= self.max_rate:
|
|
||||||
if self.limit_time:
|
|
||||||
limit_time = self.limit_time
|
|
||||||
else:
|
|
||||||
limit_time = 1 / self._rate_per_sec * self.amount
|
|
||||||
user_data["limit_time"] = time + limit_time
|
|
||||||
user = update.effective_user
|
|
||||||
logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s 秒", user.full_name, user.id, limit_time)
|
|
||||||
raise ApplicationHandlerStop
|
|
||||||
user_data["last_task_time"] = time
|
|
||||||
task_level = user_data.get("task_level", 0)
|
|
||||||
user_data["task_level"] = task_level + self.amount
|
|
||||||
|
286
core/manager.py
286
core/manager.py
@ -1,286 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from importlib import import_module
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
|
|
||||||
|
|
||||||
from arkowrapper import ArkoWrapper
|
|
||||||
from async_timeout import timeout
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
|
||||||
from core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
|
|
||||||
from core.config import config as bot_config
|
|
||||||
from utils.const import PLUGIN_DIR, PROJECT_ROOT
|
|
||||||
from utils.helpers import gen_pkg
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.application import Application
|
|
||||||
from core.plugin import PluginType
|
|
||||||
from core.builtins.executor import Executor
|
|
||||||
|
|
||||||
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
|
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
T = TypeVar("T")
|
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
|
|
||||||
def _load_module(path: Path) -> None:
|
|
||||||
for pkg in gen_pkg(path):
|
|
||||||
try:
|
|
||||||
logger.debug('正在导入 "%s"', pkg)
|
|
||||||
import_module(pkg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
'在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
|
|
||||||
)
|
|
||||||
raise SystemExit from e
|
|
||||||
|
|
||||||
|
|
||||||
class Manager(Generic[T]):
|
|
||||||
"""生命周期控制基类"""
|
|
||||||
|
|
||||||
_executor: Optional["Executor"] = None
|
|
||||||
_lib: Dict[Type[T], T] = {}
|
|
||||||
_application: "Optional[Application]" = None
|
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
|
||||||
self._application = application
|
|
||||||
|
|
||||||
@property
|
|
||||||
def application(self) -> "Application":
|
|
||||||
if self._application is None:
|
|
||||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
|
||||||
return self._application
|
|
||||||
|
|
||||||
@property
|
|
||||||
def executor(self) -> "Executor":
|
|
||||||
"""执行器"""
|
|
||||||
if self._executor is None:
|
|
||||||
raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.")
|
|
||||||
return self._executor
|
|
||||||
|
|
||||||
def build_executor(self, name: str):
|
|
||||||
from core.builtins.executor import Executor
|
|
||||||
from core.builtins.dispatcher import BaseDispatcher
|
|
||||||
|
|
||||||
self._executor = Executor(name, dispatcher=BaseDispatcher)
|
|
||||||
self._executor.set_application(self.application)
|
|
||||||
|
|
||||||
|
|
||||||
class DependenceManager(Manager[DependenceType]):
|
|
||||||
"""基础依赖管理"""
|
|
||||||
|
|
||||||
_dependency: Dict[Type[DependenceType], DependenceType] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dependency(self) -> List[DependenceType]:
|
|
||||||
return list(self._dependency.values())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]:
|
|
||||||
return self._dependency
|
|
||||||
|
|
||||||
async def start_dependency(self) -> None:
|
|
||||||
_load_module(PROJECT_ROOT / "core/dependence")
|
|
||||||
|
|
||||||
for dependence in filter(lambda x: x.is_dependence, get_all_services()):
|
|
||||||
dependence: Type[DependenceType]
|
|
||||||
instance: DependenceType
|
|
||||||
try:
|
|
||||||
if hasattr(dependence, "from_config"): # 如果有 from_config 方法
|
|
||||||
instance = dependence.from_config(bot_config) # 用 from_config 实例化服务
|
|
||||||
else:
|
|
||||||
instance = await self.executor(dependence)
|
|
||||||
|
|
||||||
await instance.initialize()
|
|
||||||
logger.success('基础服务 "%s" 启动成功', dependence.__name__)
|
|
||||||
|
|
||||||
self._lib[dependence] = instance
|
|
||||||
self._dependency[dependence] = instance
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception('基础服务 "%s" 初始化失败,BOT 将自动关闭', dependence.__name__)
|
|
||||||
raise SystemExit from e
|
|
||||||
|
|
||||||
async def stop_dependency(self) -> None:
|
|
||||||
async def task(d):
|
|
||||||
try:
|
|
||||||
async with timeout(5):
|
|
||||||
await d.shutdown()
|
|
||||||
logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e)
|
|
||||||
|
|
||||||
tasks = []
|
|
||||||
for dependence in self._dependency.values():
|
|
||||||
tasks.append(asyncio.create_task(task(dependence)))
|
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
|
|
||||||
class ComponentManager(Manager[ComponentType]):
|
|
||||||
"""组件管理"""
|
|
||||||
|
|
||||||
_components: Dict[Type[ComponentType], ComponentType] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def components(self) -> List[ComponentType]:
|
|
||||||
return list(self._components.values())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def components_map(self) -> Dict[Type[ComponentType], ComponentType]:
|
|
||||||
return self._components
|
|
||||||
|
|
||||||
async def init_components(self):
|
|
||||||
for path in filter(
|
|
||||||
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
|
|
||||||
):
|
|
||||||
_load_module(path)
|
|
||||||
components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component)
|
|
||||||
retry_times = 0
|
|
||||||
max_retry_times = len(components)
|
|
||||||
while components:
|
|
||||||
start_len = len(components)
|
|
||||||
for component in list(components):
|
|
||||||
component: Type[ComponentType]
|
|
||||||
instance: ComponentType
|
|
||||||
try:
|
|
||||||
instance = await self.executor(component)
|
|
||||||
self._lib[component] = instance
|
|
||||||
self._components[component] = instance
|
|
||||||
components = components.remove(component)
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True})
|
|
||||||
end_len = len(list(components))
|
|
||||||
if start_len == end_len:
|
|
||||||
retry_times += 1
|
|
||||||
|
|
||||||
if retry_times == max_retry_times and components:
|
|
||||||
for component in components:
|
|
||||||
logger.error('组件 "%s" 初始化失败', component.__name__)
|
|
||||||
raise SystemExit
|
|
||||||
|
|
||||||
|
|
||||||
class ServiceManager(Manager[BaseServiceType]):
|
|
||||||
"""服务控制类"""
|
|
||||||
|
|
||||||
_services: Dict[Type[BaseServiceType], BaseServiceType] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def services(self) -> List[BaseServiceType]:
|
|
||||||
return list(self._services.values())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]:
|
|
||||||
return self._services
|
|
||||||
|
|
||||||
async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType:
|
|
||||||
instance: BaseServiceType
|
|
||||||
try:
|
|
||||||
if hasattr(target, "from_config"): # 如果有 from_config 方法
|
|
||||||
instance = target.from_config(bot_config) # 用 from_config 实例化服务
|
|
||||||
else:
|
|
||||||
instance = await self.executor(target)
|
|
||||||
|
|
||||||
await instance.initialize()
|
|
||||||
logger.success('服务 "%s" 启动成功', target.__name__)
|
|
||||||
|
|
||||||
return instance
|
|
||||||
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.exception('服务 "%s" 初始化失败,BOT 将自动关闭', target.__name__)
|
|
||||||
raise SystemExit from e
|
|
||||||
|
|
||||||
async def start_services(self) -> None:
|
|
||||||
for path in filter(
|
|
||||||
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
|
|
||||||
):
|
|
||||||
_load_module(path)
|
|
||||||
|
|
||||||
for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类
|
|
||||||
instance = await self._initialize_service(service)
|
|
||||||
|
|
||||||
self._lib[service] = instance
|
|
||||||
self._services[service] = instance
|
|
||||||
|
|
||||||
async def stop_services(self) -> None:
|
|
||||||
"""关闭服务"""
|
|
||||||
if not self._services:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def task(s):
|
|
||||||
try:
|
|
||||||
async with timeout(5):
|
|
||||||
await s.shutdown()
|
|
||||||
logger.success('服务 "%s" 关闭成功', s.__class__.__name__)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning('服务 "%s" 关闭超时', s.__class__.__name__)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e)
|
|
||||||
|
|
||||||
logger.info("正在关闭服务")
|
|
||||||
tasks = []
|
|
||||||
for service in self._services.values():
|
|
||||||
tasks.append(asyncio.create_task(task(service)))
|
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginManager(Manager["PluginType"]):
|
|
||||||
"""插件管理"""
|
|
||||||
|
|
||||||
_plugins: Dict[Type["PluginType"], "PluginType"] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def plugins(self) -> List["PluginType"]:
|
|
||||||
"""所有已经加载的插件"""
|
|
||||||
return list(self._plugins.values())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]:
|
|
||||||
return self._plugins
|
|
||||||
|
|
||||||
async def install_plugins(self) -> None:
|
|
||||||
"""安装所有插件"""
|
|
||||||
from core.plugin import get_all_plugins
|
|
||||||
|
|
||||||
for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()):
|
|
||||||
_load_module(path)
|
|
||||||
|
|
||||||
for plugin in get_all_plugins():
|
|
||||||
plugin: Type["PluginType"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
instance: "PluginType" = await self.executor(plugin)
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._plugins[plugin] = instance
|
|
||||||
|
|
||||||
if self._application is not None:
|
|
||||||
instance.set_application(self._application)
|
|
||||||
|
|
||||||
await asyncio.create_task(self.plugin_install_task(plugin, instance))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"):
|
|
||||||
try:
|
|
||||||
await instance.install()
|
|
||||||
logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}")
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
|
||||||
|
|
||||||
async def uninstall_plugins(self) -> None:
|
|
||||||
for plugin in self._plugins.values():
|
|
||||||
try:
|
|
||||||
await plugin.uninstall()
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
|
||||||
|
|
||||||
|
|
||||||
class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager):
|
|
||||||
"""BOT 除自身外的生命周期管理类"""
|
|
@ -1,117 +0,0 @@
|
|||||||
"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化"""
|
|
||||||
from typing import Any, AsyncIterable, Optional
|
|
||||||
|
|
||||||
import httpcore
|
|
||||||
from httpx import (
|
|
||||||
AsyncByteStream,
|
|
||||||
AsyncHTTPTransport as DefaultAsyncHTTPTransport,
|
|
||||||
Limits,
|
|
||||||
Response as DefaultResponse,
|
|
||||||
Timeout,
|
|
||||||
)
|
|
||||||
from telegram.request import HTTPXRequest as DefaultHTTPXRequest
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ujson as jsonlib
|
|
||||||
except ImportError:
|
|
||||||
import json as jsonlib
|
|
||||||
|
|
||||||
__all__ = ("HTTPXRequest",)
|
|
||||||
|
|
||||||
|
|
||||||
class Response(DefaultResponse):
|
|
||||||
def json(self, **kwargs: Any) -> Any:
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from httpx._utils import guess_json_utf
|
|
||||||
|
|
||||||
if self.charset_encoding is None and self.content and len(self.content) > 3:
|
|
||||||
encoding = guess_json_utf(self.content)
|
|
||||||
if encoding is not None:
|
|
||||||
return jsonlib.loads(self.content.decode(encoding), **kwargs)
|
|
||||||
return jsonlib.loads(self.text, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
class AsyncHTTPTransport(DefaultAsyncHTTPTransport):
|
|
||||||
async def handle_async_request(self, request) -> Response:
|
|
||||||
from httpx._transports.default import (
|
|
||||||
map_httpcore_exceptions,
|
|
||||||
AsyncResponseStream,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(request.stream, AsyncByteStream):
|
|
||||||
raise AssertionError
|
|
||||||
|
|
||||||
req = httpcore.Request(
|
|
||||||
method=request.method,
|
|
||||||
url=httpcore.URL(
|
|
||||||
scheme=request.url.raw_scheme,
|
|
||||||
host=request.url.raw_host,
|
|
||||||
port=request.url.port,
|
|
||||||
target=request.url.raw_path,
|
|
||||||
),
|
|
||||||
headers=request.headers.raw,
|
|
||||||
content=request.stream,
|
|
||||||
extensions=request.extensions,
|
|
||||||
)
|
|
||||||
with map_httpcore_exceptions():
|
|
||||||
resp = await self._pool.handle_async_request(req)
|
|
||||||
|
|
||||||
if not isinstance(resp.stream, AsyncIterable):
|
|
||||||
raise AssertionError
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
status_code=resp.status,
|
|
||||||
headers=resp.headers,
|
|
||||||
stream=AsyncResponseStream(resp.stream),
|
|
||||||
extensions=resp.extensions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPXRequest(DefaultHTTPXRequest):
|
|
||||||
def __init__( # pylint: disable=W0231
|
|
||||||
self,
|
|
||||||
connection_pool_size: int = 1,
|
|
||||||
proxy_url: str = None,
|
|
||||||
read_timeout: Optional[float] = 5.0,
|
|
||||||
write_timeout: Optional[float] = 5.0,
|
|
||||||
connect_timeout: Optional[float] = 5.0,
|
|
||||||
pool_timeout: Optional[float] = 1.0,
|
|
||||||
http_version: str = "1.1",
|
|
||||||
):
|
|
||||||
self._http_version = http_version
|
|
||||||
timeout = Timeout(
|
|
||||||
connect=connect_timeout,
|
|
||||||
read=read_timeout,
|
|
||||||
write=write_timeout,
|
|
||||||
pool=pool_timeout,
|
|
||||||
)
|
|
||||||
limits = Limits(
|
|
||||||
max_connections=connection_pool_size,
|
|
||||||
max_keepalive_connections=connection_pool_size,
|
|
||||||
)
|
|
||||||
if http_version not in ("1.1", "2"):
|
|
||||||
raise ValueError("`http_version` must be either '1.1' or '2'.")
|
|
||||||
http1 = http_version == "1.1"
|
|
||||||
self._client_kwargs = dict(
|
|
||||||
timeout=timeout,
|
|
||||||
proxies=proxy_url,
|
|
||||||
limits=limits,
|
|
||||||
transport=AsyncHTTPTransport(limits=limits),
|
|
||||||
http1=http1,
|
|
||||||
http2=not http1,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._client = self._build_client()
|
|
||||||
except ImportError as exc:
|
|
||||||
if "httpx[http2]" not in str(exc) and "httpx[socks]" not in str(exc):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
if "httpx[socks]" in str(exc):
|
|
||||||
raise RuntimeError(
|
|
||||||
"To use Socks5 proxies, PTB must be installed via `pip install " "python-telegram-bot[socks]`."
|
|
||||||
) from exc
|
|
||||||
raise RuntimeError(
|
|
||||||
"To use HTTP/2, PTB must be installed via `pip install " "python-telegram-bot[http2]`."
|
|
||||||
) from exc
|
|
@ -1,8 +1,8 @@
|
|||||||
"""插件"""
|
"""插件"""
|
||||||
|
|
||||||
from core.plugin._handler import conversation, error_handler, handler
|
from gram_core.plugin._handler import conversation, error_handler, handler
|
||||||
from core.plugin._job import TimeType, job
|
from gram_core.plugin._job import TimeType, job
|
||||||
from core.plugin._plugin import Plugin, PluginType, get_all_plugins
|
from gram_core.plugin._plugin import Plugin, PluginType, get_all_plugins
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"Plugin",
|
"Plugin",
|
||||||
|
@ -1,178 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
|
||||||
|
|
||||||
import aiofiles
|
|
||||||
import httpx
|
|
||||||
from httpx import UnsupportedProtocol
|
|
||||||
from telegram import Chat, Message, ReplyKeyboardRemove, Update
|
|
||||||
from telegram.error import Forbidden, NetworkError
|
|
||||||
from telegram.ext import CallbackContext, ConversationHandler, Job
|
|
||||||
|
|
||||||
from core.dependence.redisdb import RedisDB
|
|
||||||
from core.plugin._handler import conversation, handler
|
|
||||||
from utils.const import CACHE_DIR, REQUEST_HEADERS
|
|
||||||
from utils.error import UrlResourcesNotFoundError
|
|
||||||
from utils.helpers import sha1
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.application import Application
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ujson as json
|
|
||||||
except ImportError:
|
|
||||||
import json
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"PluginFuncs",
|
|
||||||
"ConversationFuncs",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginFuncs:
|
|
||||||
_application: "Optional[Application]" = None
|
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
|
||||||
self._application = application
|
|
||||||
|
|
||||||
@property
|
|
||||||
def application(self) -> "Application":
|
|
||||||
if self._application is None:
|
|
||||||
raise RuntimeError("No application was set for this PluginManager.")
|
|
||||||
return self._application
|
|
||||||
|
|
||||||
async def _delete_message(self, context: CallbackContext) -> None:
|
|
||||||
job = context.job
|
|
||||||
message_id = job.data
|
|
||||||
chat_info = f"chat_id[{job.chat_id}]"
|
|
||||||
|
|
||||||
try:
|
|
||||||
chat = await self.get_chat(job.chat_id)
|
|
||||||
full_name = chat.full_name
|
|
||||||
if full_name:
|
|
||||||
chat_info = f"{full_name}[{chat.id}]"
|
|
||||||
else:
|
|
||||||
chat_info = f"{chat.title}[{chat.id}]"
|
|
||||||
except (NetworkError, Forbidden) as exc:
|
|
||||||
logger.warning("获取 chat info 失败 %s", exc.message)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
|
||||||
|
|
||||||
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
|
||||||
except NetworkError as exc:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
|
||||||
except Forbidden as exc:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc)
|
|
||||||
|
|
||||||
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat:
|
|
||||||
application = self.application
|
|
||||||
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
|
|
||||||
|
|
||||||
if not redis_db:
|
|
||||||
return await application.bot.get_chat(chat_id)
|
|
||||||
|
|
||||||
qname = f"bot:chat:{chat_id}"
|
|
||||||
|
|
||||||
data = await redis_db.client.get(qname)
|
|
||||||
if data:
|
|
||||||
json_data = json.loads(data)
|
|
||||||
return Chat.de_json(json_data, application.telegram.bot)
|
|
||||||
|
|
||||||
chat_info = await application.telegram.bot.get_chat(chat_id)
|
|
||||||
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
|
|
||||||
return chat_info
|
|
||||||
|
|
||||||
def add_delete_message_job(
|
|
||||||
self,
|
|
||||||
message: Optional[Union[int, Message]] = None,
|
|
||||||
*,
|
|
||||||
delay: int = 60,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
chat: Optional[Union[int, Chat]] = None,
|
|
||||||
context: Optional[CallbackContext] = None,
|
|
||||||
) -> Job:
|
|
||||||
"""延迟删除消息"""
|
|
||||||
|
|
||||||
if isinstance(message, Message):
|
|
||||||
if chat is None:
|
|
||||||
chat = message.chat_id
|
|
||||||
message = message.id
|
|
||||||
|
|
||||||
chat = chat.id if isinstance(chat, Chat) else chat
|
|
||||||
|
|
||||||
job_queue = self.application.job_queue or context.job_queue
|
|
||||||
|
|
||||||
if job_queue is None or chat is None:
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
return job_queue.run_once(
|
|
||||||
callback=self._delete_message,
|
|
||||||
when=delay,
|
|
||||||
data=message,
|
|
||||||
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
|
|
||||||
chat_id=chat,
|
|
||||||
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def download_resource(url: str, return_path: bool = False) -> str:
|
|
||||||
url_sha1 = sha1(url) # url 的 hash 值
|
|
||||||
pathed_url = Path(url)
|
|
||||||
|
|
||||||
file_name = url_sha1 + pathed_url.suffix
|
|
||||||
file_path = CACHE_DIR.joinpath(file_name)
|
|
||||||
|
|
||||||
if not file_path.exists(): # 若文件不存在,则下载
|
|
||||||
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
|
|
||||||
try:
|
|
||||||
response = await client.get(url)
|
|
||||||
except UnsupportedProtocol:
|
|
||||||
logger.error("链接不支持 url[%s]", url)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if response.is_error:
|
|
||||||
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
|
|
||||||
raise UrlResourcesNotFoundError(url)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
|
|
||||||
raise UrlResourcesNotFoundError(url)
|
|
||||||
|
|
||||||
async with aiofiles.open(file_path, mode="wb") as f:
|
|
||||||
await f.write(response.content)
|
|
||||||
|
|
||||||
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
|
|
||||||
|
|
||||||
return file_path if return_path else Path(file_path).as_uri()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_args(context: CallbackContext) -> List[str]:
|
|
||||||
args = context.args
|
|
||||||
match = context.match
|
|
||||||
|
|
||||||
if args is None:
|
|
||||||
if match is not None and (command := match.groups()[0]):
|
|
||||||
temp = []
|
|
||||||
command_parts = command.split(" ")
|
|
||||||
for command_part in command_parts:
|
|
||||||
if command_part:
|
|
||||||
temp.append(command_part)
|
|
||||||
return temp
|
|
||||||
return []
|
|
||||||
if len(args) >= 1:
|
|
||||||
return args
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationFuncs:
|
|
||||||
@conversation.fallback
|
|
||||||
@handler.command(command="cancel", block=False)
|
|
||||||
async def cancel(self, update: Update, _) -> int:
|
|
||||||
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
|
|
||||||
return ConversationHandler.END
|
|
@ -1,380 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from functools import wraps
|
|
||||||
from importlib import import_module
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
ClassVar,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Pattern,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from telegram._utils.defaultvalue import DEFAULT_TRUE
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from telegram._utils.types import DVInput
|
|
||||||
from telegram.ext import BaseHandler
|
|
||||||
from telegram.ext.filters import BaseFilter
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
|
||||||
from core.handler.callbackqueryhandler import CallbackQueryHandler
|
|
||||||
from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.builtins.dispatcher import AbstractDispatcher
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"handler",
|
|
||||||
"conversation",
|
|
||||||
"ConversationDataType",
|
|
||||||
"ConversationData",
|
|
||||||
"HandlerData",
|
|
||||||
"ErrorHandlerData",
|
|
||||||
"error_handler",
|
|
||||||
)
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
T = TypeVar("T")
|
|
||||||
R = TypeVar("R")
|
|
||||||
UT = TypeVar("UT")
|
|
||||||
|
|
||||||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
|
||||||
HandlerCls = Type[HandlerType]
|
|
||||||
|
|
||||||
Module = import_module("telegram.ext")
|
|
||||||
|
|
||||||
HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
|
||||||
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
|
||||||
|
|
||||||
ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
|
||||||
|
|
||||||
CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
|
||||||
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
|
||||||
|
|
||||||
WRAPPER_ASSIGNMENTS = list(
|
|
||||||
set(
|
|
||||||
_WRAPPER_ASSIGNMENTS
|
|
||||||
+ [
|
|
||||||
HANDLER_DATA_ATTR_NAME,
|
|
||||||
ERROR_HANDLER_ATTR_NAME,
|
|
||||||
CONVERSATION_HANDLER_ATTR_NAME,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=True)
|
|
||||||
class HandlerData:
|
|
||||||
type: Type[HandlerType]
|
|
||||||
admin: bool
|
|
||||||
kwargs: Dict[str, Any]
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class _Handler:
|
|
||||||
_type: Type["HandlerType"]
|
|
||||||
|
|
||||||
kwargs: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs) -> None:
|
|
||||||
"""用于获取 python-telegram-bot 中对应的 handler class"""
|
|
||||||
|
|
||||||
handler_name = f"{cls.__name__.strip('_')}Handler"
|
|
||||||
|
|
||||||
if handler_name == "CallbackQueryHandler":
|
|
||||||
cls._type = CallbackQueryHandler
|
|
||||||
return
|
|
||||||
|
|
||||||
cls._type = getattr(Module, handler_name, None)
|
|
||||||
|
|
||||||
def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None:
|
|
||||||
self.dispatcher = dispatcher
|
|
||||||
self.admin = admin
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
"""decorator实现,从 func 生成 Handler"""
|
|
||||||
|
|
||||||
handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, [])
|
|
||||||
handler_datas.append(
|
|
||||||
HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher)
|
|
||||||
)
|
|
||||||
setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas)
|
|
||||||
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class _CallbackQuery(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
|
|
||||||
*,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
admin: bool = False,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _ChatJoinRequest(_Handler):
|
|
||||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
|
||||||
super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _ChatMember(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
chat_member_types: int = -1,
|
|
||||||
*,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _ChosenInlineResult(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
*,
|
|
||||||
pattern: Union[str, Pattern] = None,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(block=block, pattern=pattern, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _Command(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
command: Union[str, List[str]],
|
|
||||||
filters: "BaseFilter" = None,
|
|
||||||
*,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
admin: bool = False,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super(_Command, self).__init__(
|
|
||||||
command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _InlineQuery(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
pattern: Union[str, Pattern] = None,
|
|
||||||
chat_types: List[str] = None,
|
|
||||||
*,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _Message(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
filters: BaseFilter,
|
|
||||||
*,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
admin: bool = False,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
) -> None:
|
|
||||||
super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _PollAnswer(_Handler):
|
|
||||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
|
||||||
super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _Poll(_Handler):
|
|
||||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
|
||||||
super(_Poll, self).__init__(block=block, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _PreCheckoutQuery(_Handler):
|
|
||||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
|
||||||
super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _Prefix(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prefix: str,
|
|
||||||
command: str,
|
|
||||||
filters: BaseFilter = None,
|
|
||||||
*,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super(_Prefix, self).__init__(
|
|
||||||
prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _ShippingQuery(_Handler):
|
|
||||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
|
||||||
super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _StringCommand(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
command: str,
|
|
||||||
*,
|
|
||||||
admin: bool = False,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _StringRegex(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
pattern: Union[str, Pattern],
|
|
||||||
*,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
admin: bool = False,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
class _Type(_Handler):
|
|
||||||
# noinspection PyShadowingBuiltins
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
type: Type[UT], # pylint: disable=W0622
|
|
||||||
strict: bool = False,
|
|
||||||
*,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
): # pylint: disable=redefined-builtin
|
|
||||||
super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
|
||||||
class handler(_Handler):
|
|
||||||
callback_query = _CallbackQuery
|
|
||||||
chat_join_request = _ChatJoinRequest
|
|
||||||
chat_member = _ChatMember
|
|
||||||
chosen_inline_result = _ChosenInlineResult
|
|
||||||
command = _Command
|
|
||||||
inline_query = _InlineQuery
|
|
||||||
message = _Message
|
|
||||||
poll_answer = _PollAnswer
|
|
||||||
pool = _Poll
|
|
||||||
pre_checkout_query = _PreCheckoutQuery
|
|
||||||
prefix = _Prefix
|
|
||||||
shipping_query = _ShippingQuery
|
|
||||||
string_command = _StringCommand
|
|
||||||
string_regex = _StringRegex
|
|
||||||
type = _Type
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]],
|
|
||||||
*,
|
|
||||||
admin: bool = False,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> None:
|
|
||||||
self._type = handler_type
|
|
||||||
super().__init__(admin=admin, dispatcher=dispatcher, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationDataType(Enum):
|
|
||||||
"""conversation handler 的类型"""
|
|
||||||
|
|
||||||
Entry = "entry"
|
|
||||||
State = "state"
|
|
||||||
Fallback = "fallback"
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationData(BaseModel):
|
|
||||||
"""用于储存 conversation handler 的数据"""
|
|
||||||
|
|
||||||
type: ConversationDataType
|
|
||||||
state: Optional[Any] = None
|
|
||||||
|
|
||||||
|
|
||||||
class _ConversationType:
|
|
||||||
_type: ClassVar[ConversationDataType]
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs) -> None:
|
|
||||||
cls._type = ConversationDataType(cls.__name__.lstrip("_").lower())
|
|
||||||
|
|
||||||
|
|
||||||
def _entry(func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry))
|
|
||||||
|
|
||||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
|
||||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
class _State(_ConversationType):
|
|
||||||
def __init__(self, state: Any) -> None:
|
|
||||||
self.state = state
|
|
||||||
|
|
||||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
|
||||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state))
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def _fallback(func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback))
|
|
||||||
|
|
||||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
|
||||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
|
||||||
class conversation(_Handler):
|
|
||||||
entry_point = _entry
|
|
||||||
state = _State
|
|
||||||
fallback = _fallback
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=True)
|
|
||||||
class ErrorHandlerData:
|
|
||||||
block: bool
|
|
||||||
func: Optional[Callable] = None
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
|
||||||
class error_handler:
|
|
||||||
_func: Callable[P, R]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
block: bool = DEFAULT_TRUE,
|
|
||||||
):
|
|
||||||
self._block = block
|
|
||||||
|
|
||||||
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
|
|
||||||
self._func = func
|
|
||||||
wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self)
|
|
||||||
|
|
||||||
handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, [])
|
|
||||||
handler_datas.append(ErrorHandlerData(block=self._block))
|
|
||||||
setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas)
|
|
||||||
|
|
||||||
return self._func
|
|
@ -1,173 +0,0 @@
|
|||||||
"""插件"""
|
|
||||||
import datetime
|
|
||||||
import re
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from telegram._utils.types import JSONDict
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from telegram.ext._utils.types import JobCallback
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.builtins.dispatcher import AbstractDispatcher
|
|
||||||
|
|
||||||
__all__ = ["TimeType", "job", "JobData"]
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
T = TypeVar("T")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
|
|
||||||
|
|
||||||
_JOB_ATTR_NAME = "_job_data"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=True)
|
|
||||||
class JobData:
|
|
||||||
name: str
|
|
||||||
data: Any
|
|
||||||
chat_id: int
|
|
||||||
user_id: int
|
|
||||||
type: str
|
|
||||||
job_kwargs: JSONDict = field(default_factory=dict)
|
|
||||||
kwargs: JSONDict = field(default_factory=dict)
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class _Job:
|
|
||||||
kwargs: Dict = {}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str = None,
|
|
||||||
data: object = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
*,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.name = name
|
|
||||||
self.data = data
|
|
||||||
self.chat_id = chat_id
|
|
||||||
self.user_id = user_id
|
|
||||||
self.job_kwargs = {} if job_kwargs is None else job_kwargs
|
|
||||||
self.kwargs = kwargs
|
|
||||||
if dispatcher is None:
|
|
||||||
from core.builtins.dispatcher import JobDispatcher
|
|
||||||
|
|
||||||
dispatcher = JobDispatcher
|
|
||||||
|
|
||||||
self.dispatcher = dispatcher
|
|
||||||
|
|
||||||
def __call__(self, func: JobCallback) -> JobCallback:
|
|
||||||
data = JobData(
|
|
||||||
name=self.name,
|
|
||||||
data=self.data,
|
|
||||||
chat_id=self.chat_id,
|
|
||||||
user_id=self.user_id,
|
|
||||||
job_kwargs=self.job_kwargs,
|
|
||||||
kwargs=self.kwargs,
|
|
||||||
type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
|
|
||||||
dispatcher=self.dispatcher,
|
|
||||||
)
|
|
||||||
if hasattr(func, _JOB_ATTR_NAME):
|
|
||||||
job_datas = getattr(func, _JOB_ATTR_NAME)
|
|
||||||
job_datas.append(data)
|
|
||||||
setattr(func, _JOB_ATTR_NAME, job_datas)
|
|
||||||
else:
|
|
||||||
setattr(func, _JOB_ATTR_NAME, [data])
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class _RunOnce(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
when: TimeType,
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
*,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when)
|
|
||||||
|
|
||||||
|
|
||||||
class _RunRepeating(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
interval: Union[float, datetime.timedelta],
|
|
||||||
first: TimeType = None,
|
|
||||||
last: TimeType = None,
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
*,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _RunMonthly(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
when: datetime.time,
|
|
||||||
day: int,
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
*,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day)
|
|
||||||
|
|
||||||
|
|
||||||
class _RunDaily(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
time: datetime.time,
|
|
||||||
days: Tuple[int, ...] = tuple(range(7)),
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
*,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days)
|
|
||||||
|
|
||||||
|
|
||||||
class _RunCustom(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
*,
|
|
||||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher)
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
|
||||||
class job:
|
|
||||||
run_once = _RunOnce
|
|
||||||
run_repeating = _RunRepeating
|
|
||||||
run_monthly = _RunMonthly
|
|
||||||
run_daily = _RunDaily
|
|
||||||
run_custom = _RunCustom
|
|
@ -1,314 +0,0 @@
|
|||||||
"""插件"""
|
|
||||||
import asyncio
|
|
||||||
from abc import ABC
|
|
||||||
from dataclasses import asdict
|
|
||||||
from datetime import timedelta
|
|
||||||
from functools import partial, wraps
|
|
||||||
from itertools import chain
|
|
||||||
from multiprocessing import RLock as Lock
|
|
||||||
from types import MethodType
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
ClassVar,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
|
||||||
from core.handler.adminhandler import AdminHandler
|
|
||||||
from core.plugin._funcs import ConversationFuncs, PluginFuncs
|
|
||||||
from core.plugin._handler import ConversationDataType
|
|
||||||
from utils.const import WRAPPER_ASSIGNMENTS
|
|
||||||
from utils.helpers import isabstract
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.application import Application
|
|
||||||
from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData
|
|
||||||
from core.plugin._job import JobData
|
|
||||||
from multiprocessing.synchronize import RLock as LockType
|
|
||||||
|
|
||||||
__all__ = ("Plugin", "PluginType", "get_all_plugins")
|
|
||||||
|
|
||||||
wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS)
|
|
||||||
P = ParamSpec("P")
|
|
||||||
T = TypeVar("T")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
|
||||||
|
|
||||||
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
|
||||||
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
|
||||||
|
|
||||||
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
|
||||||
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
|
||||||
|
|
||||||
_ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
|
||||||
|
|
||||||
_JOB_ATTR_NAME = "_job_data"
|
|
||||||
|
|
||||||
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
|
|
||||||
|
|
||||||
|
|
||||||
class _Plugin(PluginFuncs):
|
|
||||||
"""插件"""
|
|
||||||
|
|
||||||
_lock: ClassVar["LockType"] = Lock()
|
|
||||||
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
|
|
||||||
_installed: bool = False
|
|
||||||
|
|
||||||
_handlers: Optional[List[HandlerType]] = None
|
|
||||||
_error_handlers: Optional[List["ErrorHandlerData"]] = None
|
|
||||||
_jobs: Optional[List[Job]] = None
|
|
||||||
_application: "Optional[Application]" = None
|
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
|
||||||
self._application = application
|
|
||||||
|
|
||||||
@property
|
|
||||||
def application(self) -> "Application":
|
|
||||||
if self._application is None:
|
|
||||||
raise RuntimeError("No application was set for this Plugin.")
|
|
||||||
return self._application
|
|
||||||
|
|
||||||
@property
|
|
||||||
def handlers(self) -> List[HandlerType]:
|
|
||||||
"""该插件的所有 handler"""
|
|
||||||
with self._lock:
|
|
||||||
if self._handlers is None:
|
|
||||||
self._handlers = []
|
|
||||||
|
|
||||||
for attr in dir(self):
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
|
||||||
and isinstance(func := getattr(self, attr), MethodType)
|
|
||||||
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
|
||||||
):
|
|
||||||
for data in datas:
|
|
||||||
data: "HandlerData"
|
|
||||||
if data.admin:
|
|
||||||
self._handlers.append(
|
|
||||||
AdminHandler(
|
|
||||||
handler=data.type(
|
|
||||||
callback=func,
|
|
||||||
**data.kwargs,
|
|
||||||
),
|
|
||||||
application=self.application,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._handlers.append(
|
|
||||||
data.type(
|
|
||||||
callback=func,
|
|
||||||
**data.kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return self._handlers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def error_handlers(self) -> List["ErrorHandlerData"]:
|
|
||||||
with self._lock:
|
|
||||||
if self._error_handlers is None:
|
|
||||||
self._error_handlers = []
|
|
||||||
for attr in dir(self):
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
|
||||||
and isinstance(func := getattr(self, attr), MethodType)
|
|
||||||
and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, []))
|
|
||||||
):
|
|
||||||
for data in datas:
|
|
||||||
data: "ErrorHandlerData"
|
|
||||||
data.func = func
|
|
||||||
self._error_handlers.append(data)
|
|
||||||
|
|
||||||
return self._error_handlers
|
|
||||||
|
|
||||||
def _install_jobs(self) -> None:
|
|
||||||
if self._jobs is None:
|
|
||||||
self._jobs = []
|
|
||||||
for attr in dir(self):
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
|
||||||
and isinstance(func := getattr(self, attr), MethodType)
|
|
||||||
and (datas := getattr(func, _JOB_ATTR_NAME, []))
|
|
||||||
):
|
|
||||||
for data in datas:
|
|
||||||
data: "JobData"
|
|
||||||
self._jobs.append(
|
|
||||||
getattr(self.application.telegram.job_queue, data.type)(
|
|
||||||
callback=func,
|
|
||||||
**data.kwargs,
|
|
||||||
**{
|
|
||||||
key: value
|
|
||||||
for key, value in asdict(data).items()
|
|
||||||
if key not in ["type", "kwargs", "dispatcher"]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def jobs(self) -> List[Job]:
|
|
||||||
with self._lock:
|
|
||||||
if self._jobs is None:
|
|
||||||
self._jobs = []
|
|
||||||
self._install_jobs()
|
|
||||||
return self._jobs
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
"""初始化插件"""
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
"""销毁插件"""
|
|
||||||
|
|
||||||
async def install(self) -> None:
|
|
||||||
"""安装"""
|
|
||||||
group = id(self)
|
|
||||||
if not self._installed:
|
|
||||||
await self.initialize()
|
|
||||||
# initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题
|
|
||||||
async with self._asyncio_lock:
|
|
||||||
self._install_jobs()
|
|
||||||
|
|
||||||
for h in self.handlers:
|
|
||||||
if not isinstance(h, TypeHandler):
|
|
||||||
self.application.telegram.add_handler(h, group)
|
|
||||||
else:
|
|
||||||
self.application.telegram.add_handler(h, -1)
|
|
||||||
|
|
||||||
for h in self.error_handlers:
|
|
||||||
self.application.telegram.add_error_handler(h.func, h.block)
|
|
||||||
self._installed = True
|
|
||||||
|
|
||||||
async def uninstall(self) -> None:
|
|
||||||
"""卸载"""
|
|
||||||
group = id(self)
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
if self._installed:
|
|
||||||
if group in self.application.telegram.handlers:
|
|
||||||
del self.application.telegram.handlers[id(self)]
|
|
||||||
|
|
||||||
for h in self.handlers:
|
|
||||||
if isinstance(h, TypeHandler):
|
|
||||||
self.application.telegram.remove_handler(h, -1)
|
|
||||||
for h in self.error_handlers:
|
|
||||||
self.application.telegram.remove_error_handler(h.func)
|
|
||||||
|
|
||||||
for j in self.application.telegram.job_queue.jobs():
|
|
||||||
j.schedule_removal()
|
|
||||||
await self.shutdown()
|
|
||||||
self._installed = False
|
|
||||||
|
|
||||||
async def reload(self) -> None:
|
|
||||||
await self.uninstall()
|
|
||||||
await self.install()
|
|
||||||
|
|
||||||
|
|
||||||
class _Conversation(_Plugin, ConversationFuncs, ABC):
|
|
||||||
"""Conversation类"""
|
|
||||||
|
|
||||||
# noinspection SpellCheckingInspection
|
|
||||||
class Config(BaseModel):
|
|
||||||
allow_reentry: bool = False
|
|
||||||
per_chat: bool = True
|
|
||||||
per_user: bool = True
|
|
||||||
per_message: bool = False
|
|
||||||
conversation_timeout: Optional[Union[float, timedelta]] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
map_to_parent: Optional[Dict[object, object]] = None
|
|
||||||
block: bool = False
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
|
||||||
cls._conversation_kwargs = kwargs
|
|
||||||
super(_Conversation, cls).__init_subclass__()
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@property
|
|
||||||
def handlers(self) -> List[HandlerType]:
|
|
||||||
with self._lock:
|
|
||||||
if self._handlers is None:
|
|
||||||
self._handlers = []
|
|
||||||
|
|
||||||
entry_points: List[HandlerType] = []
|
|
||||||
states: Dict[Any, List[HandlerType]] = {}
|
|
||||||
fallbacks: List[HandlerType] = []
|
|
||||||
for attr in dir(self):
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
|
||||||
and (func := getattr(self, attr, None)) is not None
|
|
||||||
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
|
||||||
):
|
|
||||||
conversation_data: "ConversationData"
|
|
||||||
|
|
||||||
handlers: List[HandlerType] = []
|
|
||||||
for data in datas:
|
|
||||||
if data.admin:
|
|
||||||
handlers.append(
|
|
||||||
AdminHandler(
|
|
||||||
handler=data.type(
|
|
||||||
callback=func,
|
|
||||||
**data.kwargs,
|
|
||||||
),
|
|
||||||
application=self.application,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
handlers.append(
|
|
||||||
data.type(
|
|
||||||
callback=func,
|
|
||||||
**data.kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None):
|
|
||||||
if (_type := conversation_data.type) is ConversationDataType.Entry:
|
|
||||||
entry_points.extend(handlers)
|
|
||||||
elif _type is ConversationDataType.State:
|
|
||||||
if conversation_data.state in states:
|
|
||||||
states[conversation_data.state].extend(handlers)
|
|
||||||
else:
|
|
||||||
states[conversation_data.state] = handlers
|
|
||||||
elif _type is ConversationDataType.Fallback:
|
|
||||||
fallbacks.extend(handlers)
|
|
||||||
else:
|
|
||||||
self._handlers.extend(handlers)
|
|
||||||
else:
|
|
||||||
self._handlers.extend(handlers)
|
|
||||||
if entry_points and states and fallbacks:
|
|
||||||
kwargs = self._conversation_kwargs
|
|
||||||
kwargs.update(self.Config().dict())
|
|
||||||
self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs))
|
|
||||||
else:
|
|
||||||
temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks}
|
|
||||||
reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items()))
|
|
||||||
logger.warning(
|
|
||||||
"'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason)
|
|
||||||
)
|
|
||||||
return self._handlers
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin(_Plugin, ABC):
|
|
||||||
"""插件"""
|
|
||||||
|
|
||||||
Conversation = _Conversation
|
|
||||||
|
|
||||||
|
|
||||||
PluginType = TypeVar("PluginType", bound=_Plugin)
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_plugins() -> Iterable[Type[PluginType]]:
|
|
||||||
"""获取所有 Plugin 的子类"""
|
|
||||||
return filter(
|
|
||||||
lambda x: x.__name__[0] != "_" and not isabstract(x),
|
|
||||||
chain(Plugin.__subclasses__(), _Conversation.__subclasses__()),
|
|
||||||
)
|
|
@ -1,67 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
from typing import Callable, Coroutine, Any, Union, List, Dict, Optional, TypeVar, Type
|
|
||||||
|
|
||||||
from telegram.error import RetryAfter
|
|
||||||
from telegram.ext import BaseRateLimiter, ApplicationHandlerStop
|
|
||||||
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
JSONDict: Type[dict[str, Any]] = Dict[str, Any]
|
|
||||||
RL_ARGS = TypeVar("RL_ARGS")
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter(BaseRateLimiter[int]):
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
__slots__ = (
|
|
||||||
"_limiter_info",
|
|
||||||
"_retry_after_event",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._limiter_info: Dict[Union[str, int], float] = {}
|
|
||||||
self._retry_after_event = asyncio.Event()
|
|
||||||
self._retry_after_event.set()
|
|
||||||
|
|
||||||
async def process_request(
|
|
||||||
self,
|
|
||||||
callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, List[JSONDict]]]],
|
|
||||||
args: Any,
|
|
||||||
kwargs: Dict[str, Any],
|
|
||||||
endpoint: str,
|
|
||||||
data: Dict[str, Any],
|
|
||||||
rate_limit_args: Optional[RL_ARGS],
|
|
||||||
) -> Union[bool, JSONDict, List[JSONDict]]:
|
|
||||||
chat_id = data.get("chat_id")
|
|
||||||
|
|
||||||
with contextlib.suppress(ValueError, TypeError):
|
|
||||||
chat_id = int(chat_id)
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
time = loop.time()
|
|
||||||
|
|
||||||
await self._retry_after_event.wait()
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
chat_limit_time = self._limiter_info.get(chat_id)
|
|
||||||
if chat_limit_time:
|
|
||||||
if time >= chat_limit_time:
|
|
||||||
raise ApplicationHandlerStop
|
|
||||||
del self._limiter_info[chat_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await callback(*args, **kwargs)
|
|
||||||
except RetryAfter as exc:
|
|
||||||
logger.warning("chat_id[%s] 触发洪水限制 当前被服务器限制 retry_after[%s]秒", chat_id, exc.retry_after)
|
|
||||||
self._limiter_info[chat_id] = time + (exc.retry_after * 2)
|
|
||||||
sleep = exc.retry_after + 0.1
|
|
||||||
self._retry_after_event.clear()
|
|
||||||
await asyncio.sleep(sleep)
|
|
||||||
finally:
|
|
||||||
self._retry_after_event.set()
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
@ -1,97 +1,3 @@
|
|||||||
from typing import List, Union
|
from gram_core.services.cookies.cache import PublicCookiesCache
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.basemodel import RegionEnum
|
|
||||||
from core.dependence.redisdb import RedisDB
|
|
||||||
from core.services.cookies.error import CookiesCachePoolExhausted
|
|
||||||
from utils.error import RegionNotFoundError
|
|
||||||
|
|
||||||
__all__ = ("PublicCookiesCache",)
|
__all__ = ("PublicCookiesCache",)
|
||||||
|
|
||||||
|
|
||||||
class PublicCookiesCache(BaseService.Component):
|
|
||||||
"""使用优先级(score)进行排序,对使用次数最少的Cookies进行审核"""
|
|
||||||
|
|
||||||
def __init__(self, redis: RedisDB):
|
|
||||||
self.client = redis.client
|
|
||||||
self.score_qname = "cookie:public"
|
|
||||||
self.user_times_qname = "cookie:public:times"
|
|
||||||
self.end = 20
|
|
||||||
self.user_times_ttl = 60 * 60 * 24
|
|
||||||
|
|
||||||
def get_public_cookies_queue_name(self, region: RegionEnum):
|
|
||||||
if region == RegionEnum.HYPERION:
|
|
||||||
return f"{self.score_qname}:yuanshen"
|
|
||||||
if region == RegionEnum.HOYOLAB:
|
|
||||||
return f"{self.score_qname}:genshin"
|
|
||||||
raise RegionNotFoundError(region.name)
|
|
||||||
|
|
||||||
async def putback_public_cookies(self, uid: int, region: RegionEnum):
|
|
||||||
"""重新添加单个到缓存列表
|
|
||||||
:param uid:
|
|
||||||
:param region:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
qname = self.get_public_cookies_queue_name(region)
|
|
||||||
score_maps = {f"{uid}": 0}
|
|
||||||
result = await self.client.zrem(qname, f"{uid}")
|
|
||||||
if result == 1:
|
|
||||||
await self.client.zadd(qname, score_maps)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def add_public_cookies(self, uid: Union[List[int], int], region: RegionEnum):
|
|
||||||
"""单个或批量添加到缓存列表
|
|
||||||
:param uid:
|
|
||||||
:param region:
|
|
||||||
:return: 成功返回列表大小
|
|
||||||
"""
|
|
||||||
qname = self.get_public_cookies_queue_name(region)
|
|
||||||
if isinstance(uid, int):
|
|
||||||
score_maps = {f"{uid}": 0}
|
|
||||||
elif isinstance(uid, list):
|
|
||||||
score_maps = {f"{i}": 0 for i in uid}
|
|
||||||
else:
|
|
||||||
raise TypeError("uid variable type error")
|
|
||||||
async with self.client.pipeline(transaction=True) as pipe:
|
|
||||||
# nx:只添加新元素。不要更新已经存在的元素
|
|
||||||
await pipe.zadd(qname, score_maps, nx=True)
|
|
||||||
await pipe.zcard(qname)
|
|
||||||
add, count = await pipe.execute()
|
|
||||||
return int(add), count
|
|
||||||
|
|
||||||
async def get_public_cookies(self, region: RegionEnum):
|
|
||||||
"""从缓存列表获取
|
|
||||||
:param region:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
qname = self.get_public_cookies_queue_name(region)
|
|
||||||
scores = await self.client.zrange(qname, 0, self.end, withscores=True, score_cast_func=int)
|
|
||||||
if len(scores) <= 0:
|
|
||||||
raise CookiesCachePoolExhausted
|
|
||||||
key = scores[0][0]
|
|
||||||
score = scores[0][1]
|
|
||||||
async with self.client.pipeline(transaction=True) as pipe:
|
|
||||||
await pipe.zincrby(qname, 1, key)
|
|
||||||
await pipe.execute()
|
|
||||||
return int(key), score + 1
|
|
||||||
|
|
||||||
async def delete_public_cookies(self, uid: int, region: RegionEnum):
|
|
||||||
qname = self.get_public_cookies_queue_name(region)
|
|
||||||
async with self.client.pipeline(transaction=True) as pipe:
|
|
||||||
await pipe.zrem(qname, uid)
|
|
||||||
return await pipe.execute()
|
|
||||||
|
|
||||||
async def get_public_cookies_count(self, limit: bool = True):
|
|
||||||
async with self.client.pipeline(transaction=True) as pipe:
|
|
||||||
if limit:
|
|
||||||
await pipe.zcount(0, self.end)
|
|
||||||
else:
|
|
||||||
await pipe.zcard(self.score_qname)
|
|
||||||
return await pipe.execute()
|
|
||||||
|
|
||||||
async def incr_by_user_times(self, user_id: Union[List[int], int], amount: int = 1):
|
|
||||||
qname = f"{self.user_times_qname}:{user_id}"
|
|
||||||
times = await self.client.incrby(qname, amount)
|
|
||||||
if times <= 1:
|
|
||||||
await self.client.expire(qname, self.user_times_ttl)
|
|
||||||
return times
|
|
||||||
|
@ -1,12 +1,3 @@
|
|||||||
class CookieServiceError(Exception):
|
from gram_core.services.cookies.error import CookieServiceError, CookiesCachePoolExhausted, TooManyRequestPublicCookies
|
||||||
pass
|
|
||||||
|
|
||||||
|
__all__ = ("CookieServiceError", "CookiesCachePoolExhausted", "TooManyRequestPublicCookies")
|
||||||
class CookiesCachePoolExhausted(CookieServiceError):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__("Cookies cache pool is exhausted")
|
|
||||||
|
|
||||||
|
|
||||||
class TooManyRequestPublicCookies(CookieServiceError):
|
|
||||||
def __init__(self, user_id):
|
|
||||||
super().__init__(f"{user_id} too many request public cookies")
|
|
||||||
|
@ -1,39 +1,3 @@
|
|||||||
import enum
|
from gram_core.services.cookies.models import Cookies, CookiesDataBase, CookiesStatusEnum
|
||||||
from typing import Optional, Dict
|
|
||||||
|
|
||||||
from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index
|
|
||||||
|
|
||||||
from core.basemodel import RegionEnum
|
|
||||||
|
|
||||||
__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum")
|
__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum")
|
||||||
|
|
||||||
|
|
||||||
class CookiesStatusEnum(int, enum.Enum):
|
|
||||||
STATUS_SUCCESS = 0
|
|
||||||
INVALID_COOKIES = 1
|
|
||||||
TOO_MANY_REQUESTS = 2
|
|
||||||
|
|
||||||
|
|
||||||
class Cookies(SQLModel):
|
|
||||||
__table_args__ = (
|
|
||||||
Index("index_user_account", "user_id", "account_id", unique=True),
|
|
||||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
|
||||||
)
|
|
||||||
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
|
|
||||||
user_id: int = Field(
|
|
||||||
sa_column=Column(BigInteger()),
|
|
||||||
)
|
|
||||||
account_id: int = Field(
|
|
||||||
default=None,
|
|
||||||
sa_column=Column(
|
|
||||||
BigInteger(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
|
|
||||||
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
|
|
||||||
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
|
|
||||||
is_share: Optional[bool] = Field(sa_column=Column(Boolean))
|
|
||||||
|
|
||||||
|
|
||||||
class CookiesDataBase(Cookies, table=True):
|
|
||||||
__tablename__ = "cookies"
|
|
||||||
|
@ -1,55 +1,3 @@
|
|||||||
from typing import Optional, List
|
from gram_core.services.cookies.repositories import CookiesRepository
|
||||||
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.basemodel import RegionEnum
|
|
||||||
from core.dependence.database import Database
|
|
||||||
from core.services.cookies.models import CookiesDataBase as Cookies
|
|
||||||
from core.sqlmodel.session import AsyncSession
|
|
||||||
|
|
||||||
__all__ = ("CookiesRepository",)
|
__all__ = ("CookiesRepository",)
|
||||||
|
|
||||||
|
|
||||||
class CookiesRepository(BaseService.Component):
|
|
||||||
def __init__(self, database: Database):
|
|
||||||
self.engine = database.engine
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
self,
|
|
||||||
user_id: int,
|
|
||||||
account_id: Optional[int] = None,
|
|
||||||
region: Optional[RegionEnum] = None,
|
|
||||||
) -> Optional[Cookies]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(Cookies).where(Cookies.user_id == user_id)
|
|
||||||
if account_id is not None:
|
|
||||||
statement = statement.where(Cookies.account_id == account_id)
|
|
||||||
if region is not None:
|
|
||||||
statement = statement.where(Cookies.region == region)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
return results.first()
|
|
||||||
|
|
||||||
async def add(self, cookies: Cookies) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(cookies)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def update(self, cookies: Cookies) -> Cookies:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(cookies)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(cookies)
|
|
||||||
return cookies
|
|
||||||
|
|
||||||
async def delete(self, cookies: Cookies) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
await session.delete(cookies)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(Cookies).where(Cookies.region == region)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
cookies = results.all()
|
|
||||||
return cookies
|
|
||||||
|
@ -1,159 +1,80 @@
|
|||||||
from typing import List, Optional
|
from gram_core.base_service import BaseService
|
||||||
|
from gram_core.basemodel import RegionEnum
|
||||||
|
from gram_core.services.cookies.error import CookieServiceError
|
||||||
|
from gram_core.services.cookies.models import CookiesStatusEnum, CookiesDataBase as Cookies
|
||||||
|
from gram_core.services.cookies.services import (
|
||||||
|
CookiesService,
|
||||||
|
PublicCookiesService as BasePublicCookiesService,
|
||||||
|
NeedContinue,
|
||||||
|
)
|
||||||
|
|
||||||
from simnet import StarRailClient, Region, Game
|
from simnet import StarRailClient, Region, Game
|
||||||
from simnet.errors import InvalidCookies, BadRequest as SimnetBadRequest, TooManyRequests
|
from simnet.errors import InvalidCookies, TooManyRequests, BadRequest as SimnetBadRequest
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.basemodel import RegionEnum
|
|
||||||
from core.services.cookies.cache import PublicCookiesCache
|
|
||||||
from core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies
|
|
||||||
from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum
|
|
||||||
from core.services.cookies.repositories import CookiesRepository
|
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
__all__ = ("CookiesService", "PublicCookiesService")
|
__all__ = ("CookiesService", "PublicCookiesService")
|
||||||
|
|
||||||
|
|
||||||
class CookiesService(BaseService):
|
class PublicCookiesService(BaseService, BasePublicCookiesService):
|
||||||
def __init__(self, cookies_repository: CookiesRepository) -> None:
|
async def check_public_cookie(self, region: RegionEnum, cookies: Cookies, public_id: int):
|
||||||
self._repository: CookiesRepository = cookies_repository
|
if region == RegionEnum.HYPERION:
|
||||||
|
client = StarRailClient(cookies=cookies.data, region=Region.CHINESE)
|
||||||
async def update(self, cookies: Cookies):
|
elif region == RegionEnum.HOYOLAB:
|
||||||
await self._repository.update(cookies)
|
client = StarRailClient(cookies=cookies.data, region=Region.OVERSEAS, lang="zh-cn")
|
||||||
|
else:
|
||||||
async def add(self, cookies: Cookies):
|
raise CookieServiceError
|
||||||
await self._repository.add(cookies)
|
try:
|
||||||
|
if client.account_id is None:
|
||||||
async def get(
|
raise RuntimeError("account_id not found")
|
||||||
self,
|
record_cards = await client.get_record_cards()
|
||||||
user_id: int,
|
for record_card in record_cards:
|
||||||
account_id: Optional[int] = None,
|
if record_card.game == Game.STARRAIL:
|
||||||
region: Optional[RegionEnum] = None,
|
await client.get_starrail_user(record_card.uid)
|
||||||
) -> Optional[Cookies]:
|
break
|
||||||
return await self._repository.get(user_id, account_id, region)
|
|
||||||
|
|
||||||
async def delete(self, cookies: Cookies) -> None:
|
|
||||||
return await self._repository.delete(cookies)
|
|
||||||
|
|
||||||
|
|
||||||
class PublicCookiesService(BaseService):
|
|
||||||
def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache):
|
|
||||||
self._cache = public_cookies_cache
|
|
||||||
self._repository: CookiesRepository = cookies_repository
|
|
||||||
self.count: int = 0
|
|
||||||
self.user_times_limiter = 3 * 3
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.info("正在初始化公共Cookies池")
|
|
||||||
await self.refresh()
|
|
||||||
logger.success("刷新公共Cookies池成功")
|
|
||||||
|
|
||||||
async def refresh(self):
|
|
||||||
"""刷新公共Cookies 定时任务
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
user_list: List[int] = []
|
|
||||||
cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2
|
|
||||||
for cookies in cookies_list:
|
|
||||||
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
|
||||||
user_list.append(cookies.user_id)
|
|
||||||
if len(user_list) > 0:
|
|
||||||
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION)
|
|
||||||
logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
|
|
||||||
user_list.clear()
|
|
||||||
cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB)
|
|
||||||
for cookies in cookies_list:
|
|
||||||
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
|
||||||
user_list.append(cookies.user_id)
|
|
||||||
if len(user_list) > 0:
|
|
||||||
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB)
|
|
||||||
logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
|
|
||||||
|
|
||||||
async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL):
|
|
||||||
"""获取公共Cookies
|
|
||||||
:param user_id: 用户ID
|
|
||||||
:param region: 注册的服务器
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
user_times = await self._cache.incr_by_user_times(user_id)
|
|
||||||
if int(user_times) > self.user_times_limiter:
|
|
||||||
logger.warning("用户 %s 使用公共Cookies次数已经到达上限", user_id)
|
|
||||||
raise TooManyRequestPublicCookies(user_id)
|
|
||||||
while True:
|
|
||||||
public_id, count = await self._cache.get_public_cookies(region)
|
|
||||||
cookies = await self._repository.get(public_id, region=region)
|
|
||||||
if cookies is None:
|
|
||||||
await self._cache.delete_public_cookies(public_id, region)
|
|
||||||
continue
|
|
||||||
if region == RegionEnum.HYPERION:
|
|
||||||
client = StarRailClient(cookies=cookies.data, region=Region.CHINESE)
|
|
||||||
elif region == RegionEnum.HOYOLAB:
|
|
||||||
client = StarRailClient(cookies=cookies.data, region=Region.OVERSEAS, lang="zh-cn")
|
|
||||||
else:
|
else:
|
||||||
raise CookieServiceError
|
accounts = await client.get_game_accounts()
|
||||||
try:
|
for account in accounts:
|
||||||
if client.account_id is None:
|
if account.game == Game.STARRAIL:
|
||||||
raise RuntimeError("account_id not found")
|
await client.get_starrail_user(account.uid)
|
||||||
record_cards = await client.get_record_cards()
|
|
||||||
for record_card in record_cards:
|
|
||||||
if record_card.game == Game.STARRAIL:
|
|
||||||
await client.get_starrail_user(record_card.uid)
|
|
||||||
break
|
break
|
||||||
else:
|
except InvalidCookies as exc:
|
||||||
accounts = await client.get_game_accounts()
|
if exc.ret_code in (10001, -100):
|
||||||
for account in accounts:
|
logger.warning("用户 [%s] Cookies无效", public_id)
|
||||||
if account.game == Game.STARRAIL:
|
elif exc.ret_code == 10103:
|
||||||
await client.get_starrail_user(account.uid)
|
logger.warning("用户 [%s] Cookies有效,但没有绑定到游戏帐户", public_id)
|
||||||
break
|
else:
|
||||||
except InvalidCookies as exc:
|
logger.warning("Cookies无效 ")
|
||||||
if exc.ret_code in (10001, -100):
|
logger.exception(exc)
|
||||||
logger.warning("用户 [%s] Cookies无效", public_id)
|
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
||||||
elif exc.ret_code == 10103:
|
await self._repository.update(cookies)
|
||||||
logger.warning("用户 [%s] Cookies有效,但没有绑定到游戏帐户", public_id)
|
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||||
else:
|
raise NeedContinue
|
||||||
logger.warning("Cookies无效 ")
|
except TooManyRequests:
|
||||||
logger.exception(exc)
|
logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id)
|
||||||
|
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS
|
||||||
|
await self._repository.update(cookies)
|
||||||
|
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||||
|
raise NeedContinue
|
||||||
|
except SimnetBadRequest as exc:
|
||||||
|
if "invalid content type" in exc.message:
|
||||||
|
raise exc
|
||||||
|
if exc.ret_code == 1034:
|
||||||
|
logger.warning("用户 [%s] 触发验证", public_id)
|
||||||
|
else:
|
||||||
|
logger.warning("用户 [%s] 获取账号信息发生错误,错误信息为", public_id)
|
||||||
|
logger.exception(exc)
|
||||||
|
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||||
|
raise NeedContinue
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if "account_id not found" in str(exc):
|
||||||
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
||||||
await self._repository.update(cookies)
|
await self._repository.update(cookies)
|
||||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||||
continue
|
raise NeedContinue
|
||||||
except TooManyRequests:
|
raise exc
|
||||||
logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id)
|
except Exception as exc:
|
||||||
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS
|
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||||
await self._repository.update(cookies)
|
raise exc
|
||||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
finally:
|
||||||
continue
|
await client.shutdown()
|
||||||
except SimnetBadRequest as exc:
|
|
||||||
if "invalid content type" in exc.message:
|
|
||||||
raise exc
|
|
||||||
if exc.ret_code == 1034:
|
|
||||||
logger.warning("用户 [%s] 触发验证", public_id)
|
|
||||||
else:
|
|
||||||
logger.warning("用户 [%s] 获取账号信息发生错误,错误信息为", public_id)
|
|
||||||
logger.exception(exc)
|
|
||||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
|
||||||
continue
|
|
||||||
except RuntimeError as exc:
|
|
||||||
if "account_id not found" in str(exc):
|
|
||||||
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
|
||||||
await self._repository.update(cookies)
|
|
||||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
|
||||||
continue
|
|
||||||
raise exc
|
|
||||||
except Exception as exc:
|
|
||||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
|
||||||
raise exc
|
|
||||||
finally:
|
|
||||||
await client.shutdown()
|
|
||||||
logger.info("用户 user_id[%s] 请求用户 user_id[%s] 的公共Cookies 该Cookies使用次数为%s次 ", user_id, public_id, count)
|
|
||||||
return cookies
|
|
||||||
|
|
||||||
async def undo(self, user_id: int, cookies: Optional[Cookies] = None, status: Optional[CookiesStatusEnum] = None):
|
|
||||||
await self._cache.incr_by_user_times(user_id, -1)
|
|
||||||
if cookies is not None and status is not None:
|
|
||||||
cookies.status = status
|
|
||||||
await self._repository.update(cookies)
|
|
||||||
await self._cache.delete_public_cookies(cookies.user_id, cookies.region)
|
|
||||||
logger.info("用户 user_id[%s] 反馈用户 user_id[%s] 的Cookies状态为 %s", user_id, cookies.user_id, status.name)
|
|
||||||
else:
|
|
||||||
logger.info("用户 user_id[%s] 撤销一次公共Cookies计数", user_id)
|
|
||||||
|
@ -1,23 +1,3 @@
|
|||||||
from typing import Optional
|
from gram_core.services.devices.models import Devices, DevicesDataBase
|
||||||
|
|
||||||
from sqlmodel import SQLModel, Field, Column, Integer, BigInteger
|
|
||||||
|
|
||||||
__all__ = ("Devices", "DevicesDataBase")
|
__all__ = ("Devices", "DevicesDataBase")
|
||||||
|
|
||||||
|
|
||||||
class Devices(SQLModel):
|
|
||||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
|
||||||
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
|
|
||||||
account_id: int = Field(
|
|
||||||
default=None,
|
|
||||||
sa_column=Column(
|
|
||||||
BigInteger(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
device_id: str = Field()
|
|
||||||
device_fp: str = Field()
|
|
||||||
device_name: Optional[str] = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class DevicesDataBase(Devices, table=True):
|
|
||||||
__tablename__ = "devices"
|
|
||||||
|
@ -1,41 +1,3 @@
|
|||||||
from typing import Optional
|
from gram_core.services.devices.repositories import DevicesRepository
|
||||||
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.dependence.database import Database
|
|
||||||
from core.services.devices.models import DevicesDataBase as Devices
|
|
||||||
from core.sqlmodel.session import AsyncSession
|
|
||||||
|
|
||||||
__all__ = ("DevicesRepository",)
|
__all__ = ("DevicesRepository",)
|
||||||
|
|
||||||
|
|
||||||
class DevicesRepository(BaseService.Component):
|
|
||||||
def __init__(self, database: Database):
|
|
||||||
self.engine = database.engine
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
self,
|
|
||||||
account_id: int,
|
|
||||||
) -> Optional[Devices]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(Devices).where(Devices.account_id == account_id)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
return results.first()
|
|
||||||
|
|
||||||
async def add(self, devices: Devices) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(devices)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def update(self, devices: Devices) -> Devices:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(devices)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(devices)
|
|
||||||
return devices
|
|
||||||
|
|
||||||
async def delete(self, devices: Devices) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
await session.delete(devices)
|
|
||||||
await session.commit()
|
|
||||||
|
@ -1,25 +1,3 @@
|
|||||||
from typing import Optional
|
from gram_core.services.devices.services import DevicesService
|
||||||
|
|
||||||
from core.base_service import BaseService
|
__all__ = ("DevicesService",)
|
||||||
from core.services.devices.repositories import DevicesRepository
|
|
||||||
from core.services.devices.models import DevicesDataBase as Devices
|
|
||||||
|
|
||||||
|
|
||||||
class DevicesService(BaseService):
|
|
||||||
def __init__(self, devices_repository: DevicesRepository) -> None:
|
|
||||||
self._repository: DevicesRepository = devices_repository
|
|
||||||
|
|
||||||
async def update(self, devices: Devices):
|
|
||||||
await self._repository.update(devices)
|
|
||||||
|
|
||||||
async def add(self, devices: Devices):
|
|
||||||
await self._repository.add(devices)
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
self,
|
|
||||||
account_id: int,
|
|
||||||
) -> Optional[Devices]:
|
|
||||||
return await self._repository.get(account_id)
|
|
||||||
|
|
||||||
async def delete(self, devices: Devices) -> None:
|
|
||||||
return await self._repository.delete(devices)
|
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
class PlayerNotFoundError(Exception):
|
from gram_core.services.players.error import PlayerNotFoundError
|
||||||
pass
|
|
||||||
|
__all__ = ("PlayerNotFoundError",)
|
||||||
|
@ -1,96 +1,3 @@
|
|||||||
from datetime import datetime
|
from gram_core.services.players.models import Player, PlayersDataBase, PlayerInfo, PlayerInfoSQLModel
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, BaseSettings
|
|
||||||
from sqlalchemy import TypeDecorator
|
|
||||||
from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime
|
|
||||||
|
|
||||||
from core.basemodel import RegionEnum
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ujson as jsonlib
|
|
||||||
except ImportError:
|
|
||||||
import json as jsonlib
|
|
||||||
|
|
||||||
__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel")
|
__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel")
|
||||||
|
|
||||||
|
|
||||||
class Player(SQLModel):
|
|
||||||
__table_args__ = (
|
|
||||||
Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
|
|
||||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
|
||||||
)
|
|
||||||
id: Optional[int] = Field(
|
|
||||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
|
||||||
)
|
|
||||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
|
||||||
account_id: int = Field(default=None, primary_key=True, sa_column=Column(BigInteger()))
|
|
||||||
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
|
||||||
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
|
|
||||||
is_chosen: Optional[bool] = Field(sa_column=Column(Boolean))
|
|
||||||
|
|
||||||
|
|
||||||
class PlayersDataBase(Player, table=True):
|
|
||||||
__tablename__ = "players"
|
|
||||||
|
|
||||||
|
|
||||||
class ExtraPlayerInfo(BaseModel):
|
|
||||||
class Config(BaseSettings.Config):
|
|
||||||
json_loads = jsonlib.loads
|
|
||||||
json_dumps = jsonlib.dumps
|
|
||||||
|
|
||||||
waifu_id: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223
|
|
||||||
impl = VARCHAR(length=521)
|
|
||||||
|
|
||||||
cache_ok = True
|
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect):
|
|
||||||
"""
|
|
||||||
:param value: ExtraPlayerInfo | obj | None
|
|
||||||
:param dialect:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if value is not None:
|
|
||||||
if isinstance(value, ExtraPlayerInfo):
|
|
||||||
return value.json()
|
|
||||||
raise TypeError
|
|
||||||
return value
|
|
||||||
|
|
||||||
def process_result_value(self, value, dialect):
|
|
||||||
"""
|
|
||||||
:param value: str | obj | None
|
|
||||||
:param dialect:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if value is not None:
|
|
||||||
return ExtraPlayerInfo.parse_raw(value)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class PlayerInfo(SQLModel):
|
|
||||||
__table_args__ = (
|
|
||||||
Index("index_user_account_player", "user_id", "player_id", unique=True),
|
|
||||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
|
||||||
)
|
|
||||||
id: Optional[int] = Field(
|
|
||||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
|
||||||
)
|
|
||||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
|
||||||
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
|
||||||
nickname: Optional[str] = Field()
|
|
||||||
signature: Optional[str] = Field()
|
|
||||||
hand_image: Optional[int] = Field()
|
|
||||||
name_card: Optional[int] = Field()
|
|
||||||
extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType))
|
|
||||||
create_time: Optional[datetime] = Field(
|
|
||||||
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
|
|
||||||
)
|
|
||||||
last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
|
|
||||||
is_update: Optional[bool] = Field(sa_column=Column(Boolean))
|
|
||||||
|
|
||||||
|
|
||||||
class PlayerInfoSQLModel(PlayerInfo, table=True):
|
|
||||||
__tablename__ = "players_info"
|
|
||||||
|
@ -1,110 +1,3 @@
|
|||||||
from typing import List, Optional
|
from gram_core.services.players.repositories import PlayersRepository, PlayerInfoRepository
|
||||||
|
|
||||||
from sqlmodel import select, delete
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.basemodel import RegionEnum
|
|
||||||
from core.dependence.database import Database
|
|
||||||
from core.services.players.models import PlayerInfoSQLModel
|
|
||||||
from core.services.players.models import PlayersDataBase as Player
|
|
||||||
from core.sqlmodel.session import AsyncSession
|
|
||||||
|
|
||||||
__all__ = ("PlayersRepository", "PlayerInfoRepository")
|
__all__ = ("PlayersRepository", "PlayerInfoRepository")
|
||||||
|
|
||||||
|
|
||||||
class PlayersRepository(BaseService.Component):
|
|
||||||
def __init__(self, database: Database):
|
|
||||||
self.engine = database.engine
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
self,
|
|
||||||
user_id: int,
|
|
||||||
player_id: Optional[int] = None,
|
|
||||||
account_id: Optional[int] = None,
|
|
||||||
region: Optional[RegionEnum] = None,
|
|
||||||
is_chosen: Optional[bool] = None,
|
|
||||||
) -> Optional[Player]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(Player).where(Player.user_id == user_id)
|
|
||||||
if player_id is not None:
|
|
||||||
statement = statement.where(Player.player_id == player_id)
|
|
||||||
if account_id is not None:
|
|
||||||
statement = statement.where(Player.account_id == account_id)
|
|
||||||
if region is not None:
|
|
||||||
statement = statement.where(Player.region == region)
|
|
||||||
if is_chosen is not None:
|
|
||||||
statement = statement.where(Player.is_chosen == is_chosen)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
return results.first()
|
|
||||||
|
|
||||||
async def add(self, player: Player) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(player)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(player)
|
|
||||||
|
|
||||||
async def delete(self, player: Player) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
await session.delete(player)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def update(self, player: Player) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(player)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(player)
|
|
||||||
|
|
||||||
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(Player).where(Player.user_id == user_id)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
players = results.all()
|
|
||||||
return players
|
|
||||||
|
|
||||||
|
|
||||||
class PlayerInfoRepository(BaseService.Component):
|
|
||||||
def __init__(self, database: Database):
|
|
||||||
self.engine = database.engine
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
self,
|
|
||||||
user_id: int,
|
|
||||||
player_id: int,
|
|
||||||
) -> Optional[PlayerInfoSQLModel]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = (
|
|
||||||
select(PlayerInfoSQLModel)
|
|
||||||
.where(PlayerInfoSQLModel.player_id == player_id)
|
|
||||||
.where(PlayerInfoSQLModel.user_id == user_id)
|
|
||||||
)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
return results.first()
|
|
||||||
|
|
||||||
async def add(self, player: PlayerInfoSQLModel) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(player)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def delete(self, player: PlayerInfoSQLModel) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
await session.delete(player)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def delete_by_id(
|
|
||||||
self,
|
|
||||||
user_id: int,
|
|
||||||
player_id: int,
|
|
||||||
) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = (
|
|
||||||
delete(PlayerInfoSQLModel)
|
|
||||||
.where(PlayerInfoSQLModel.player_id == player_id)
|
|
||||||
.where(PlayerInfoSQLModel.user_id == user_id)
|
|
||||||
)
|
|
||||||
await session.execute(statement)
|
|
||||||
|
|
||||||
async def update(self, player: PlayerInfoSQLModel) -> None:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(player)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(player)
|
|
||||||
|
@ -1,52 +1,18 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.base_service import BaseService
|
from core.base_service import BaseService
|
||||||
from core.basemodel import RegionEnum
|
|
||||||
from core.dependence.redisdb import RedisDB
|
from core.dependence.redisdb import RedisDB
|
||||||
from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo
|
from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo
|
||||||
from core.services.players.repositories import PlayersRepository, PlayerInfoRepository
|
from core.services.players.repositories import PlayerInfoRepository
|
||||||
from modules.apihelper.client.components.player_cards import PlayerCards, PlayerBaseInfo
|
from modules.apihelper.client.components.player_cards import PlayerCards, PlayerBaseInfo
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
|
from gram_core.services.players import PlayersService
|
||||||
|
|
||||||
__all__ = ("PlayersService", "PlayerInfoService")
|
__all__ = ("PlayersService", "PlayerInfoService")
|
||||||
|
|
||||||
|
|
||||||
class PlayersService(BaseService):
|
|
||||||
def __init__(self, players_repository: PlayersRepository) -> None:
|
|
||||||
self._repository = players_repository
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
self,
|
|
||||||
user_id: int,
|
|
||||||
player_id: Optional[int] = None,
|
|
||||||
account_id: Optional[int] = None,
|
|
||||||
region: Optional[RegionEnum] = None,
|
|
||||||
is_chosen: Optional[bool] = None,
|
|
||||||
) -> Optional[Player]:
|
|
||||||
return await self._repository.get(user_id, player_id, account_id, region, is_chosen)
|
|
||||||
|
|
||||||
async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]:
|
|
||||||
return await self._repository.get(user_id, region=region, is_chosen=True)
|
|
||||||
|
|
||||||
async def add(self, player: Player) -> None:
|
|
||||||
await self._repository.add(player)
|
|
||||||
|
|
||||||
async def update(self, player: Player) -> None:
|
|
||||||
await self._repository.update(player)
|
|
||||||
|
|
||||||
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
|
|
||||||
return await self._repository.get_all_by_user_id(user_id)
|
|
||||||
|
|
||||||
async def remove_all_by_user_id(self, user_id: int):
|
|
||||||
players = await self._repository.get_all_by_user_id(user_id)
|
|
||||||
for player in players:
|
|
||||||
await self._repository.delete(player)
|
|
||||||
|
|
||||||
async def delete(self, player: Player):
|
|
||||||
await self._repository.delete(player)
|
|
||||||
|
|
||||||
|
|
||||||
class PlayerInfoService(BaseService):
|
class PlayerInfoService(BaseService):
|
||||||
def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository):
|
def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository):
|
||||||
self.cache = redis.client
|
self.cache = redis.client
|
||||||
|
@ -1,44 +1,3 @@
|
|||||||
import enum
|
from gram_core.services.task.models import Task, TaskStatusEnum, TaskTypeEnum
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
|
|
||||||
from sqlalchemy import func, BigInteger, JSON
|
|
||||||
from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer
|
|
||||||
|
|
||||||
__all__ = ("Task", "TaskStatusEnum", "TaskTypeEnum")
|
__all__ = ("Task", "TaskStatusEnum", "TaskTypeEnum")
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusEnum(int, enum.Enum):
|
|
||||||
STATUS_SUCCESS = 0 # 任务执行成功
|
|
||||||
INVALID_COOKIES = 1 # Cookie无效
|
|
||||||
ALREADY_CLAIMED = 2 # 已经获取奖励
|
|
||||||
NEED_CHALLENGE = 3 # 需要验证码
|
|
||||||
GENSHIN_EXCEPTION = 4 # API异常
|
|
||||||
TIMEOUT_ERROR = 5 # 请求超时
|
|
||||||
BAD_REQUEST = 6 # 请求失败
|
|
||||||
FORBIDDEN = 7 # 这错误一般为通知失败 机器人被用户BAN
|
|
||||||
|
|
||||||
|
|
||||||
class TaskTypeEnum(int, enum.Enum):
|
|
||||||
SIGN = 0 # 签到
|
|
||||||
RESIN = 1 # 开拓力
|
|
||||||
REALM = 2 # 洞天宝钱
|
|
||||||
EXPEDITION = 3 # 委托
|
|
||||||
TRANSFORMER = 4 # 参量质变仪
|
|
||||||
CARD = 5 # 生日画片
|
|
||||||
|
|
||||||
|
|
||||||
class Task(SQLModel, table=True):
|
|
||||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
|
||||||
id: Optional[int] = Field(
|
|
||||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
|
||||||
)
|
|
||||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True))
|
|
||||||
chat_id: Optional[int] = Field(default=None, sa_column=Column(BigInteger()))
|
|
||||||
time_created: Optional[datetime] = Field(
|
|
||||||
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
|
|
||||||
)
|
|
||||||
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
|
|
||||||
type: TaskTypeEnum = Field(primary_key=True, sa_column=Column(Enum(TaskTypeEnum)))
|
|
||||||
status: Optional[TaskStatusEnum] = Field(sa_column=Column(Enum(TaskStatusEnum)))
|
|
||||||
data: Optional[Dict[str, Any]] = Field(sa_column=Column(JSON))
|
|
||||||
|
@ -1,50 +1,3 @@
|
|||||||
from typing import List, Optional
|
from gram_core.services.task.repositories import TaskRepository
|
||||||
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.dependence.database import Database
|
|
||||||
from core.services.task.models import Task, TaskTypeEnum
|
|
||||||
from core.sqlmodel.session import AsyncSession
|
|
||||||
|
|
||||||
__all__ = ("TaskRepository",)
|
__all__ = ("TaskRepository",)
|
||||||
|
|
||||||
|
|
||||||
class TaskRepository(BaseService.Component):
|
|
||||||
def __init__(self, database: Database):
|
|
||||||
self.engine = database.engine
|
|
||||||
|
|
||||||
async def add(self, task: Task):
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(task)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def remove(self, task: Task):
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
await session.delete(task)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def update(self, task: Task) -> Task:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(task)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(task)
|
|
||||||
return task
|
|
||||||
|
|
||||||
async def get_by_user_id(self, user_id: int, task_type: TaskTypeEnum) -> Optional[Task]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(Task).where(Task.user_id == user_id).where(Task.type == task_type)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
return results.first()
|
|
||||||
|
|
||||||
async def get_by_chat_id(self, chat_id: int, task_type: TaskTypeEnum) -> Optional[List[Task]]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(Task).where(Task.chat_id == chat_id).where(Task.type == task_type)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
return results.all()
|
|
||||||
|
|
||||||
async def get_all(self, task_type: TaskTypeEnum) -> List[Task]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
query = select(Task).where(Task.type == task_type)
|
|
||||||
results = await session.exec(query)
|
|
||||||
return results.all()
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import datetime
|
from gram_core.services.task.services import (
|
||||||
from typing import Optional, Dict, Any
|
TaskServices,
|
||||||
|
SignServices,
|
||||||
from core.base_service import BaseService
|
TaskCardServices,
|
||||||
from core.services.task.models import Task, TaskTypeEnum
|
TaskResinServices,
|
||||||
from core.services.task.repositories import TaskRepository
|
TaskExpeditionServices,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TaskServices",
|
"TaskServices",
|
||||||
@ -12,168 +13,3 @@ __all__ = [
|
|||||||
"TaskResinServices",
|
"TaskResinServices",
|
||||||
"TaskExpeditionServices",
|
"TaskExpeditionServices",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TaskServices(BaseService):
|
|
||||||
TASK_TYPE: TaskTypeEnum
|
|
||||||
|
|
||||||
def __init__(self, task_repository: TaskRepository) -> None:
|
|
||||||
self._repository: TaskRepository = task_repository
|
|
||||||
|
|
||||||
async def add(self, task: Task):
|
|
||||||
return await self._repository.add(task)
|
|
||||||
|
|
||||||
async def remove(self, task: Task):
|
|
||||||
return await self._repository.remove(task)
|
|
||||||
|
|
||||||
async def update(self, task: Task):
|
|
||||||
task.time_updated = datetime.datetime.now()
|
|
||||||
return await self._repository.update(task)
|
|
||||||
|
|
||||||
async def get_by_user_id(self, user_id: int):
|
|
||||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
|
||||||
|
|
||||||
async def get_all(self):
|
|
||||||
return await self._repository.get_all(self.TASK_TYPE)
|
|
||||||
|
|
||||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
|
||||||
return Task(
|
|
||||||
user_id=user_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
time_created=datetime.datetime.now(),
|
|
||||||
status=status,
|
|
||||||
type=self.TASK_TYPE,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SignServices(BaseService):
|
|
||||||
TASK_TYPE = TaskTypeEnum.SIGN
|
|
||||||
|
|
||||||
def __init__(self, task_repository: TaskRepository) -> None:
|
|
||||||
self._repository: TaskRepository = task_repository
|
|
||||||
|
|
||||||
async def add(self, task: Task):
|
|
||||||
return await self._repository.add(task)
|
|
||||||
|
|
||||||
async def remove(self, task: Task):
|
|
||||||
return await self._repository.remove(task)
|
|
||||||
|
|
||||||
async def update(self, task: Task):
|
|
||||||
task.time_updated = datetime.datetime.now()
|
|
||||||
return await self._repository.update(task)
|
|
||||||
|
|
||||||
async def get_by_user_id(self, user_id: int):
|
|
||||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
|
||||||
|
|
||||||
async def get_all(self):
|
|
||||||
return await self._repository.get_all(self.TASK_TYPE)
|
|
||||||
|
|
||||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
|
||||||
return Task(
|
|
||||||
user_id=user_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
time_created=datetime.datetime.now(),
|
|
||||||
status=status,
|
|
||||||
type=self.TASK_TYPE,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskCardServices(BaseService):
|
|
||||||
TASK_TYPE = TaskTypeEnum.CARD
|
|
||||||
|
|
||||||
def __init__(self, task_repository: TaskRepository) -> None:
|
|
||||||
self._repository: TaskRepository = task_repository
|
|
||||||
|
|
||||||
async def add(self, task: Task):
|
|
||||||
return await self._repository.add(task)
|
|
||||||
|
|
||||||
async def remove(self, task: Task):
|
|
||||||
return await self._repository.remove(task)
|
|
||||||
|
|
||||||
async def update(self, task: Task):
|
|
||||||
task.time_updated = datetime.datetime.now()
|
|
||||||
return await self._repository.update(task)
|
|
||||||
|
|
||||||
async def get_by_user_id(self, user_id: int):
|
|
||||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
|
||||||
|
|
||||||
async def get_all(self):
|
|
||||||
return await self._repository.get_all(self.TASK_TYPE)
|
|
||||||
|
|
||||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
|
||||||
return Task(
|
|
||||||
user_id=user_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
time_created=datetime.datetime.now(),
|
|
||||||
status=status,
|
|
||||||
type=self.TASK_TYPE,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskResinServices(BaseService):
|
|
||||||
TASK_TYPE = TaskTypeEnum.RESIN
|
|
||||||
|
|
||||||
def __init__(self, task_repository: TaskRepository) -> None:
|
|
||||||
self._repository: TaskRepository = task_repository
|
|
||||||
|
|
||||||
async def add(self, task: Task):
|
|
||||||
return await self._repository.add(task)
|
|
||||||
|
|
||||||
async def remove(self, task: Task):
|
|
||||||
return await self._repository.remove(task)
|
|
||||||
|
|
||||||
async def update(self, task: Task):
|
|
||||||
task.time_updated = datetime.datetime.now()
|
|
||||||
return await self._repository.update(task)
|
|
||||||
|
|
||||||
async def get_by_user_id(self, user_id: int):
|
|
||||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
|
||||||
|
|
||||||
async def get_all(self):
|
|
||||||
return await self._repository.get_all(self.TASK_TYPE)
|
|
||||||
|
|
||||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
|
||||||
return Task(
|
|
||||||
user_id=user_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
time_created=datetime.datetime.now(),
|
|
||||||
status=status,
|
|
||||||
type=self.TASK_TYPE,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskExpeditionServices(BaseService):
|
|
||||||
TASK_TYPE = TaskTypeEnum.EXPEDITION
|
|
||||||
|
|
||||||
def __init__(self, task_repository: TaskRepository) -> None:
|
|
||||||
self._repository: TaskRepository = task_repository
|
|
||||||
|
|
||||||
async def add(self, task: Task):
|
|
||||||
return await self._repository.add(task)
|
|
||||||
|
|
||||||
async def remove(self, task: Task):
|
|
||||||
return await self._repository.remove(task)
|
|
||||||
|
|
||||||
async def update(self, task: Task):
|
|
||||||
task.time_updated = datetime.datetime.now()
|
|
||||||
return await self._repository.update(task)
|
|
||||||
|
|
||||||
async def get_by_user_id(self, user_id: int):
|
|
||||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
|
||||||
|
|
||||||
async def get_all(self):
|
|
||||||
return await self._repository.get_all(self.TASK_TYPE)
|
|
||||||
|
|
||||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
|
||||||
return Task(
|
|
||||||
user_id=user_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
time_created=datetime.datetime.now(),
|
|
||||||
status=status,
|
|
||||||
type=self.TASK_TYPE,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
@ -1,58 +1,3 @@
|
|||||||
import gzip
|
from gram_core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
|
||||||
import pickle # nosec B403
|
|
||||||
from hashlib import sha256
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.dependence.redisdb import RedisDB
|
|
||||||
|
|
||||||
__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"]
|
__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"]
|
||||||
|
|
||||||
|
|
||||||
class TemplatePreviewCache(BaseService.Component):
|
|
||||||
"""暂存渲染模板的数据用于预览"""
|
|
||||||
|
|
||||||
def __init__(self, redis: RedisDB):
|
|
||||||
self.client = redis.client
|
|
||||||
self.qname = "bot:template:preview"
|
|
||||||
|
|
||||||
async def get_data(self, key: str) -> Any:
|
|
||||||
data = await self.client.get(self.cache_key(key))
|
|
||||||
if data:
|
|
||||||
# skipcq: BAN-B301
|
|
||||||
return pickle.loads(gzip.decompress(data)) # nosec B301
|
|
||||||
|
|
||||||
async def set_data(self, key: str, data: Any, ttl: int = 8 * 60 * 60):
|
|
||||||
ck = self.cache_key(key)
|
|
||||||
await self.client.set(ck, gzip.compress(pickle.dumps(data)))
|
|
||||||
if ttl != -1:
|
|
||||||
await self.client.expire(ck, ttl)
|
|
||||||
|
|
||||||
def cache_key(self, key: str) -> str:
|
|
||||||
return f"{self.qname}:{key}"
|
|
||||||
|
|
||||||
|
|
||||||
class HtmlToFileIdCache(BaseService.Component):
|
|
||||||
"""html to file_id 的缓存"""
|
|
||||||
|
|
||||||
def __init__(self, redis: RedisDB):
|
|
||||||
self.client = redis.client
|
|
||||||
self.qname = "bot:template:html-to-file-id"
|
|
||||||
|
|
||||||
async def get_data(self, html: str, file_type: str) -> Optional[str]:
|
|
||||||
data = await self.client.get(self.cache_key(html, file_type))
|
|
||||||
if data:
|
|
||||||
return data.decode()
|
|
||||||
|
|
||||||
async def set_data(self, html: str, file_type: str, file_id: str, ttl: int = 24 * 60 * 60):
|
|
||||||
ck = self.cache_key(html, file_type)
|
|
||||||
await self.client.set(ck, file_id)
|
|
||||||
if ttl != -1:
|
|
||||||
await self.client.expire(ck, ttl)
|
|
||||||
|
|
||||||
async def delete_data(self, html: str, file_type: str) -> bool:
|
|
||||||
return await self.client.delete(self.cache_key(html, file_type))
|
|
||||||
|
|
||||||
def cache_key(self, html: str, file_type: str) -> str:
|
|
||||||
key = sha256(html.encode()).hexdigest()
|
|
||||||
return f"{self.qname}:{file_type}:{key}"
|
|
||||||
|
@ -1,14 +1,8 @@
|
|||||||
class TemplateException(Exception):
|
from gram_core.services.template.error import (
|
||||||
pass
|
ErrorFileType,
|
||||||
|
FileIdNotFound,
|
||||||
|
QuerySelectorNotFound,
|
||||||
|
TemplateException,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ("TemplateException", "QuerySelectorNotFound", "ErrorFileType", "FileIdNotFound")
|
||||||
class QuerySelectorNotFound(TemplateException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorFileType(TemplateException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FileIdNotFound(TemplateException):
|
|
||||||
pass
|
|
||||||
|
@ -1,146 +1,3 @@
|
|||||||
from enum import Enum
|
from gram_core.services.template.models import FileType, RenderResult, RenderGroupResult
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from telegram import InputMediaDocument, InputMediaPhoto, Message
|
|
||||||
from telegram.error import BadRequest
|
|
||||||
|
|
||||||
from core.services.template.cache import HtmlToFileIdCache
|
|
||||||
from core.services.template.error import ErrorFileType, FileIdNotFound
|
|
||||||
|
|
||||||
__all__ = ["FileType", "RenderResult", "RenderGroupResult"]
|
__all__ = ["FileType", "RenderResult", "RenderGroupResult"]
|
||||||
|
|
||||||
|
|
||||||
class FileType(Enum):
|
|
||||||
PHOTO = 1
|
|
||||||
DOCUMENT = 2
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def media_type(file_type: "FileType"):
|
|
||||||
"""对应的 Telegram media 类型"""
|
|
||||||
if file_type == FileType.PHOTO:
|
|
||||||
return InputMediaPhoto
|
|
||||||
if file_type == FileType.DOCUMENT:
|
|
||||||
return InputMediaDocument
|
|
||||||
raise ErrorFileType
|
|
||||||
|
|
||||||
|
|
||||||
class RenderResult:
|
|
||||||
"""渲染结果"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
html: str,
|
|
||||||
photo: Union[bytes, str],
|
|
||||||
file_type: FileType,
|
|
||||||
cache: HtmlToFileIdCache,
|
|
||||||
ttl: int = 24 * 60 * 60,
|
|
||||||
caption: Optional[str] = None,
|
|
||||||
parse_mode: Optional[str] = None,
|
|
||||||
filename: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
`html`: str 渲染生成的 html
|
|
||||||
`photo`: Union[bytes, str] 渲染生成的图片。bytes 表示是图片,str 则为 file_id
|
|
||||||
"""
|
|
||||||
self.caption = caption
|
|
||||||
self.parse_mode = parse_mode
|
|
||||||
self.filename = filename
|
|
||||||
self.html = html
|
|
||||||
self.photo = photo
|
|
||||||
self.file_type = file_type
|
|
||||||
self._cache = cache
|
|
||||||
self.ttl = ttl
|
|
||||||
|
|
||||||
async def reply_photo(self, message: Message, *args, **kwargs):
|
|
||||||
"""是 `message.reply_photo` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用"""
|
|
||||||
if self.file_type != FileType.PHOTO:
|
|
||||||
raise ErrorFileType
|
|
||||||
|
|
||||||
try:
|
|
||||||
reply = await message.reply_photo(photo=self.photo, *args, **kwargs)
|
|
||||||
except BadRequest as exc:
|
|
||||||
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
|
|
||||||
await self._cache.delete_data(self.html, self.file_type.name)
|
|
||||||
raise BadRequest(message="Wrong file identifier specified")
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
await self.cache_file_id(reply)
|
|
||||||
|
|
||||||
return reply
|
|
||||||
|
|
||||||
async def reply_document(self, message: Message, *args, **kwargs):
|
|
||||||
"""是 `message.reply_document` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用"""
|
|
||||||
if self.file_type != FileType.DOCUMENT:
|
|
||||||
raise ErrorFileType
|
|
||||||
|
|
||||||
try:
|
|
||||||
reply = await message.reply_document(document=self.photo, *args, **kwargs)
|
|
||||||
except BadRequest as exc:
|
|
||||||
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
|
|
||||||
await self._cache.delete_data(self.html, self.file_type.name)
|
|
||||||
raise BadRequest(message="Wrong file identifier specified")
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
await self.cache_file_id(reply)
|
|
||||||
|
|
||||||
return reply
|
|
||||||
|
|
||||||
async def edit_media(self, message: Message, *args, **kwargs):
|
|
||||||
"""是 `message.edit_media` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用"""
|
|
||||||
if self.file_type != FileType.PHOTO:
|
|
||||||
raise ErrorFileType
|
|
||||||
|
|
||||||
media = InputMediaPhoto(
|
|
||||||
media=self.photo, caption=self.caption, parse_mode=self.parse_mode, filename=self.filename
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
edit_media = await message.edit_media(media, *args, **kwargs)
|
|
||||||
except BadRequest as exc:
|
|
||||||
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
|
|
||||||
await self._cache.delete_data(self.html, self.file_type.name)
|
|
||||||
raise BadRequest(message="Wrong file identifier specified")
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
await self.cache_file_id(edit_media)
|
|
||||||
|
|
||||||
return edit_media
|
|
||||||
|
|
||||||
async def cache_file_id(self, reply: Message):
|
|
||||||
"""缓存 telegram 返回的 file_id"""
|
|
||||||
if self.is_file_id():
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.file_type == FileType.PHOTO and reply.photo:
|
|
||||||
file_id = reply.photo[0].file_id
|
|
||||||
elif self.file_type == FileType.DOCUMENT and reply.document:
|
|
||||||
file_id = reply.document.file_id
|
|
||||||
else:
|
|
||||||
raise FileIdNotFound
|
|
||||||
await self._cache.set_data(self.html, self.file_type.name, file_id, self.ttl)
|
|
||||||
|
|
||||||
def is_file_id(self) -> bool:
|
|
||||||
return isinstance(self.photo, str)
|
|
||||||
|
|
||||||
|
|
||||||
class RenderGroupResult:
|
|
||||||
def __init__(self, results: List[RenderResult]):
|
|
||||||
self.results = results
|
|
||||||
|
|
||||||
async def reply_media_group(self, message: Message, *args, **kwargs):
|
|
||||||
"""是 `message.reply_media_group` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用"""
|
|
||||||
|
|
||||||
reply = await message.reply_media_group(
|
|
||||||
media=[
|
|
||||||
FileType.media_type(result.file_type)(
|
|
||||||
media=result.photo, caption=result.caption, parse_mode=result.parse_mode, filename=result.filename
|
|
||||||
)
|
|
||||||
for result in self.results
|
|
||||||
],
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
for index, value in enumerate(reply):
|
|
||||||
result = self.results[index]
|
|
||||||
await result.cache_file_id(value)
|
|
||||||
|
@ -1,207 +1,3 @@
|
|||||||
import asyncio
|
from gram_core.services.template.services import TemplateService, TemplatePreviewer
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlencode, urljoin, urlsplit
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from fastapi.responses import FileResponse, HTMLResponse
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from jinja2 import Environment, FileSystemLoader, Template
|
|
||||||
from playwright.async_api import ViewportSize
|
|
||||||
|
|
||||||
from core.application import Application
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.config import config as application_config
|
|
||||||
from core.dependence.aiobrowser import AioBrowser
|
|
||||||
from core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
|
|
||||||
from core.services.template.error import QuerySelectorNotFound
|
|
||||||
from core.services.template.models import FileType, RenderResult
|
|
||||||
from utils.const import PROJECT_ROOT
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
__all__ = ("TemplateService", "TemplatePreviewer")
|
__all__ = ("TemplateService", "TemplatePreviewer")
|
||||||
|
|
||||||
|
|
||||||
class TemplateService(BaseService):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
app: Application,
|
|
||||||
browser: AioBrowser,
|
|
||||||
html_to_file_id_cache: HtmlToFileIdCache,
|
|
||||||
preview_cache: TemplatePreviewCache,
|
|
||||||
template_dir: str = "resources",
|
|
||||||
):
|
|
||||||
self._browser = browser
|
|
||||||
self.template_dir = PROJECT_ROOT / template_dir
|
|
||||||
|
|
||||||
self._jinja2_env = Environment(
|
|
||||||
loader=FileSystemLoader(template_dir),
|
|
||||||
enable_async=True,
|
|
||||||
autoescape=True,
|
|
||||||
auto_reload=application_config.debug,
|
|
||||||
)
|
|
||||||
self.using_preview = application_config.debug and application_config.webserver.enable
|
|
||||||
|
|
||||||
if self.using_preview:
|
|
||||||
self.previewer = TemplatePreviewer(self, preview_cache, app.web_app)
|
|
||||||
|
|
||||||
self.html_to_file_id_cache = html_to_file_id_cache
|
|
||||||
|
|
||||||
def get_template(self, template_name: str) -> Template:
|
|
||||||
return self._jinja2_env.get_template(template_name)
|
|
||||||
|
|
||||||
async def render_async(self, template_name: str, template_data: dict) -> str:
|
|
||||||
"""模板渲染
|
|
||||||
:param template_name: 模板文件名
|
|
||||||
:param template_data: 模板数据
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
start_time = loop.time()
|
|
||||||
template = self.get_template(template_name)
|
|
||||||
html = await template.render_async(**template_data)
|
|
||||||
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
|
|
||||||
return html
|
|
||||||
|
|
||||||
async def render(
|
|
||||||
self,
|
|
||||||
template_name: str,
|
|
||||||
template_data: dict,
|
|
||||||
viewport: Optional[ViewportSize] = None,
|
|
||||||
full_page: bool = True,
|
|
||||||
evaluate: Optional[str] = None,
|
|
||||||
query_selector: Optional[str] = None,
|
|
||||||
file_type: FileType = FileType.PHOTO,
|
|
||||||
ttl: int = 24 * 60 * 60,
|
|
||||||
caption: Optional[str] = None,
|
|
||||||
parse_mode: Optional[str] = None,
|
|
||||||
filename: Optional[str] = None,
|
|
||||||
) -> RenderResult:
|
|
||||||
"""模板渲染成图片
|
|
||||||
:param template_name: 模板文件名
|
|
||||||
:param template_data: 模板数据
|
|
||||||
:param viewport: 截图大小
|
|
||||||
:param full_page: 是否长截图
|
|
||||||
:param evaluate: 页面加载后运行的 js
|
|
||||||
:param query_selector: 截图选择器
|
|
||||||
:param file_type: 缓存的文件类型
|
|
||||||
:param ttl: 缓存时间
|
|
||||||
:param caption: 图片描述
|
|
||||||
:param parse_mode: 图片描述解析模式
|
|
||||||
:param filename: 文件名字
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
start_time = loop.time()
|
|
||||||
template = self.get_template(template_name)
|
|
||||||
|
|
||||||
if self.using_preview:
|
|
||||||
preview_url = await self.previewer.get_preview_url(template_name, template_data)
|
|
||||||
logger.debug("调试模板 URL: \n%s", preview_url)
|
|
||||||
|
|
||||||
html = await template.render_async(**template_data)
|
|
||||||
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
|
|
||||||
|
|
||||||
file_id = await self.html_to_file_id_cache.get_data(html, file_type.name)
|
|
||||||
if file_id and not application_config.debug:
|
|
||||||
logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id)
|
|
||||||
return RenderResult(
|
|
||||||
html=html,
|
|
||||||
photo=file_id,
|
|
||||||
file_type=file_type,
|
|
||||||
cache=self.html_to_file_id_cache,
|
|
||||||
ttl=ttl,
|
|
||||||
caption=caption,
|
|
||||||
parse_mode=parse_mode,
|
|
||||||
filename=filename,
|
|
||||||
)
|
|
||||||
|
|
||||||
browser = await self._browser.get_browser()
|
|
||||||
start_time = loop.time()
|
|
||||||
page = await browser.new_page(viewport=viewport)
|
|
||||||
uri = (PROJECT_ROOT / template.filename).as_uri()
|
|
||||||
await page.goto(uri)
|
|
||||||
await page.set_content(html, wait_until="networkidle")
|
|
||||||
if evaluate:
|
|
||||||
await page.evaluate(evaluate)
|
|
||||||
clip = None
|
|
||||||
if query_selector:
|
|
||||||
try:
|
|
||||||
card = await page.query_selector(query_selector)
|
|
||||||
if not card:
|
|
||||||
raise QuerySelectorNotFound
|
|
||||||
clip = await card.bounding_box()
|
|
||||||
if not clip:
|
|
||||||
raise QuerySelectorNotFound
|
|
||||||
except QuerySelectorNotFound:
|
|
||||||
logger.warning("未找到 %s 元素", query_selector)
|
|
||||||
png_data = await page.screenshot(clip=clip, full_page=full_page)
|
|
||||||
await page.close()
|
|
||||||
logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time))
|
|
||||||
return RenderResult(
|
|
||||||
html=html,
|
|
||||||
photo=png_data,
|
|
||||||
file_type=file_type,
|
|
||||||
cache=self.html_to_file_id_cache,
|
|
||||||
ttl=ttl,
|
|
||||||
caption=caption,
|
|
||||||
parse_mode=parse_mode,
|
|
||||||
filename=filename,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
template_service: TemplateService,
|
|
||||||
cache: TemplatePreviewCache,
|
|
||||||
web_app: FastAPI,
|
|
||||||
):
|
|
||||||
self.web_app = web_app
|
|
||||||
self.template_service = template_service
|
|
||||||
self.cache = cache
|
|
||||||
self.register_routes()
|
|
||||||
|
|
||||||
async def get_preview_url(self, template: str, data: dict):
|
|
||||||
"""获取预览 URL"""
|
|
||||||
components = urlsplit(application_config.webserver.url)
|
|
||||||
path = urljoin("/preview/", template)
|
|
||||||
query = {}
|
|
||||||
|
|
||||||
# 如果有数据,暂存在 redis 中
|
|
||||||
if data:
|
|
||||||
key = str(uuid4())
|
|
||||||
await self.cache.set_data(key, data)
|
|
||||||
query["key"] = key
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
return components._replace(path=path, query=urlencode(query)).geturl()
|
|
||||||
|
|
||||||
def register_routes(self):
|
|
||||||
"""注册预览用到的路由"""
|
|
||||||
|
|
||||||
@self.web_app.get("/preview/{path:path}")
|
|
||||||
async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612
|
|
||||||
# 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源
|
|
||||||
if not path.endswith((".html", ".jinja2")):
|
|
||||||
full_path = self.template_service.template_dir / path
|
|
||||||
if not full_path.is_file():
|
|
||||||
raise HTTPException(status_code=404, detail=f"Template '{path}' not found")
|
|
||||||
return FileResponse(full_path)
|
|
||||||
|
|
||||||
# 取回暂存的渲染数据
|
|
||||||
data = await self.cache.get_data(key) if key else {}
|
|
||||||
if key and data is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Template data {key} not found")
|
|
||||||
|
|
||||||
# 渲染 jinja2 模板
|
|
||||||
html = await self.template_service.render_async(path, data)
|
|
||||||
# 将本地 URL file:// 修改为 HTTP url,因为浏览器内不允许加载本地文件
|
|
||||||
# file:///project_dir/cache/image.jpg => /cache/image.jpg
|
|
||||||
html = html.replace(PROJECT_ROOT.as_uri(), "")
|
|
||||||
return HTMLResponse(html)
|
|
||||||
|
|
||||||
# 其他静态资源
|
|
||||||
for name in ["cache", "resources"]:
|
|
||||||
directory = PROJECT_ROOT / name
|
|
||||||
directory.mkdir(exist_ok=True)
|
|
||||||
self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)
|
|
||||||
|
@ -1,24 +1,3 @@
|
|||||||
from typing import List
|
from gram_core.services.users.cache import UserAdminCache
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.dependence.redisdb import RedisDB
|
|
||||||
|
|
||||||
__all__ = ("UserAdminCache",)
|
__all__ = ("UserAdminCache",)
|
||||||
|
|
||||||
|
|
||||||
class UserAdminCache(BaseService.Component):
|
|
||||||
def __init__(self, redis: RedisDB):
|
|
||||||
self.client = redis.client
|
|
||||||
self.qname = "users:admin"
|
|
||||||
|
|
||||||
async def ismember(self, user_id: int) -> bool:
|
|
||||||
return await self.client.sismember(self.qname, user_id)
|
|
||||||
|
|
||||||
async def get_all(self) -> List[int]:
|
|
||||||
return [int(str_data) for str_data in await self.client.smembers(self.qname)]
|
|
||||||
|
|
||||||
async def set(self, user_id: int) -> bool:
|
|
||||||
return await self.client.sadd(self.qname, user_id)
|
|
||||||
|
|
||||||
async def remove(self, user_id: int) -> bool:
|
|
||||||
return await self.client.srem(self.qname, user_id)
|
|
||||||
|
@ -1,34 +1,7 @@
|
|||||||
import enum
|
from gram_core.services.users.models import User, UserDataBase, PermissionsEnum
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"User",
|
"User",
|
||||||
"UserDataBase",
|
"UserDataBase",
|
||||||
"PermissionsEnum",
|
"PermissionsEnum",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PermissionsEnum(int, enum.Enum):
|
|
||||||
OWNER = 1
|
|
||||||
ADMIN = 2
|
|
||||||
PUBLIC = 3
|
|
||||||
|
|
||||||
|
|
||||||
class User(SQLModel):
|
|
||||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
|
||||||
id: Optional[int] = Field(
|
|
||||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
|
||||||
)
|
|
||||||
user_id: int = Field(unique=True, sa_column=Column(BigInteger()))
|
|
||||||
permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum)))
|
|
||||||
locale: Optional[str] = Field()
|
|
||||||
ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
|
|
||||||
ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
|
|
||||||
is_banned: Optional[int] = Field()
|
|
||||||
|
|
||||||
|
|
||||||
class UserDataBase(User, table=True):
|
|
||||||
__tablename__ = "users"
|
|
||||||
|
@ -1,44 +1,3 @@
|
|||||||
from typing import Optional, List
|
from gram_core.services.users.repositories import UserRepository
|
||||||
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.dependence.database import Database
|
|
||||||
from core.services.users.models import UserDataBase as User
|
|
||||||
from core.sqlmodel.session import AsyncSession
|
|
||||||
|
|
||||||
__all__ = ("UserRepository",)
|
__all__ = ("UserRepository",)
|
||||||
|
|
||||||
|
|
||||||
class UserRepository(BaseService.Component):
|
|
||||||
def __init__(self, database: Database):
|
|
||||||
self.engine = database.engine
|
|
||||||
|
|
||||||
async def get_by_user_id(self, user_id: int) -> Optional[User]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(User).where(User.user_id == user_id)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
return results.first()
|
|
||||||
|
|
||||||
async def add(self, user: User):
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(user)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def update(self, user: User) -> User:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
session.add(user)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(user)
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def remove(self, user: User):
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
await session.delete(user)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def get_all(self) -> List[User]:
|
|
||||||
async with AsyncSession(self.engine) as session:
|
|
||||||
statement = select(User)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
return results.all()
|
|
||||||
|
@ -1,83 +1,3 @@
|
|||||||
from typing import List, Optional
|
from gram_core.services.users.services import UserService, UserAdminService
|
||||||
|
|
||||||
from core.base_service import BaseService
|
|
||||||
from core.config import config
|
|
||||||
from core.services.users.cache import UserAdminCache
|
|
||||||
from core.services.users.models import PermissionsEnum, UserDataBase as User
|
|
||||||
from core.services.users.repositories import UserRepository
|
|
||||||
|
|
||||||
__all__ = ("UserService", "UserAdminService")
|
__all__ = ("UserService", "UserAdminService")
|
||||||
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
class UserService(BaseService):
|
|
||||||
def __init__(self, user_repository: UserRepository) -> None:
|
|
||||||
self._repository: UserRepository = user_repository
|
|
||||||
|
|
||||||
async def get_user_by_id(self, user_id: int) -> Optional[User]:
|
|
||||||
"""从数据库获取用户信息
|
|
||||||
:param user_id:用户ID
|
|
||||||
:return: User
|
|
||||||
"""
|
|
||||||
return await self._repository.get_by_user_id(user_id)
|
|
||||||
|
|
||||||
async def remove(self, user: User):
|
|
||||||
return await self._repository.remove(user)
|
|
||||||
|
|
||||||
async def update_user(self, user: User):
|
|
||||||
return await self._repository.add(user)
|
|
||||||
|
|
||||||
|
|
||||||
class UserAdminService(BaseService):
|
|
||||||
def __init__(self, user_repository: UserRepository, cache: UserAdminCache):
|
|
||||||
self.user_repository = user_repository
|
|
||||||
self._cache = cache
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
owner = config.owner
|
|
||||||
if owner:
|
|
||||||
user = await self.user_repository.get_by_user_id(owner)
|
|
||||||
if user:
|
|
||||||
if user.permissions != PermissionsEnum.OWNER:
|
|
||||||
user.permissions = PermissionsEnum.OWNER
|
|
||||||
await self._cache.set(user.user_id)
|
|
||||||
await self.user_repository.update(user)
|
|
||||||
else:
|
|
||||||
user = User(user_id=owner, permissions=PermissionsEnum.OWNER)
|
|
||||||
await self._cache.set(user.user_id)
|
|
||||||
await self.user_repository.add(user)
|
|
||||||
else:
|
|
||||||
logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限")
|
|
||||||
users = await self.user_repository.get_all()
|
|
||||||
for user in users:
|
|
||||||
await self._cache.set(user.user_id)
|
|
||||||
|
|
||||||
async def is_admin(self, user_id: int) -> bool:
|
|
||||||
return await self._cache.ismember(user_id)
|
|
||||||
|
|
||||||
async def get_admin_list(self) -> List[int]:
|
|
||||||
return await self._cache.get_all()
|
|
||||||
|
|
||||||
async def add_admin(self, user_id: int) -> bool:
|
|
||||||
user = await self.user_repository.get_by_user_id(user_id)
|
|
||||||
if user:
|
|
||||||
if user.permissions == PermissionsEnum.OWNER:
|
|
||||||
return False
|
|
||||||
if user.permissions != PermissionsEnum.ADMIN:
|
|
||||||
user.permissions = PermissionsEnum.ADMIN
|
|
||||||
await self.user_repository.update(user)
|
|
||||||
else:
|
|
||||||
user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN)
|
|
||||||
await self.user_repository.add(user)
|
|
||||||
return await self._cache.set(user_id)
|
|
||||||
|
|
||||||
async def delete_admin(self, user_id: int) -> bool:
|
|
||||||
user = await self.user_repository.get_by_user_id(user_id)
|
|
||||||
if user:
|
|
||||||
if user.permissions == PermissionsEnum.OWNER:
|
|
||||||
return True # 假装移除成功
|
|
||||||
user.permissions = PermissionsEnum.PUBLIC
|
|
||||||
await self.user_repository.update(user)
|
|
||||||
return await self._cache.remove(user.user_id)
|
|
||||||
return False
|
|
||||||
|
@ -1,118 +1,3 @@
|
|||||||
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload
|
from gram_core.sqlmodel.session import AsyncSession
|
||||||
|
|
||||||
from sqlalchemy import util
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
|
|
||||||
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
|
|
||||||
from sqlalchemy.sql.base import Executable as _Executable
|
|
||||||
from sqlmodel.engine.result import Result, ScalarResult
|
|
||||||
from sqlmodel.orm.session import Session
|
|
||||||
from sqlmodel.sql.base import Executable
|
|
||||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
_TSelectParam = TypeVar("_TSelectParam")
|
|
||||||
|
|
||||||
__all__ = ("AsyncSession",)
|
__all__ = ("AsyncSession",)
|
||||||
|
|
||||||
|
|
||||||
class AsyncSession(_AsyncSession): # pylint: disable=W0223
|
|
||||||
sync_session_class = Session
|
|
||||||
sync_session: Session
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
|
|
||||||
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
|
|
||||||
sync_session_class: Type[Session] = Session,
|
|
||||||
**kw: Any,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
bind=bind,
|
|
||||||
binds=binds,
|
|
||||||
sync_session_class=sync_session_class,
|
|
||||||
**kw,
|
|
||||||
)
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def exec(
|
|
||||||
self,
|
|
||||||
statement: Select[_TSelectParam],
|
|
||||||
*,
|
|
||||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
|
||||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
|
||||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
|
||||||
**kw: Any,
|
|
||||||
) -> Result[_TSelectParam]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def exec(
|
|
||||||
self,
|
|
||||||
statement: SelectOfScalar[_TSelectParam],
|
|
||||||
*,
|
|
||||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
|
||||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
|
||||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
|
||||||
**kw: Any,
|
|
||||||
) -> ScalarResult[_TSelectParam]:
|
|
||||||
...
|
|
||||||
|
|
||||||
async def exec(
|
|
||||||
self,
|
|
||||||
statement: Union[
|
|
||||||
Select[_TSelectParam],
|
|
||||||
SelectOfScalar[_TSelectParam],
|
|
||||||
Executable[_TSelectParam],
|
|
||||||
],
|
|
||||||
*,
|
|
||||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
|
||||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
|
||||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
|
||||||
**kw: Any,
|
|
||||||
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
|
|
||||||
results = super().execute(
|
|
||||||
statement,
|
|
||||||
params=params,
|
|
||||||
execution_options=execution_options,
|
|
||||||
bind_arguments=bind_arguments,
|
|
||||||
**kw,
|
|
||||||
)
|
|
||||||
if isinstance(statement, SelectOfScalar):
|
|
||||||
return (await results).scalars() # type: ignore
|
|
||||||
return await results # type: ignore
|
|
||||||
|
|
||||||
async def execute( # pylint: disable=W0221
|
|
||||||
self,
|
|
||||||
statement: _Executable,
|
|
||||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
|
||||||
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
|
|
||||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
|
||||||
**kw: Any,
|
|
||||||
) -> Result[Any]:
|
|
||||||
return await super().execute( # type: ignore
|
|
||||||
statement=statement,
|
|
||||||
params=params,
|
|
||||||
execution_options=execution_options,
|
|
||||||
bind_arguments=bind_arguments,
|
|
||||||
**kw,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get( # pylint: disable=W0221
|
|
||||||
self,
|
|
||||||
entity: Type[_TSelectParam],
|
|
||||||
ident: Any,
|
|
||||||
options: Optional[Sequence[Any]] = None,
|
|
||||||
populate_existing: bool = False,
|
|
||||||
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
|
|
||||||
identity_token: Optional[Any] = None,
|
|
||||||
execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT,
|
|
||||||
) -> Optional[_TSelectParam]:
|
|
||||||
return await super().get(
|
|
||||||
entity=entity,
|
|
||||||
ident=ident,
|
|
||||||
options=options,
|
|
||||||
populate_existing=populate_existing,
|
|
||||||
with_for_update=with_for_update,
|
|
||||||
identity_token=identity_token,
|
|
||||||
execution_options=execution_options,
|
|
||||||
)
|
|
||||||
|
17
poetry.lock
generated
17
poetry.lock
generated
@ -606,6 +606,21 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
gitdb = ">=4.0.1,<5"
|
gitdb = ">=4.0.1,<5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gram-core"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "telegram robot base core."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "https://github.com/PaiGramTeam/GramCore.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "7fb5d4c0e01731e6901829fe317be19023c2c4c7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "greenlet"
|
name = "greenlet"
|
||||||
version = "1.1.3"
|
version = "1.1.3"
|
||||||
@ -2316,4 +2331,4 @@ test = ["flaky", "pytest", "pytest-asyncio"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.8"
|
python-versions = "^3.8"
|
||||||
content-hash = "8378e2149d042db3c75d422e969b658c255654e517c32bdb60b6dafef01e071e"
|
content-hash = "847b0e55ac501bd6126b5370fc08c018401147d8dd5c68c3216833b64b1213ea"
|
||||||
|
@ -44,6 +44,7 @@ pillow = "^10.0.0"
|
|||||||
playwright = "^1.27.1"
|
playwright = "^1.27.1"
|
||||||
aiosqlite = { extras = ["sqlite"], version = "^0.19.0" }
|
aiosqlite = { extras = ["sqlite"], version = "^0.19.0" }
|
||||||
simnet = { git = "https://github.com/PaiGramTeam/SIMNet" }
|
simnet = { git = "https://github.com/PaiGramTeam/SIMNet" }
|
||||||
|
gram-core = {git = "https://github.com/PaiGramTeam/GramCore.git"}
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
pyro = ["Pyrogram", "TgCrypto"]
|
pyro = ["Pyrogram", "TgCrypto"]
|
||||||
|
@ -26,6 +26,7 @@ fastapi==0.99.1 ; python_version >= "3.8" and python_version < "4.0"
|
|||||||
flaky==3.7.0 ; python_version >= "3.8" and python_version < "4.0"
|
flaky==3.7.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||||
gitdb==4.0.10 ; python_version >= "3.8" and python_version < "4.0"
|
gitdb==4.0.10 ; python_version >= "3.8" and python_version < "4.0"
|
||||||
gitpython==3.1.32 ; python_version >= "3.8" and python_version < "4.0"
|
gitpython==3.1.32 ; python_version >= "3.8" and python_version < "4.0"
|
||||||
|
gram-core @ git+https://github.com/PaiGramTeam/GramCore.git@main ; python_version >= "3.8" and python_version < "4.0"
|
||||||
greenlet==1.1.3 ; python_version >= "3.8" and python_version < "4.0"
|
greenlet==1.1.3 ; python_version >= "3.8" and python_version < "4.0"
|
||||||
h11==0.14.0 ; python_version >= "3.8" and python_version < "4.0"
|
h11==0.14.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||||
httpcore==0.17.3 ; python_version >= "3.8" and python_version < "4.0"
|
httpcore==0.17.3 ; python_version >= "3.8" and python_version < "4.0"
|
||||||
|
Loading…
Reference in New Issue
Block a user