mirror of
https://github.com/PaiGramTeam/MibooGram.git
synced 2024-11-22 07:08:04 +00:00
♻️ separate core code
Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
This commit is contained in:
parent
af2e9bdb9b
commit
865f29bd77
@ -1,289 +1,5 @@
|
||||
"""BOT"""
|
||||
import asyncio
|
||||
import signal
|
||||
from functools import wraps
|
||||
from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func
|
||||
from ssl import SSLZeroReturnError
|
||||
from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar
|
||||
|
||||
import pytz
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from telegram import Bot, Update
|
||||
from telegram.error import NetworkError, TelegramError, TimedOut
|
||||
from telegram.ext import (
|
||||
Application as TelegramApplication,
|
||||
ApplicationBuilder as TelegramApplicationBuilder,
|
||||
Defaults,
|
||||
JobQueue,
|
||||
)
|
||||
from typing_extensions import ParamSpec
|
||||
from uvicorn import Server
|
||||
|
||||
from core.config import config as application_config
|
||||
from core.handler.limiterhandler import LimiterHandler
|
||||
from core.manager import Managers
|
||||
from core.override.telegram import HTTPXRequest
|
||||
from core.ratelimiter import RateLimiter
|
||||
from utils.const import WRAPPER_ASSIGNMENTS
|
||||
from utils.log import logger
|
||||
from utils.models.signal import Singleton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from asyncio import Task
|
||||
from types import FrameType
|
||||
from gram_core.application import Application
|
||||
|
||||
__all__ = ("Application",)
|
||||
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class Application(Singleton):
|
||||
"""Application"""
|
||||
|
||||
_web_server_task: Optional["Task"] = None
|
||||
|
||||
_startup_funcs: List[Callable] = []
|
||||
_shutdown_funcs: List[Callable] = []
|
||||
|
||||
def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None:
|
||||
self._running = False
|
||||
self.managers = managers
|
||||
self.telegram = telegram
|
||||
self.web_server = web_server
|
||||
self.managers.set_application(application=self) # 给 managers 设置 application
|
||||
self.managers.build_executor("Application")
|
||||
|
||||
@classmethod
|
||||
def build(cls):
|
||||
managers = Managers()
|
||||
telegram = (
|
||||
TelegramApplicationBuilder()
|
||||
.get_updates_read_timeout(application_config.update_read_timeout)
|
||||
.get_updates_write_timeout(application_config.update_write_timeout)
|
||||
.get_updates_connect_timeout(application_config.update_connect_timeout)
|
||||
.get_updates_pool_timeout(application_config.update_pool_timeout)
|
||||
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
|
||||
.token(application_config.bot_token)
|
||||
.request(
|
||||
HTTPXRequest(
|
||||
connection_pool_size=application_config.connection_pool_size,
|
||||
proxy_url=application_config.proxy_url,
|
||||
read_timeout=application_config.read_timeout,
|
||||
write_timeout=application_config.write_timeout,
|
||||
connect_timeout=application_config.connect_timeout,
|
||||
pool_timeout=application_config.pool_timeout,
|
||||
)
|
||||
)
|
||||
.rate_limiter(RateLimiter())
|
||||
.build()
|
||||
)
|
||||
web_server = Server(
|
||||
uvicorn.Config(
|
||||
app=FastAPI(debug=application_config.debug),
|
||||
port=application_config.webserver.port,
|
||||
host=application_config.webserver.host,
|
||||
log_config=None,
|
||||
)
|
||||
)
|
||||
return cls(managers, telegram, web_server)
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""bot 是否正在运行"""
|
||||
with self._lock:
|
||||
return self._running
|
||||
|
||||
@property
|
||||
def web_app(self) -> FastAPI:
|
||||
"""fastapi app"""
|
||||
return self.web_server.config.app
|
||||
|
||||
@property
|
||||
def bot(self) -> Optional[Bot]:
|
||||
return self.telegram.bot
|
||||
|
||||
@property
|
||||
def job_queue(self) -> Optional[JobQueue]:
|
||||
return self.telegram.job_queue
|
||||
|
||||
async def _on_startup(self) -> None:
|
||||
for func in self._startup_funcs:
|
||||
await self.managers.executor(func, block=getattr(func, "block", False))
|
||||
|
||||
async def _on_shutdown(self) -> None:
|
||||
for func in self._shutdown_funcs:
|
||||
await self.managers.executor(func, block=getattr(func, "block", False))
|
||||
|
||||
async def initialize(self):
|
||||
"""BOT 初始化"""
|
||||
self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制
|
||||
await self.managers.start_dependency() # 启动基础服务
|
||||
await self.managers.init_components() # 实例化组件
|
||||
await self.managers.start_services() # 启动其他服务
|
||||
await self.managers.install_plugins() # 安装插件
|
||||
|
||||
async def shutdown(self):
|
||||
"""BOT 关闭"""
|
||||
await self.managers.uninstall_plugins() # 卸载插件
|
||||
await self.managers.stop_services() # 终止其他服务
|
||||
await self.managers.stop_dependency() # 终止基础服务
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 BOT"""
|
||||
logger.info("正在启动 BOT 中...")
|
||||
|
||||
def error_callback(exc: TelegramError) -> None:
|
||||
"""错误信息回调"""
|
||||
self.telegram.create_task(self.telegram.process_error(error=exc, update=None))
|
||||
|
||||
await self.telegram.initialize()
|
||||
logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True})
|
||||
|
||||
if application_config.webserver.enable: # 如果使用 web app
|
||||
server_config = self.web_server.config
|
||||
server_config.setup_event_loop()
|
||||
if not server_config.loaded:
|
||||
server_config.load()
|
||||
self.web_server.lifespan = server_config.lifespan_class(server_config)
|
||||
try:
|
||||
await self.web_server.startup()
|
||||
except OSError as e:
|
||||
if e.errno == 10048:
|
||||
logger.error("Web Server 端口被占用:%s", e)
|
||||
logger.error("Web Server 启动失败,正在退出")
|
||||
raise SystemExit from None
|
||||
|
||||
if self.web_server.should_exit:
|
||||
logger.error("Web Server 启动失败,正在退出")
|
||||
raise SystemExit from None
|
||||
logger.success("Web Server 启动成功")
|
||||
|
||||
self._web_server_task = asyncio.create_task(self.web_server.main_loop())
|
||||
|
||||
for _ in range(5): # 连接至 telegram 服务器
|
||||
try:
|
||||
await self.telegram.updater.start_polling(
|
||||
error_callback=error_callback, allowed_updates=Update.ALL_TYPES
|
||||
)
|
||||
break
|
||||
except TimedOut:
|
||||
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
|
||||
continue
|
||||
except NetworkError as e:
|
||||
logger.exception()
|
||||
if isinstance(e, SSLZeroReturnError):
|
||||
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||||
else:
|
||||
logger.error("网络连接出现问题, 请检查您的网络状况.")
|
||||
raise SystemExit from e
|
||||
|
||||
await self.initialize()
|
||||
logger.success("BOT 初始化成功")
|
||||
logger.debug("BOT 开始启动")
|
||||
|
||||
await self._on_startup()
|
||||
await self.telegram.start()
|
||||
self._running = True
|
||||
logger.success("BOT 启动成功")
|
||||
|
||||
def stop_signal_handler(self, signum: int):
|
||||
"""终止信号处理"""
|
||||
signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")}
|
||||
logger.debug("接收到了终止信号 %s 正在退出...", signals[signum])
|
||||
if self._web_server_task:
|
||||
self._web_server_task.cancel()
|
||||
|
||||
async def idle(self) -> None:
|
||||
"""在接收到中止信号之前,堵塞loop"""
|
||||
|
||||
task = None
|
||||
|
||||
def stop_handler(signum: int, _: "FrameType") -> None:
|
||||
self.stop_signal_handler(signum)
|
||||
task.cancel()
|
||||
|
||||
for s in (SIGINT, SIGTERM, SIGABRT):
|
||||
signal_func(s, stop_handler)
|
||||
|
||||
while True:
|
||||
task = asyncio.create_task(asyncio.sleep(600))
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""关闭"""
|
||||
logger.info("BOT 正在关闭")
|
||||
self._running = False
|
||||
|
||||
await self._on_shutdown()
|
||||
|
||||
if self.telegram.updater.running:
|
||||
await self.telegram.updater.stop()
|
||||
|
||||
await self.shutdown()
|
||||
|
||||
if self.telegram.running:
|
||||
await self.telegram.stop()
|
||||
|
||||
await self.telegram.shutdown()
|
||||
if self.web_server is not None:
|
||||
try:
|
||||
await self.web_server.shutdown()
|
||||
logger.info("Web Server 已经关闭")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
logger.success("BOT 关闭成功")
|
||||
|
||||
def launch(self) -> None:
|
||||
"""启动"""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(self.start())
|
||||
loop.run_until_complete(self.idle())
|
||||
except (SystemExit, KeyboardInterrupt) as exc:
|
||||
logger.debug("接收到了终止信号,BOT 即将关闭", exc_info=exc) # 接收到了终止信号
|
||||
except NetworkError as e:
|
||||
if isinstance(e, SSLZeroReturnError):
|
||||
logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||||
else:
|
||||
logger.critical("网络连接出现问题, 请检查您的网络状况.")
|
||||
except Exception as e:
|
||||
logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e)
|
||||
finally:
|
||||
loop.run_until_complete(self.stop())
|
||||
|
||||
if application_config.reload:
|
||||
raise SystemExit from None
|
||||
|
||||
def on_startup(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""注册一个在 BOT 启动时执行的函数"""
|
||||
|
||||
if func not in self._startup_funcs:
|
||||
self._startup_funcs.append(func)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""注册一个在 BOT 停止时执行的函数"""
|
||||
|
||||
if func not in self._shutdown_funcs:
|
||||
self._shutdown_funcs.append(func)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
@ -1,60 +1,3 @@
|
||||
from abc import ABC
|
||||
from itertools import chain
|
||||
from typing import ClassVar, Iterable, Type, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from utils.helpers import isabstract
|
||||
from gram_core.base_service import BaseService, BaseServiceType, DependenceType, ComponentType, get_all_services
|
||||
|
||||
__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services")
|
||||
|
||||
|
||||
class _BaseService:
|
||||
"""服务基类"""
|
||||
|
||||
_is_component: ClassVar[bool] = False
|
||||
_is_dependence: ClassVar[bool] = False
|
||||
|
||||
def __init_subclass__(cls, load: bool = True, **kwargs):
|
||||
cls.is_dependence = cls._is_dependence
|
||||
cls.is_component = cls._is_component
|
||||
cls.load = load
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
await self.shutdown()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize resources used by this service"""
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop & clear resources used by this service"""
|
||||
|
||||
|
||||
class _Dependence(_BaseService, ABC):
|
||||
_is_dependence: ClassVar[bool] = True
|
||||
|
||||
|
||||
class _Component(_BaseService, ABC):
|
||||
_is_component: ClassVar[bool] = True
|
||||
|
||||
|
||||
class BaseService(_BaseService, ABC):
|
||||
Dependence: Type[_BaseService] = _Dependence
|
||||
Component: Type[_BaseService] = _Component
|
||||
|
||||
|
||||
BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService)
|
||||
DependenceType = TypeVar("DependenceType", bound=_Dependence)
|
||||
ComponentType = TypeVar("ComponentType", bound=_Component)
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def get_all_services() -> Iterable[Type[_BaseService]]:
|
||||
return filter(
|
||||
lambda x: x.__name__[0] != "_" and x.load and not isabstract(x),
|
||||
chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()),
|
||||
)
|
||||
|
@ -1,29 +1,3 @@
|
||||
import enum
|
||||
|
||||
try:
|
||||
import ujson as jsonlib
|
||||
except ImportError:
|
||||
import json as jsonlib
|
||||
|
||||
from pydantic import BaseSettings
|
||||
from gram_core.basemodel import RegionEnum, Settings
|
||||
|
||||
__all__ = ("RegionEnum", "Settings")
|
||||
|
||||
|
||||
class RegionEnum(int, enum.Enum):
|
||||
"""账号数据所在服务器"""
|
||||
|
||||
NULL = 0
|
||||
HYPERION = 1 # 米忽悠国服 hyperion
|
||||
HOYOLAB = 2 # 米忽悠国际服 hoyolab
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
cls.update_forward_refs()
|
||||
return super(Settings, cls).__new__(cls) # pylint: disable=E1120
|
||||
|
||||
class Config(BaseSettings.Config):
|
||||
case_sensitive = False
|
||||
json_loads = jsonlib.loads
|
||||
json_dumps = jsonlib.dumps
|
||||
|
@ -1 +0,0 @@
|
||||
"""bot builtins"""
|
@ -1,38 +0,0 @@
|
||||
"""上下文管理"""
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram.ext import CallbackContext
|
||||
from telegram import Update
|
||||
|
||||
__all__ = [
|
||||
"CallbackContextCV",
|
||||
"UpdateCV",
|
||||
"handler_contexts",
|
||||
"job_contexts",
|
||||
]
|
||||
|
||||
CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback")
|
||||
UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def handler_contexts(update: "Update", context: "CallbackContext") -> None:
|
||||
context_token = CallbackContextCV.set(context)
|
||||
update_token = UpdateCV.set(update)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CallbackContextCV.reset(context_token)
|
||||
UpdateCV.reset(update_token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def job_contexts(context: "CallbackContext") -> None:
|
||||
token = CallbackContextCV.set(context)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CallbackContextCV.reset(token)
|
@ -1,309 +0,0 @@
|
||||
"""参数分发器"""
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import AbstractEventLoop
|
||||
from functools import cached_property, lru_cache, partial, wraps
|
||||
from inspect import Parameter, Signature
|
||||
from itertools import chain
|
||||
from types import GenericAlias, MethodType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from arkowrapper import ArkoWrapper
|
||||
from fastapi import FastAPI
|
||||
from telegram import Bot as TelegramBot, Chat, Message, Update, User
|
||||
from telegram.ext import Application as TelegramApplication, CallbackContext, Job
|
||||
from typing_extensions import ParamSpec
|
||||
from uvicorn import Server
|
||||
|
||||
from core.application import Application
|
||||
from utils.const import WRAPPER_ASSIGNMENTS
|
||||
from utils.typedefs import R, T
|
||||
|
||||
__all__ = (
|
||||
"catch",
|
||||
"AbstractDispatcher",
|
||||
"BaseDispatcher",
|
||||
"HandlerDispatcher",
|
||||
"JobDispatcher",
|
||||
"dispatched",
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
TargetType = Union[Type, str, Callable[[Any], bool]]
|
||||
|
||||
_CATCH_TARGET_ATTR = "_catch_targets"
|
||||
|
||||
|
||||
def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
||||
setattr(func, _CATCH_TARGET_ATTR, targets)
|
||||
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
@lru_cache(64)
|
||||
def get_signature(func: Union[type, Callable]) -> Signature:
|
||||
if isinstance(func, type):
|
||||
return inspect.signature(func.__init__)
|
||||
return inspect.signature(func)
|
||||
|
||||
|
||||
class AbstractDispatcher(ABC):
|
||||
"""参数分发器"""
|
||||
|
||||
IGNORED_ATTRS = []
|
||||
|
||||
_args: List[Any] = []
|
||||
_kwargs: Dict[Union[str, Type], Any] = {}
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||
return self._application
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._args = list(args)
|
||||
self._kwargs = dict(kwargs)
|
||||
|
||||
for _, value in kwargs.items():
|
||||
type_arg = type(value)
|
||||
if type_arg != str:
|
||||
self._kwargs[type_arg] = value
|
||||
|
||||
for arg in args:
|
||||
type_arg = type(arg)
|
||||
if type_arg != str:
|
||||
self._kwargs[type_arg] = arg
|
||||
|
||||
@cached_property
|
||||
def catch_funcs(self) -> List[MethodType]:
|
||||
# noinspection PyTypeChecker
|
||||
return list(
|
||||
ArkoWrapper(dir(self))
|
||||
.filter(lambda x: not x.startswith("_"))
|
||||
.filter(
|
||||
lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"]
|
||||
)
|
||||
.map(lambda x: getattr(self, x))
|
||||
.filter(lambda x: isinstance(x, MethodType))
|
||||
.filter(lambda x: hasattr(x, "_catch_targets"))
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]:
|
||||
result = {}
|
||||
for catch_func in self.catch_funcs:
|
||||
catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR)
|
||||
for catch_target in catch_targets:
|
||||
result[catch_target] = catch_func
|
||||
return result
|
||||
|
||||
@cached_property
|
||||
def dispatch_funcs(self) -> List[MethodType]:
|
||||
return list(
|
||||
ArkoWrapper(dir(self))
|
||||
.filter(lambda x: x.startswith("dispatch_by_"))
|
||||
.map(lambda x: getattr(self, x))
|
||||
.filter(lambda x: isinstance(x, MethodType))
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||
"""默认的 dispatch 方法"""
|
||||
|
||||
@abstractmethod
|
||||
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
||||
"""使用 catch_func 获取并分配参数"""
|
||||
|
||||
def dispatch(self, func: Callable[P, R]) -> Callable[..., R]:
|
||||
"""将参数分配给函数,从而合成一个无需参数即可执行的函数"""
|
||||
params = {}
|
||||
signature = get_signature(func)
|
||||
parameters: Dict[str, Parameter] = dict(signature.parameters)
|
||||
|
||||
for name, parameter in list(parameters.items()):
|
||||
parameter: Parameter
|
||||
if any(
|
||||
[
|
||||
name == "self" and isinstance(func, (type, MethodType)),
|
||||
parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL],
|
||||
]
|
||||
):
|
||||
del parameters[name]
|
||||
continue
|
||||
|
||||
for dispatch_func in self.dispatch_funcs:
|
||||
parameters[name] = dispatch_func(parameter)
|
||||
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.default != Parameter.empty:
|
||||
params[name] = parameter.default
|
||||
else:
|
||||
params[name] = None
|
||||
|
||||
return partial(func, **params)
|
||||
|
||||
@catch(Application)
|
||||
def catch_application(self) -> Application:
|
||||
return self.application
|
||||
|
||||
|
||||
class BaseDispatcher(AbstractDispatcher):
|
||||
"""默认参数分发器"""
|
||||
|
||||
_instances: Sequence[Any]
|
||||
|
||||
def _get_kwargs(self) -> Dict[Type[T], T]:
|
||||
result = self._get_default_kwargs()
|
||||
result[AbstractDispatcher] = self
|
||||
result.update(self._kwargs)
|
||||
return result
|
||||
|
||||
def _get_default_kwargs(self) -> Dict[Type[T], T]:
|
||||
application = self.application
|
||||
_default_kwargs = {
|
||||
FastAPI: application.web_app,
|
||||
Server: application.web_server,
|
||||
TelegramApplication: application.telegram,
|
||||
TelegramBot: application.telegram.bot,
|
||||
}
|
||||
if not application.running:
|
||||
for obj in chain(
|
||||
application.managers.dependency,
|
||||
application.managers.components,
|
||||
application.managers.services,
|
||||
application.managers.plugins,
|
||||
):
|
||||
_default_kwargs[type(obj)] = obj
|
||||
return {k: v for k, v in _default_kwargs.items() if v is not None}
|
||||
|
||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||
annotation = parameter.annotation
|
||||
# noinspection PyTypeChecker
|
||||
if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None:
|
||||
parameter._default = value # pylint: disable=W0212
|
||||
return parameter
|
||||
|
||||
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
||||
annotation = parameter.annotation
|
||||
if annotation != Any and isinstance(annotation, GenericAlias):
|
||||
return parameter
|
||||
|
||||
catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name)
|
||||
if catch_func is not None:
|
||||
# noinspection PyUnresolvedReferences,PyProtectedMember
|
||||
parameter._default = catch_func() # pylint: disable=W0212
|
||||
return parameter
|
||||
|
||||
@catch(AbstractEventLoop)
|
||||
def catch_loop(self) -> AbstractEventLoop:
|
||||
return asyncio.get_event_loop()
|
||||
|
||||
|
||||
class HandlerDispatcher(BaseDispatcher):
|
||||
"""Handler 参数分发器"""
|
||||
|
||||
def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
||||
super().__init__(update=update, context=context, **kwargs)
|
||||
self._update = update
|
||||
self._context = context
|
||||
|
||||
def dispatch(
|
||||
self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None
|
||||
) -> Callable[..., R]:
|
||||
self._update = update or self._update
|
||||
self._context = context or self._context
|
||||
if self._update is None:
|
||||
from core.builtins.contexts import UpdateCV
|
||||
|
||||
self._update = UpdateCV.get()
|
||||
if self._context is None:
|
||||
from core.builtins.contexts import CallbackContextCV
|
||||
|
||||
self._context = CallbackContextCV.get()
|
||||
return super().dispatch(func)
|
||||
|
||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||
"""HandlerDispatcher 默认不使用 dispatch_by_default"""
|
||||
return parameter
|
||||
|
||||
@catch(Update)
|
||||
def catch_update(self) -> Update:
|
||||
return self._update
|
||||
|
||||
@catch(CallbackContext)
|
||||
def catch_context(self) -> CallbackContext:
|
||||
return self._context
|
||||
|
||||
@catch(Message)
|
||||
def catch_message(self) -> Message:
|
||||
return self._update.effective_message
|
||||
|
||||
@catch(User)
|
||||
def catch_user(self) -> User:
|
||||
return self._update.effective_user
|
||||
|
||||
@catch(Chat)
|
||||
def catch_chat(self) -> Chat:
|
||||
return self._update.effective_chat
|
||||
|
||||
|
||||
class JobDispatcher(BaseDispatcher):
|
||||
"""Job 参数分发器"""
|
||||
|
||||
def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
||||
super().__init__(context=context, **kwargs)
|
||||
self._context = context
|
||||
|
||||
def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]:
|
||||
self._context = context or self._context
|
||||
if self._context is None:
|
||||
from core.builtins.contexts import CallbackContextCV
|
||||
|
||||
self._context = CallbackContextCV.get()
|
||||
return super().dispatch(func)
|
||||
|
||||
@catch("data")
|
||||
def catch_data(self) -> Any:
|
||||
return self._context.job.data
|
||||
|
||||
@catch(Job)
|
||||
def catch_job(self) -> Job:
|
||||
return self._context.job
|
||||
|
||||
@catch(CallbackContext)
|
||||
def catch_context(self) -> CallbackContext:
|
||||
return self._context
|
||||
|
||||
|
||||
def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher):
|
||||
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return dispatcher().dispatch(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
@ -1,131 +0,0 @@
|
||||
"""执行器"""
|
||||
import inspect
|
||||
from functools import cached_property
|
||||
from multiprocessing import RLock as Lock
|
||||
from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import CallbackContext
|
||||
from typing_extensions import ParamSpec, Self
|
||||
|
||||
from core.builtins.contexts import handler_contexts, job_contexts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
from core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher
|
||||
from multiprocessing.synchronize import RLock as LockType
|
||||
|
||||
__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor")
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
"""执行器
|
||||
Args:
|
||||
name(str): 该执行器的名称。执行器的名称是唯一的。
|
||||
|
||||
只支持执行只拥有 POSITIONAL_OR_KEYWORD 和 KEYWORD_ONLY 两种参数类型的函数
|
||||
"""
|
||||
|
||||
_lock: ClassVar["LockType"] = Lock()
|
||||
_instances: ClassVar[Dict[str, Self]] = {}
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||
return self._application
|
||||
|
||||
def __new__(cls: Type[T], name: str, *args, **kwargs) -> T:
|
||||
with cls._lock:
|
||||
if (instance := cls._instances.get(name)) is None:
|
||||
instance = object.__new__(cls)
|
||||
instance.__init__(name, *args, **kwargs)
|
||||
cls._instances.update({name: instance})
|
||||
return instance
|
||||
|
||||
@cached_property
|
||||
def name(self) -> str:
|
||||
"""当前执行器的名称"""
|
||||
return self._name
|
||||
|
||||
def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
|
||||
self._name = name
|
||||
self._dispatcher = dispatcher
|
||||
|
||||
|
||||
class Executor(BaseExecutor, Generic[P, R]):
|
||||
async def __call__(
|
||||
self,
|
||||
target: Callable[P, R],
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
**kwargs,
|
||||
) -> R:
|
||||
dispatcher = self._dispatcher or dispatcher
|
||||
dispatcher_instance = dispatcher(**kwargs)
|
||||
dispatcher_instance.set_application(application=self.application)
|
||||
dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数
|
||||
|
||||
# 执行
|
||||
if inspect.iscoroutinefunction(target):
|
||||
result = await dispatched_func()
|
||||
else:
|
||||
result = dispatched_func()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class HandlerExecutor(BaseExecutor, Generic[P, R]):
|
||||
"""Handler专用执行器"""
|
||||
|
||||
_callback: Callable[P, R]
|
||||
_dispatcher: "HandlerDispatcher"
|
||||
|
||||
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None:
|
||||
if dispatcher is None:
|
||||
from core.builtins.dispatcher import HandlerDispatcher
|
||||
|
||||
dispatcher = HandlerDispatcher
|
||||
super().__init__("handler", dispatcher)
|
||||
self._callback = func
|
||||
self._dispatcher = dispatcher()
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
if self._dispatcher is not None:
|
||||
self._dispatcher.set_application(application)
|
||||
|
||||
async def __call__(self, update: Update, context: CallbackContext) -> R:
|
||||
with handler_contexts(update, context):
|
||||
dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context)
|
||||
return await dispatched_func()
|
||||
|
||||
|
||||
class JobExecutor(BaseExecutor):
|
||||
"""Job 专用执行器"""
|
||||
|
||||
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
|
||||
if dispatcher is None:
|
||||
from core.builtins.dispatcher import JobDispatcher
|
||||
|
||||
dispatcher = JobDispatcher
|
||||
super().__init__("job", dispatcher)
|
||||
self._callback = func
|
||||
self._dispatcher = dispatcher()
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
if self._dispatcher is not None:
|
||||
self._dispatcher.set_application(application)
|
||||
|
||||
async def __call__(self, context: CallbackContext) -> R:
|
||||
with job_contexts(context):
|
||||
dispatched_func = self._dispatcher.dispatch(self._callback, context=context)
|
||||
return await dispatched_func()
|
@ -1,185 +0,0 @@
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from watchfiles import watch
|
||||
|
||||
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
|
||||
from utils.log import logger
|
||||
from utils.typedefs import StrOrPath
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from multiprocessing.process import BaseProcess
|
||||
|
||||
__all__ = ("Reloader",)
|
||||
|
||||
multiprocessing.allow_connection_pickling()
|
||||
spawn = multiprocessing.get_context("spawn")
|
||||
|
||||
|
||||
class FileFilter:
|
||||
"""监控文件过滤"""
|
||||
|
||||
def __init__(self, includes: List[str], excludes: List[str]) -> None:
|
||||
default_includes = ["*.py"]
|
||||
self.includes = [default for default in default_includes if default not in excludes]
|
||||
self.includes.extend(includes)
|
||||
self.includes = list(set(self.includes))
|
||||
|
||||
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
|
||||
self.excludes = [default for default in default_excludes if default not in includes]
|
||||
self.exclude_dirs = []
|
||||
for e in excludes:
|
||||
p = Path(e)
|
||||
try:
|
||||
is_dir = p.is_dir()
|
||||
except OSError:
|
||||
is_dir = False
|
||||
|
||||
if is_dir:
|
||||
self.exclude_dirs.append(p)
|
||||
else:
|
||||
self.excludes.append(e)
|
||||
self.excludes = list(set(self.excludes))
|
||||
|
||||
def __call__(self, path: Path) -> bool:
|
||||
for include_pattern in self.includes:
|
||||
if path.match(include_pattern):
|
||||
for exclude_dir in self.exclude_dirs:
|
||||
if exclude_dir in path.parents:
|
||||
return False
|
||||
|
||||
for exclude_pattern in self.excludes:
|
||||
if path.match(exclude_pattern):
|
||||
return False
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Reloader:
|
||||
_target: Callable[..., None]
|
||||
_process: "BaseProcess"
|
||||
|
||||
@property
|
||||
def process(self) -> "BaseProcess":
|
||||
return self._process
|
||||
|
||||
@property
|
||||
def target(self) -> Callable[..., None]:
|
||||
return self._target
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Callable[..., None],
|
||||
*,
|
||||
reload_delay: float = 0.25,
|
||||
reload_dirs: List[StrOrPath] = None,
|
||||
reload_includes: List[str] = None,
|
||||
reload_excludes: List[str] = None,
|
||||
):
|
||||
if inspect.iscoroutinefunction(target):
|
||||
raise ValueError("不支持异步函数")
|
||||
self._target = target
|
||||
|
||||
self.reload_delay = reload_delay
|
||||
|
||||
_reload_dirs = []
|
||||
for reload_dir in reload_dirs or []:
|
||||
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
|
||||
|
||||
self.reload_dirs = []
|
||||
for reload_dir in _reload_dirs:
|
||||
append = True
|
||||
for parent in reload_dir.parents:
|
||||
if parent in _reload_dirs:
|
||||
append = False
|
||||
break
|
||||
if append:
|
||||
self.reload_dirs.append(reload_dir)
|
||||
|
||||
if not self.reload_dirs:
|
||||
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
|
||||
|
||||
self._should_exit = threading.Event()
|
||||
|
||||
frame = inspect.currentframe().f_back
|
||||
|
||||
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
|
||||
self.watcher = watch(
|
||||
*self.reload_dirs,
|
||||
watch_filter=None,
|
||||
stop_event=self._should_exit,
|
||||
yield_on_timeout=True,
|
||||
)
|
||||
|
||||
def get_changes(self) -> Optional[List[Path]]:
|
||||
if not self._process.is_alive():
|
||||
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
|
||||
self._should_exit.set()
|
||||
try:
|
||||
changes = next(self.watcher)
|
||||
except StopIteration:
|
||||
return None
|
||||
if changes:
|
||||
unique_paths = {Path(c[1]) for c in changes}
|
||||
return [p for p in unique_paths if self.watch_filter(p)]
|
||||
return None
|
||||
|
||||
def __iter__(self) -> Iterator[Optional[List[Path]]]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> Optional[List[Path]]:
|
||||
return self.get_changes()
|
||||
|
||||
def run(self) -> None:
|
||||
self.startup()
|
||||
for changes in self:
|
||||
if changes:
|
||||
logger.warning(
|
||||
"检测到文件 %s 发生改变, 正在重载...",
|
||||
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
|
||||
extra={"tag": "Reloader"},
|
||||
)
|
||||
self.restart()
|
||||
|
||||
self.shutdown()
|
||||
|
||||
def signal_handler(self, *_) -> None:
|
||||
"""当接收到结束信号量时"""
|
||||
self._process.join(3)
|
||||
if self._process.is_alive():
|
||||
self._process.terminate()
|
||||
self._process.join()
|
||||
self._should_exit.set()
|
||||
|
||||
def startup(self) -> None:
|
||||
"""启动进程"""
|
||||
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
|
||||
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.signal_handler)
|
||||
|
||||
self._process = spawn.Process(target=self._target)
|
||||
self._process.start()
|
||||
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
|
||||
|
||||
def restart(self) -> None:
|
||||
"""重启进程"""
|
||||
self._process.terminate()
|
||||
self._process.join(10)
|
||||
|
||||
self._process = spawn.Process(target=self._target)
|
||||
self._process.start()
|
||||
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""关闭进程"""
|
||||
self._process.terminate()
|
||||
self._process.join(10)
|
||||
|
||||
logger.info("重载器已经关闭", extra={"tag": "Reloader"})
|
164
core/config.py
164
core/config.py
@ -1,165 +1,3 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import dotenv
|
||||
from pydantic import AnyUrl, Field
|
||||
|
||||
from core.basemodel import Settings
|
||||
from utils.const import PROJECT_ROOT
|
||||
from utils.typedefs import NaturalNumber
|
||||
from gram_core.config import ApplicationConfig, config, JoinGroups
|
||||
|
||||
__all__ = ("ApplicationConfig", "config", "JoinGroups")
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
class JoinGroups(str, Enum):
|
||||
NO_ALLOW = "NO_ALLOW"
|
||||
ALLOW_AUTH_USER = "ALLOW_AUTH_USER"
|
||||
ALLOW_USER = "ALLOW_USER"
|
||||
ALLOW_ALL = "ALLOW_ALL"
|
||||
|
||||
|
||||
class DatabaseConfig(Settings):
|
||||
driver_name: str = "mysql+asyncmy"
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
database: Optional[str] = None
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "db_"
|
||||
|
||||
|
||||
class RedisConfig(Settings):
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 6379
|
||||
database: int = Field(default=0, env="redis_db")
|
||||
password: Optional[str] = None
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "redis_"
|
||||
|
||||
|
||||
class LoggerConfig(Settings):
|
||||
name: str = "PaiGram"
|
||||
width: Optional[int] = None
|
||||
time_format: str = "[%Y-%m-%d %X]"
|
||||
traceback_max_frames: int = 20
|
||||
path: Path = PROJECT_ROOT / "logs"
|
||||
render_keywords: List[str] = ["BOT"]
|
||||
locals_max_length: int = 10
|
||||
locals_max_string: int = 80
|
||||
locals_max_depth: Optional[NaturalNumber] = None
|
||||
filtered_names: List[str] = ["uvicorn"]
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "logger_"
|
||||
|
||||
|
||||
class MTProtoConfig(Settings):
|
||||
api_id: Optional[int] = None
|
||||
api_hash: Optional[str] = None
|
||||
|
||||
|
||||
class WebServerConfig(Settings):
|
||||
enable: bool = False
|
||||
"""是否启用WebServer"""
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 8080
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "web_"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
# noinspection HttpUrlsUsage
|
||||
return "http://" + self.host + ":" + str(self.port)
|
||||
|
||||
|
||||
class ErrorConfig(Settings):
|
||||
pb_url: str = ""
|
||||
pb_sunset: int = 43200
|
||||
pb_max_lines: int = 1000
|
||||
sentry_dsn: str = ""
|
||||
notification_chat_id: Optional[str] = None
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "error_"
|
||||
|
||||
|
||||
class ReloadConfig(Settings):
|
||||
delay: float = 0.25
|
||||
dirs: List[str] = []
|
||||
include: List[str] = []
|
||||
exclude: List[str] = []
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "reload_"
|
||||
|
||||
|
||||
class NoticeConfig(Settings):
|
||||
user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!"
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "notice_"
|
||||
|
||||
|
||||
class ApplicationConfig(Settings):
|
||||
debug: bool = False
|
||||
"""debug 开关"""
|
||||
retry: int = 5
|
||||
"""重试次数"""
|
||||
auto_reload: bool = False
|
||||
"""自动重载"""
|
||||
|
||||
proxy_url: Optional[AnyUrl] = None
|
||||
"""代理链接"""
|
||||
upload_bbs_host: Optional[AnyUrl] = "https://upload-bbs.miyoushe.com"
|
||||
|
||||
bot_token: str = ""
|
||||
"""BOT的token"""
|
||||
|
||||
owner: Optional[int] = None
|
||||
|
||||
channels: List[int] = []
|
||||
"""文章推送群组"""
|
||||
|
||||
verify_groups: List[Union[int, str]] = []
|
||||
"""启用群验证功能的群组"""
|
||||
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
|
||||
"""是否允许机器人被邀请到其它群组"""
|
||||
|
||||
timeout: int = 10
|
||||
connection_pool_size: int = 256
|
||||
read_timeout: Optional[float] = None
|
||||
write_timeout: Optional[float] = None
|
||||
connect_timeout: Optional[float] = None
|
||||
pool_timeout: Optional[float] = None
|
||||
update_read_timeout: Optional[float] = None
|
||||
update_write_timeout: Optional[float] = None
|
||||
update_connect_timeout: Optional[float] = None
|
||||
update_pool_timeout: Optional[float] = None
|
||||
|
||||
genshin_ttl: Optional[int] = None
|
||||
|
||||
enka_network_api_agent: str = ""
|
||||
pass_challenge_api: str = ""
|
||||
pass_challenge_app_key: str = ""
|
||||
pass_challenge_user_web: str = ""
|
||||
|
||||
reload: ReloadConfig = ReloadConfig()
|
||||
database: DatabaseConfig = DatabaseConfig()
|
||||
logger: LoggerConfig = LoggerConfig()
|
||||
webserver: WebServerConfig = WebServerConfig()
|
||||
redis: RedisConfig = RedisConfig()
|
||||
mtproto: MTProtoConfig = MTProtoConfig()
|
||||
error: ErrorConfig = ErrorConfig()
|
||||
notice: NoticeConfig = NoticeConfig()
|
||||
|
||||
|
||||
ApplicationConfig.update_forward_refs()
|
||||
config = ApplicationConfig()
|
||||
|
@ -1,56 +1,3 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from playwright.async_api import Error, async_playwright
|
||||
|
||||
from core.base_service import BaseService
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from playwright.async_api import Playwright as AsyncPlaywright, Browser
|
||||
from gram_core.dependence.aiobrowser import AioBrowser
|
||||
|
||||
__all__ = ("AioBrowser",)
|
||||
|
||||
|
||||
class AioBrowser(BaseService.Dependence):
|
||||
@property
|
||||
def browser(self):
|
||||
return self._browser
|
||||
|
||||
def __init__(self, loop=None):
|
||||
self._browser: Optional["Browser"] = None
|
||||
self._playwright: Optional["AsyncPlaywright"] = None
|
||||
self._loop = loop
|
||||
|
||||
async def get_browser(self):
|
||||
if self._browser is None:
|
||||
await self.initialize()
|
||||
return self._browser
|
||||
|
||||
async def initialize(self):
|
||||
if self._playwright is None:
|
||||
logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True})
|
||||
self._playwright = await async_playwright().start()
|
||||
logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True})
|
||||
if self._browser is None:
|
||||
logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True})
|
||||
try:
|
||||
self._browser = await self._playwright.chromium.launch(timeout=5000)
|
||||
logger.success("[blue]Browser[/] 启动成功", extra={"markup": True})
|
||||
except Error as err:
|
||||
if "playwright install" in str(err):
|
||||
logger.error(
|
||||
"检查到 [blue]playwright[/] 刚刚安装或者未升级\n"
|
||||
"请运行以下命令下载新浏览器\n"
|
||||
"[blue bold]playwright install chromium[/]",
|
||||
extra={"markup": True},
|
||||
)
|
||||
raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium")
|
||||
raise err
|
||||
|
||||
return self._browser
|
||||
|
||||
async def shutdown(self):
|
||||
if self._browser is not None:
|
||||
await self._browser.close()
|
||||
if self._playwright is not None:
|
||||
self._playwright.stop()
|
||||
|
@ -1,16 +0,0 @@
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
from playwright.async_api import Browser, Playwright as AsyncPlaywright
|
||||
|
||||
from core.base_service import BaseService
|
||||
|
||||
__all__ = ("AioBrowser",)
|
||||
|
||||
class AioBrowser(BaseService.Dependence):
|
||||
_browser: Browser | None
|
||||
_playwright: AsyncPlaywright | None
|
||||
_loop: AbstractEventLoop
|
||||
|
||||
@property
|
||||
def browser(self) -> Browser | None: ...
|
||||
async def get_browser(self) -> Browser: ...
|
@ -1,51 +1,3 @@
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing_extensions import Self
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.config import ApplicationConfig
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
from gram_core.dependence.database import Database
|
||||
|
||||
__all__ = ("Database",)
|
||||
|
||||
|
||||
class Database(BaseService.Dependence):
|
||||
@classmethod
|
||||
def from_config(cls, config: ApplicationConfig) -> Self:
|
||||
return cls(**config.database.dict())
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
driver_name: str,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
):
|
||||
self.database = database # skipcq: PTC-W0052
|
||||
self.password = password
|
||||
self.username = username
|
||||
self.port = port
|
||||
self.host = host
|
||||
self.url = URL.create(
|
||||
driver_name,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
database=self.database,
|
||||
)
|
||||
self.engine = create_async_engine(self.url)
|
||||
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def session(self) -> AsyncSession:
|
||||
yield self.Session()
|
||||
|
||||
async def shutdown(self):
|
||||
self.Session.close_all()
|
||||
|
@ -1,67 +1,3 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
from gram_core.dependence.mtproto import MTProto
|
||||
|
||||
import aiofiles
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.config import config as bot_config
|
||||
from utils.log import logger
|
||||
|
||||
try:
|
||||
from pyrogram import Client
|
||||
from pyrogram.session import session
|
||||
|
||||
session.log.debug = lambda *args, **kwargs: None # 关闭日记
|
||||
PYROGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
Client = None
|
||||
session = None
|
||||
PYROGRAM_AVAILABLE = False
|
||||
|
||||
|
||||
class MTProto(BaseService.Dependence):
|
||||
async def get_session(self):
|
||||
async with aiofiles.open(self.session_path, mode="r") as f:
|
||||
return await f.read()
|
||||
|
||||
async def set_session(self, b: str):
|
||||
async with aiofiles.open(self.session_path, mode="w+") as f:
|
||||
await f.write(b)
|
||||
|
||||
def session_exists(self):
|
||||
return os.path.exists(self.session_path)
|
||||
|
||||
def __init__(self):
|
||||
self.name = "paigram"
|
||||
current_dir = os.getcwd()
|
||||
self.session_path = os.path.join(current_dir, "paigram.session")
|
||||
self.client: Optional[Client] = None
|
||||
self.proxy: Optional[dict] = None
|
||||
http_proxy = os.environ.get("HTTP_PROXY")
|
||||
if http_proxy is not None:
|
||||
http_proxy_url = urlparse(http_proxy)
|
||||
self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port}
|
||||
|
||||
async def initialize(self): # pylint: disable=W0221
|
||||
if not PYROGRAM_AVAILABLE:
|
||||
logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None")
|
||||
return
|
||||
if bot_config.mtproto.api_id is None:
|
||||
logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None")
|
||||
return
|
||||
if bot_config.mtproto.api_hash is None:
|
||||
logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None")
|
||||
return
|
||||
self.client = Client(
|
||||
api_id=bot_config.mtproto.api_id,
|
||||
api_hash=bot_config.mtproto.api_hash,
|
||||
name=self.name,
|
||||
bot_token=bot_config.bot_token,
|
||||
proxy=self.proxy,
|
||||
)
|
||||
await self.client.start()
|
||||
|
||||
async def shutdown(self): # pylint: disable=W0221
|
||||
if self.client is not None:
|
||||
await self.client.stop(block=False)
|
||||
__all__ = ("MTProto",)
|
||||
|
@ -1,31 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypedDict
|
||||
|
||||
from core.base_service import BaseService
|
||||
|
||||
try:
|
||||
from pyrogram import Client
|
||||
from pyrogram.session import session
|
||||
|
||||
PYROGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
Client = None
|
||||
session = None
|
||||
PYROGRAM_AVAILABLE = False
|
||||
|
||||
__all__ = ("MTProto",)
|
||||
|
||||
class _ProxyType(TypedDict):
|
||||
scheme: str
|
||||
hostname: str | None
|
||||
port: int | None
|
||||
|
||||
class MTProto(BaseService.Dependence):
|
||||
name: str
|
||||
session_path: str
|
||||
client: Client | None
|
||||
proxy: _ProxyType | None
|
||||
|
||||
async def get_session(self) -> str: ...
|
||||
async def set_session(self, b: str) -> None: ...
|
||||
def session_exists(self) -> bool: ...
|
@ -1,50 +1,3 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import fakeredis.aioredis
|
||||
from redis import asyncio as aioredis
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError
|
||||
from typing_extensions import Self
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.config import ApplicationConfig
|
||||
from utils.log import logger
|
||||
from gram_core.dependence.redisdb import RedisDB
|
||||
|
||||
__all__ = ["RedisDB"]
|
||||
|
||||
|
||||
class RedisDB(BaseService.Dependence):
|
||||
@classmethod
|
||||
def from_config(cls, config: ApplicationConfig) -> Self:
|
||||
return cls(**config.redis.dict())
|
||||
|
||||
def __init__(
|
||||
self, host: str = "127.0.0.1", port: int = 6379, database: Union[str, int] = 0, password: Optional[str] = None
|
||||
):
|
||||
self.client = aioredis.Redis(host=host, port=port, db=database, password=password)
|
||||
self.ttl = 600
|
||||
|
||||
async def ping(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
if await self.client.ping():
|
||||
logger.info("连接 [red]Redis[/] 成功", extra={"markup": True})
|
||||
else:
|
||||
logger.info("连接 [red]Redis[/] 失败", extra={"markup": True})
|
||||
raise RuntimeError("连接 Redis 失败")
|
||||
|
||||
async def start_fake_redis(self):
|
||||
self.client = fakeredis.aioredis.FakeRedis()
|
||||
await self.ping()
|
||||
|
||||
async def initialize(self):
|
||||
logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True})
|
||||
try:
|
||||
await self.ping()
|
||||
except (RedisTimeoutError, RedisConnectionError) as exc:
|
||||
if isinstance(exc, RedisTimeoutError):
|
||||
logger.warning("连接 [red]Redis[/] 超时,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
|
||||
if isinstance(exc, RedisConnectionError):
|
||||
logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
|
||||
await self.start_fake_redis()
|
||||
|
||||
async def shutdown(self):
|
||||
await self.client.close()
|
||||
|
@ -1,7 +1,4 @@
|
||||
"""此模块包含核心模块的错误的基类"""
|
||||
from typing import Union
|
||||
from gram_core.error import ServiceNotFoundError
|
||||
|
||||
|
||||
class ServiceNotFoundError(Exception):
|
||||
def __init__(self, name: Union[str, type]):
|
||||
super().__init__(f"No service named '{name if isinstance(name, str) else name.__name__}'")
|
||||
__all__ = ("ServiceNotFoundError",)
|
||||
|
@ -1,59 +1,3 @@
|
||||
import asyncio
|
||||
from typing import TypeVar, TYPE_CHECKING, Any, Optional
|
||||
from gram_core.handler.adminhandler import AdminHandler
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ApplicationHandlerStop, BaseHandler
|
||||
|
||||
from core.error import ServiceNotFoundError
|
||||
from core.services.users.services import UserAdminService
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
from telegram.ext import Application as TelegramApplication
|
||||
|
||||
RT = TypeVar("RT")
|
||||
UT = TypeVar("UT")
|
||||
|
||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||
|
||||
|
||||
class AdminHandler(BaseHandler[Update, CCT]):
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None:
|
||||
self.handler = handler
|
||||
self.application = application
|
||||
self.user_service: Optional["UserAdminService"] = None
|
||||
super().__init__(self.handler.callback, self.handler.block)
|
||||
|
||||
def check_update(self, update: object) -> bool:
|
||||
if not isinstance(update, Update):
|
||||
return False
|
||||
return self.handler.check_update(update)
|
||||
|
||||
async def _user_service(self) -> "UserAdminService":
|
||||
async with self._lock:
|
||||
if self.user_service is not None:
|
||||
return self.user_service
|
||||
user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None)
|
||||
if user_service is None:
|
||||
raise ServiceNotFoundError("UserAdminService")
|
||||
self.user_service = user_service
|
||||
return self.user_service
|
||||
|
||||
async def handle_update(
|
||||
self,
|
||||
update: "UT",
|
||||
application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]",
|
||||
check_result: Any,
|
||||
context: "CCT",
|
||||
) -> RT:
|
||||
user_service = await self._user_service()
|
||||
user = update.effective_user
|
||||
if await user_service.is_admin(user.id):
|
||||
return await self.handler.handle_update(update, application, check_result, context)
|
||||
message = update.effective_message
|
||||
logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id)
|
||||
await message.reply_text("权限不足")
|
||||
raise ApplicationHandlerStop
|
||||
__all__ = ("AdminHandler",)
|
||||
|
@ -1,62 +1,3 @@
|
||||
import asyncio
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from types import TracebackType
|
||||
from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type
|
||||
from gram_core.handler.callbackqueryhandler import CallbackQueryHandler, OverlappingException, OverlappingContext
|
||||
|
||||
from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram.ext import Application
|
||||
|
||||
RT = TypeVar("RT")
|
||||
UT = TypeVar("UT")
|
||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||
|
||||
|
||||
class OverlappingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OverlappingContext(AbstractAsyncContextManager):
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, context: "CCT"):
|
||||
self.context = context
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
async with self._lock:
|
||||
flag = self.context.user_data.get("overlapping", False)
|
||||
if flag:
|
||||
raise OverlappingException
|
||||
self.context.user_data["overlapping"] = True
|
||||
return None
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
async with self._lock:
|
||||
del self.context.user_data["overlapping"]
|
||||
return None
|
||||
|
||||
|
||||
class CallbackQueryHandler(BaseCallbackQueryHandler):
|
||||
async def handle_update(
|
||||
self,
|
||||
update: "UT",
|
||||
application: "Application[Any, CCT, Any, Any, Any, Any]",
|
||||
check_result: Any,
|
||||
context: "CCT",
|
||||
) -> RT:
|
||||
self.collect_additional_context(context, update, application, check_result)
|
||||
try:
|
||||
async with OverlappingContext(context):
|
||||
return await self.callback(update, context)
|
||||
except OverlappingException as exc:
|
||||
user = update.effective_user
|
||||
logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id)
|
||||
raise ApplicationHandlerStop from exc
|
||||
__all__ = ("CallbackQueryHandler", "OverlappingException", "OverlappingContext")
|
||||
|
@ -1,71 +1,3 @@
|
||||
import asyncio
|
||||
from typing import TypeVar, Optional
|
||||
from gram_core.handler.limiterhandler import LimiterHandler
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
UT = TypeVar("UT")
|
||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||
|
||||
|
||||
class LimiterHandler(TypeHandler[UT, CCT]):
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(
|
||||
self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None
|
||||
):
|
||||
"""Limiter Handler 通过
|
||||
`Leaky bucket algorithm <https://en.wikipedia.org/wiki/Leaky_bucket>`_
|
||||
实现对用户的输入的精确控制
|
||||
|
||||
输入超过一定速率后,代码会抛出
|
||||
:class:`telegram.ext.ApplicationHandlerStop`
|
||||
异常并在一段时间内防止用户执行任何其他操作
|
||||
|
||||
:param max_rate: 在抛出异常之前最多允许 频率/秒 的速度
|
||||
:param time_period: 在限制速率的时间段的持续时间
|
||||
:param amount: 提供的容量
|
||||
:param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount
|
||||
"""
|
||||
self.max_rate = max_rate
|
||||
self.amount = amount
|
||||
self._rate_per_sec = max_rate / time_period
|
||||
self.limit_time = limit_time
|
||||
super().__init__(Update, self.limiter_callback)
|
||||
|
||||
async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if update.inline_query is not None:
|
||||
return
|
||||
loop = asyncio.get_running_loop()
|
||||
async with self._lock:
|
||||
time = loop.time()
|
||||
user_data = context.user_data
|
||||
if user_data is None:
|
||||
return
|
||||
user_limit_time = user_data.get("limit_time")
|
||||
if user_limit_time is not None:
|
||||
if time >= user_limit_time:
|
||||
del user_data["limit_time"]
|
||||
else:
|
||||
raise ApplicationHandlerStop
|
||||
last_task_time = user_data.get("last_task_time", 0)
|
||||
if last_task_time:
|
||||
task_level = user_data.get("task_level", 0)
|
||||
elapsed = time - last_task_time
|
||||
decrement = elapsed * self._rate_per_sec
|
||||
task_level = max(task_level - decrement, 0)
|
||||
user_data["task_level"] = task_level
|
||||
if not task_level + self.amount <= self.max_rate:
|
||||
if self.limit_time:
|
||||
limit_time = self.limit_time
|
||||
else:
|
||||
limit_time = 1 / self._rate_per_sec * self.amount
|
||||
user_data["limit_time"] = time + limit_time
|
||||
user = update.effective_user
|
||||
logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s 秒", user.full_name, user.id, limit_time)
|
||||
raise ApplicationHandlerStop
|
||||
user_data["last_task_time"] = time
|
||||
task_level = user_data.get("task_level", 0)
|
||||
user_data["task_level"] = task_level + self.amount
|
||||
__all__ = ("LimiterHandler",)
|
||||
|
286
core/manager.py
286
core/manager.py
@ -1,286 +0,0 @@
|
||||
import asyncio
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
from arkowrapper import ArkoWrapper
|
||||
from async_timeout import timeout
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
|
||||
from core.config import config as bot_config
|
||||
from utils.const import PLUGIN_DIR, PROJECT_ROOT
|
||||
from utils.helpers import gen_pkg
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
from core.plugin import PluginType
|
||||
from core.builtins.executor import Executor
|
||||
|
||||
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
|
||||
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _load_module(path: Path) -> None:
|
||||
for pkg in gen_pkg(path):
|
||||
try:
|
||||
logger.debug('正在导入 "%s"', pkg)
|
||||
import_module(pkg)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
|
||||
)
|
||||
raise SystemExit from e
|
||||
|
||||
|
||||
class Manager(Generic[T]):
|
||||
"""生命周期控制基类"""
|
||||
|
||||
_executor: Optional["Executor"] = None
|
||||
_lib: Dict[Type[T], T] = {}
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||
return self._application
|
||||
|
||||
@property
|
||||
def executor(self) -> "Executor":
|
||||
"""执行器"""
|
||||
if self._executor is None:
|
||||
raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.")
|
||||
return self._executor
|
||||
|
||||
def build_executor(self, name: str):
|
||||
from core.builtins.executor import Executor
|
||||
from core.builtins.dispatcher import BaseDispatcher
|
||||
|
||||
self._executor = Executor(name, dispatcher=BaseDispatcher)
|
||||
self._executor.set_application(self.application)
|
||||
|
||||
|
||||
class DependenceManager(Manager[DependenceType]):
|
||||
"""基础依赖管理"""
|
||||
|
||||
_dependency: Dict[Type[DependenceType], DependenceType] = {}
|
||||
|
||||
@property
|
||||
def dependency(self) -> List[DependenceType]:
|
||||
return list(self._dependency.values())
|
||||
|
||||
@property
|
||||
def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]:
|
||||
return self._dependency
|
||||
|
||||
async def start_dependency(self) -> None:
|
||||
_load_module(PROJECT_ROOT / "core/dependence")
|
||||
|
||||
for dependence in filter(lambda x: x.is_dependence, get_all_services()):
|
||||
dependence: Type[DependenceType]
|
||||
instance: DependenceType
|
||||
try:
|
||||
if hasattr(dependence, "from_config"): # 如果有 from_config 方法
|
||||
instance = dependence.from_config(bot_config) # 用 from_config 实例化服务
|
||||
else:
|
||||
instance = await self.executor(dependence)
|
||||
|
||||
await instance.initialize()
|
||||
logger.success('基础服务 "%s" 启动成功', dependence.__name__)
|
||||
|
||||
self._lib[dependence] = instance
|
||||
self._dependency[dependence] = instance
|
||||
|
||||
except Exception as e:
|
||||
logger.exception('基础服务 "%s" 初始化失败,BOT 将自动关闭', dependence.__name__)
|
||||
raise SystemExit from e
|
||||
|
||||
async def stop_dependency(self) -> None:
|
||||
async def task(d):
|
||||
try:
|
||||
async with timeout(5):
|
||||
await d.shutdown()
|
||||
logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e)
|
||||
|
||||
tasks = []
|
||||
for dependence in self._dependency.values():
|
||||
tasks.append(asyncio.create_task(task(dependence)))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
class ComponentManager(Manager[ComponentType]):
|
||||
"""组件管理"""
|
||||
|
||||
_components: Dict[Type[ComponentType], ComponentType] = {}
|
||||
|
||||
@property
|
||||
def components(self) -> List[ComponentType]:
|
||||
return list(self._components.values())
|
||||
|
||||
@property
|
||||
def components_map(self) -> Dict[Type[ComponentType], ComponentType]:
|
||||
return self._components
|
||||
|
||||
async def init_components(self):
|
||||
for path in filter(
|
||||
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
|
||||
):
|
||||
_load_module(path)
|
||||
components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component)
|
||||
retry_times = 0
|
||||
max_retry_times = len(components)
|
||||
while components:
|
||||
start_len = len(components)
|
||||
for component in list(components):
|
||||
component: Type[ComponentType]
|
||||
instance: ComponentType
|
||||
try:
|
||||
instance = await self.executor(component)
|
||||
self._lib[component] = instance
|
||||
self._components[component] = instance
|
||||
components = components.remove(component)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True})
|
||||
end_len = len(list(components))
|
||||
if start_len == end_len:
|
||||
retry_times += 1
|
||||
|
||||
if retry_times == max_retry_times and components:
|
||||
for component in components:
|
||||
logger.error('组件 "%s" 初始化失败', component.__name__)
|
||||
raise SystemExit
|
||||
|
||||
|
||||
class ServiceManager(Manager[BaseServiceType]):
|
||||
"""服务控制类"""
|
||||
|
||||
_services: Dict[Type[BaseServiceType], BaseServiceType] = {}
|
||||
|
||||
@property
|
||||
def services(self) -> List[BaseServiceType]:
|
||||
return list(self._services.values())
|
||||
|
||||
@property
|
||||
def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]:
|
||||
return self._services
|
||||
|
||||
async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType:
|
||||
instance: BaseServiceType
|
||||
try:
|
||||
if hasattr(target, "from_config"): # 如果有 from_config 方法
|
||||
instance = target.from_config(bot_config) # 用 from_config 实例化服务
|
||||
else:
|
||||
instance = await self.executor(target)
|
||||
|
||||
await instance.initialize()
|
||||
logger.success('服务 "%s" 启动成功', target.__name__)
|
||||
|
||||
return instance
|
||||
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception('服务 "%s" 初始化失败,BOT 将自动关闭', target.__name__)
|
||||
raise SystemExit from e
|
||||
|
||||
async def start_services(self) -> None:
|
||||
for path in filter(
|
||||
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
|
||||
):
|
||||
_load_module(path)
|
||||
|
||||
for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类
|
||||
instance = await self._initialize_service(service)
|
||||
|
||||
self._lib[service] = instance
|
||||
self._services[service] = instance
|
||||
|
||||
async def stop_services(self) -> None:
|
||||
"""关闭服务"""
|
||||
if not self._services:
|
||||
return
|
||||
|
||||
async def task(s):
|
||||
try:
|
||||
async with timeout(5):
|
||||
await s.shutdown()
|
||||
logger.success('服务 "%s" 关闭成功', s.__class__.__name__)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('服务 "%s" 关闭超时', s.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e)
|
||||
|
||||
logger.info("正在关闭服务")
|
||||
tasks = []
|
||||
for service in self._services.values():
|
||||
tasks.append(asyncio.create_task(task(service)))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
class PluginManager(Manager["PluginType"]):
|
||||
"""插件管理"""
|
||||
|
||||
_plugins: Dict[Type["PluginType"], "PluginType"] = {}
|
||||
|
||||
@property
|
||||
def plugins(self) -> List["PluginType"]:
|
||||
"""所有已经加载的插件"""
|
||||
return list(self._plugins.values())
|
||||
|
||||
@property
|
||||
def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]:
|
||||
return self._plugins
|
||||
|
||||
async def install_plugins(self) -> None:
|
||||
"""安装所有插件"""
|
||||
from core.plugin import get_all_plugins
|
||||
|
||||
for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()):
|
||||
_load_module(path)
|
||||
|
||||
for plugin in get_all_plugins():
|
||||
plugin: Type["PluginType"]
|
||||
|
||||
try:
|
||||
instance: "PluginType" = await self.executor(plugin)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||
continue
|
||||
|
||||
self._plugins[plugin] = instance
|
||||
|
||||
if self._application is not None:
|
||||
instance.set_application(self._application)
|
||||
|
||||
await asyncio.create_task(self.plugin_install_task(plugin, instance))
|
||||
|
||||
@staticmethod
|
||||
async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"):
|
||||
try:
|
||||
await instance.install()
|
||||
logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}")
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||
|
||||
async def uninstall_plugins(self) -> None:
|
||||
for plugin in self._plugins.values():
|
||||
try:
|
||||
await plugin.uninstall()
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||
|
||||
|
||||
class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager):
|
||||
"""BOT 除自身外的生命周期管理类"""
|
@ -1,117 +0,0 @@
|
||||
"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化"""
|
||||
from typing import Any, AsyncIterable, Optional
|
||||
|
||||
import httpcore
|
||||
from httpx import (
|
||||
AsyncByteStream,
|
||||
AsyncHTTPTransport as DefaultAsyncHTTPTransport,
|
||||
Limits,
|
||||
Response as DefaultResponse,
|
||||
Timeout,
|
||||
)
|
||||
from telegram.request import HTTPXRequest as DefaultHTTPXRequest
|
||||
|
||||
try:
|
||||
import ujson as jsonlib
|
||||
except ImportError:
|
||||
import json as jsonlib
|
||||
|
||||
__all__ = ("HTTPXRequest",)
|
||||
|
||||
|
||||
class Response(DefaultResponse):
|
||||
def json(self, **kwargs: Any) -> Any:
|
||||
# noinspection PyProtectedMember
|
||||
from httpx._utils import guess_json_utf
|
||||
|
||||
if self.charset_encoding is None and self.content and len(self.content) > 3:
|
||||
encoding = guess_json_utf(self.content)
|
||||
if encoding is not None:
|
||||
return jsonlib.loads(self.content.decode(encoding), **kwargs)
|
||||
return jsonlib.loads(self.text, **kwargs)
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
class AsyncHTTPTransport(DefaultAsyncHTTPTransport):
|
||||
async def handle_async_request(self, request) -> Response:
|
||||
from httpx._transports.default import (
|
||||
map_httpcore_exceptions,
|
||||
AsyncResponseStream,
|
||||
)
|
||||
|
||||
if not isinstance(request.stream, AsyncByteStream):
|
||||
raise AssertionError
|
||||
|
||||
req = httpcore.Request(
|
||||
method=request.method,
|
||||
url=httpcore.URL(
|
||||
scheme=request.url.raw_scheme,
|
||||
host=request.url.raw_host,
|
||||
port=request.url.port,
|
||||
target=request.url.raw_path,
|
||||
),
|
||||
headers=request.headers.raw,
|
||||
content=request.stream,
|
||||
extensions=request.extensions,
|
||||
)
|
||||
with map_httpcore_exceptions():
|
||||
resp = await self._pool.handle_async_request(req)
|
||||
|
||||
if not isinstance(resp.stream, AsyncIterable):
|
||||
raise AssertionError
|
||||
|
||||
return Response(
|
||||
status_code=resp.status,
|
||||
headers=resp.headers,
|
||||
stream=AsyncResponseStream(resp.stream),
|
||||
extensions=resp.extensions,
|
||||
)
|
||||
|
||||
|
||||
class HTTPXRequest(DefaultHTTPXRequest):
|
||||
def __init__( # pylint: disable=W0231
|
||||
self,
|
||||
connection_pool_size: int = 1,
|
||||
proxy_url: str = None,
|
||||
read_timeout: Optional[float] = 5.0,
|
||||
write_timeout: Optional[float] = 5.0,
|
||||
connect_timeout: Optional[float] = 5.0,
|
||||
pool_timeout: Optional[float] = 1.0,
|
||||
http_version: str = "1.1",
|
||||
):
|
||||
self._http_version = http_version
|
||||
timeout = Timeout(
|
||||
connect=connect_timeout,
|
||||
read=read_timeout,
|
||||
write=write_timeout,
|
||||
pool=pool_timeout,
|
||||
)
|
||||
limits = Limits(
|
||||
max_connections=connection_pool_size,
|
||||
max_keepalive_connections=connection_pool_size,
|
||||
)
|
||||
if http_version not in ("1.1", "2"):
|
||||
raise ValueError("`http_version` must be either '1.1' or '2'.")
|
||||
http1 = http_version == "1.1"
|
||||
self._client_kwargs = dict(
|
||||
timeout=timeout,
|
||||
proxies=proxy_url,
|
||||
limits=limits,
|
||||
transport=AsyncHTTPTransport(limits=limits),
|
||||
http1=http1,
|
||||
http2=not http1,
|
||||
)
|
||||
|
||||
try:
|
||||
self._client = self._build_client()
|
||||
except ImportError as exc:
|
||||
if "httpx[http2]" not in str(exc) and "httpx[socks]" not in str(exc):
|
||||
raise exc
|
||||
|
||||
if "httpx[socks]" in str(exc):
|
||||
raise RuntimeError(
|
||||
"To use Socks5 proxies, PTB must be installed via `pip install " "python-telegram-bot[socks]`."
|
||||
) from exc
|
||||
raise RuntimeError(
|
||||
"To use HTTP/2, PTB must be installed via `pip install " "python-telegram-bot[http2]`."
|
||||
) from exc
|
@ -1,8 +1,8 @@
|
||||
"""插件"""
|
||||
|
||||
from core.plugin._handler import conversation, error_handler, handler
|
||||
from core.plugin._job import TimeType, job
|
||||
from core.plugin._plugin import Plugin, PluginType, get_all_plugins
|
||||
from gram_core.plugin._handler import conversation, error_handler, handler
|
||||
from gram_core.plugin._job import TimeType, job
|
||||
from gram_core.plugin._plugin import Plugin, PluginType, get_all_plugins
|
||||
|
||||
__all__ = (
|
||||
"Plugin",
|
||||
|
@ -1,178 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from httpx import UnsupportedProtocol
|
||||
from telegram import Chat, Message, ReplyKeyboardRemove, Update
|
||||
from telegram.error import Forbidden, NetworkError
|
||||
from telegram.ext import CallbackContext, ConversationHandler, Job
|
||||
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from core.plugin._handler import conversation, handler
|
||||
from utils.const import CACHE_DIR, REQUEST_HEADERS
|
||||
from utils.error import UrlResourcesNotFoundError
|
||||
from utils.helpers import sha1
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
__all__ = (
|
||||
"PluginFuncs",
|
||||
"ConversationFuncs",
|
||||
)
|
||||
|
||||
|
||||
class PluginFuncs:
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError("No application was set for this PluginManager.")
|
||||
return self._application
|
||||
|
||||
async def _delete_message(self, context: CallbackContext) -> None:
|
||||
job = context.job
|
||||
message_id = job.data
|
||||
chat_info = f"chat_id[{job.chat_id}]"
|
||||
|
||||
try:
|
||||
chat = await self.get_chat(job.chat_id)
|
||||
full_name = chat.full_name
|
||||
if full_name:
|
||||
chat_info = f"{full_name}[{chat.id}]"
|
||||
else:
|
||||
chat_info = f"{chat.title}[{chat.id}]"
|
||||
except (NetworkError, Forbidden) as exc:
|
||||
logger.warning("获取 chat info 失败 %s", exc.message)
|
||||
except Exception as exc:
|
||||
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
||||
|
||||
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
||||
|
||||
try:
|
||||
# noinspection PyTypeChecker
|
||||
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
||||
except NetworkError as exc:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
except Forbidden as exc:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
except Exception as exc:
|
||||
logger.error("删除消息 %s message_id[%s] 失败", chat_info, message_id, exc_info=exc)
|
||||
|
||||
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat:
|
||||
application = self.application
|
||||
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
|
||||
|
||||
if not redis_db:
|
||||
return await application.bot.get_chat(chat_id)
|
||||
|
||||
qname = f"bot:chat:{chat_id}"
|
||||
|
||||
data = await redis_db.client.get(qname)
|
||||
if data:
|
||||
json_data = json.loads(data)
|
||||
return Chat.de_json(json_data, application.telegram.bot)
|
||||
|
||||
chat_info = await application.telegram.bot.get_chat(chat_id)
|
||||
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
|
||||
return chat_info
|
||||
|
||||
def add_delete_message_job(
|
||||
self,
|
||||
message: Optional[Union[int, Message]] = None,
|
||||
*,
|
||||
delay: int = 60,
|
||||
name: Optional[str] = None,
|
||||
chat: Optional[Union[int, Chat]] = None,
|
||||
context: Optional[CallbackContext] = None,
|
||||
) -> Job:
|
||||
"""延迟删除消息"""
|
||||
|
||||
if isinstance(message, Message):
|
||||
if chat is None:
|
||||
chat = message.chat_id
|
||||
message = message.id
|
||||
|
||||
chat = chat.id if isinstance(chat, Chat) else chat
|
||||
|
||||
job_queue = self.application.job_queue or context.job_queue
|
||||
|
||||
if job_queue is None or chat is None:
|
||||
raise RuntimeError
|
||||
|
||||
return job_queue.run_once(
|
||||
callback=self._delete_message,
|
||||
when=delay,
|
||||
data=message,
|
||||
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
|
||||
chat_id=chat,
|
||||
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def download_resource(url: str, return_path: bool = False) -> str:
|
||||
url_sha1 = sha1(url) # url 的 hash 值
|
||||
pathed_url = Path(url)
|
||||
|
||||
file_name = url_sha1 + pathed_url.suffix
|
||||
file_path = CACHE_DIR.joinpath(file_name)
|
||||
|
||||
if not file_path.exists(): # 若文件不存在,则下载
|
||||
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
|
||||
try:
|
||||
response = await client.get(url)
|
||||
except UnsupportedProtocol:
|
||||
logger.error("链接不支持 url[%s]", url)
|
||||
return ""
|
||||
|
||||
if response.is_error:
|
||||
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
|
||||
raise UrlResourcesNotFoundError(url)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
|
||||
raise UrlResourcesNotFoundError(url)
|
||||
|
||||
async with aiofiles.open(file_path, mode="wb") as f:
|
||||
await f.write(response.content)
|
||||
|
||||
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
|
||||
|
||||
return file_path if return_path else Path(file_path).as_uri()
|
||||
|
||||
@staticmethod
|
||||
def get_args(context: CallbackContext) -> List[str]:
|
||||
args = context.args
|
||||
match = context.match
|
||||
|
||||
if args is None:
|
||||
if match is not None and (command := match.groups()[0]):
|
||||
temp = []
|
||||
command_parts = command.split(" ")
|
||||
for command_part in command_parts:
|
||||
if command_part:
|
||||
temp.append(command_part)
|
||||
return temp
|
||||
return []
|
||||
if len(args) >= 1:
|
||||
return args
|
||||
return []
|
||||
|
||||
|
||||
class ConversationFuncs:
|
||||
@conversation.fallback
|
||||
@handler.command(command="cancel", block=False)
|
||||
async def cancel(self, update: Update, _) -> int:
|
||||
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
|
||||
return ConversationHandler.END
|
@ -1,380 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from importlib import import_module
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram._utils.defaultvalue import DEFAULT_TRUE
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram._utils.types import DVInput
|
||||
from telegram.ext import BaseHandler
|
||||
from telegram.ext.filters import BaseFilter
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from core.handler.callbackqueryhandler import CallbackQueryHandler
|
||||
from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.builtins.dispatcher import AbstractDispatcher
|
||||
|
||||
__all__ = (
|
||||
"handler",
|
||||
"conversation",
|
||||
"ConversationDataType",
|
||||
"ConversationData",
|
||||
"HandlerData",
|
||||
"ErrorHandlerData",
|
||||
"error_handler",
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
UT = TypeVar("UT")
|
||||
|
||||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
||||
HandlerCls = Type[HandlerType]
|
||||
|
||||
Module = import_module("telegram.ext")
|
||||
|
||||
HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||||
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||||
|
||||
ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
||||
|
||||
CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
||||
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
||||
|
||||
WRAPPER_ASSIGNMENTS = list(
|
||||
set(
|
||||
_WRAPPER_ASSIGNMENTS
|
||||
+ [
|
||||
HANDLER_DATA_ATTR_NAME,
|
||||
ERROR_HANDLER_ATTR_NAME,
|
||||
CONVERSATION_HANDLER_ATTR_NAME,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class HandlerData:
|
||||
type: Type[HandlerType]
|
||||
admin: bool
|
||||
kwargs: Dict[str, Any]
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None
|
||||
|
||||
|
||||
class _Handler:
|
||||
_type: Type["HandlerType"]
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""用于获取 python-telegram-bot 中对应的 handler class"""
|
||||
|
||||
handler_name = f"{cls.__name__.strip('_')}Handler"
|
||||
|
||||
if handler_name == "CallbackQueryHandler":
|
||||
cls._type = CallbackQueryHandler
|
||||
return
|
||||
|
||||
cls._type = getattr(Module, handler_name, None)
|
||||
|
||||
def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None:
|
||||
self.dispatcher = dispatcher
|
||||
self.admin = admin
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""decorator实现,从 func 生成 Handler"""
|
||||
|
||||
handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, [])
|
||||
handler_datas.append(
|
||||
HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher)
|
||||
)
|
||||
setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class _CallbackQuery(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _ChatJoinRequest(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _ChatMember(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
chat_member_types: int = -1,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _ChosenInlineResult(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
*,
|
||||
pattern: Union[str, Pattern] = None,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(block=block, pattern=pattern, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Command(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
command: Union[str, List[str]],
|
||||
filters: "BaseFilter" = None,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_Command, self).__init__(
|
||||
command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher
|
||||
)
|
||||
|
||||
|
||||
class _InlineQuery(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[str, Pattern] = None,
|
||||
chat_types: List[str] = None,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Message(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
filters: BaseFilter,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
) -> None:
|
||||
super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _PollAnswer(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Poll(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_Poll, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _PreCheckoutQuery(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Prefix(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
command: str,
|
||||
filters: BaseFilter = None,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_Prefix, self).__init__(
|
||||
prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher
|
||||
)
|
||||
|
||||
|
||||
class _ShippingQuery(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _StringCommand(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
command: str,
|
||||
*,
|
||||
admin: bool = False,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _StringRegex(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[str, Pattern],
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Type(_Handler):
|
||||
# noinspection PyShadowingBuiltins
|
||||
def __init__(
|
||||
self,
|
||||
type: Type[UT], # pylint: disable=W0622
|
||||
strict: bool = False,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
): # pylint: disable=redefined-builtin
|
||||
super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class handler(_Handler):
|
||||
callback_query = _CallbackQuery
|
||||
chat_join_request = _ChatJoinRequest
|
||||
chat_member = _ChatMember
|
||||
chosen_inline_result = _ChosenInlineResult
|
||||
command = _Command
|
||||
inline_query = _InlineQuery
|
||||
message = _Message
|
||||
poll_answer = _PollAnswer
|
||||
pool = _Poll
|
||||
pre_checkout_query = _PreCheckoutQuery
|
||||
prefix = _Prefix
|
||||
shipping_query = _ShippingQuery
|
||||
string_command = _StringCommand
|
||||
string_regex = _StringRegex
|
||||
type = _Type
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]],
|
||||
*,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
self._type = handler_type
|
||||
super().__init__(admin=admin, dispatcher=dispatcher, **kwargs)
|
||||
|
||||
|
||||
class ConversationDataType(Enum):
|
||||
"""conversation handler 的类型"""
|
||||
|
||||
Entry = "entry"
|
||||
State = "state"
|
||||
Fallback = "fallback"
|
||||
|
||||
|
||||
class ConversationData(BaseModel):
|
||||
"""用于储存 conversation handler 的数据"""
|
||||
|
||||
type: ConversationDataType
|
||||
state: Optional[Any] = None
|
||||
|
||||
|
||||
class _ConversationType:
|
||||
_type: ClassVar[ConversationDataType]
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
cls._type = ConversationDataType(cls.__name__.lstrip("_").lower())
|
||||
|
||||
|
||||
def _entry(func: Callable[P, R]) -> Callable[P, R]:
|
||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry))
|
||||
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class _State(_ConversationType):
|
||||
def __init__(self, state: Any) -> None:
|
||||
self.state = state
|
||||
|
||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state))
|
||||
return func
|
||||
|
||||
|
||||
def _fallback(func: Callable[P, R]) -> Callable[P, R]:
|
||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback))
|
||||
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class conversation(_Handler):
|
||||
entry_point = _entry
|
||||
state = _State
|
||||
fallback = _fallback
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class ErrorHandlerData:
|
||||
block: bool
|
||||
func: Optional[Callable] = None
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class error_handler:
|
||||
_func: Callable[P, R]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
block: bool = DEFAULT_TRUE,
|
||||
):
|
||||
self._block = block
|
||||
|
||||
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
|
||||
self._func = func
|
||||
wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self)
|
||||
|
||||
handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, [])
|
||||
handler_datas.append(ErrorHandlerData(block=self._block))
|
||||
setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas)
|
||||
|
||||
return self._func
|
@ -1,173 +0,0 @@
|
||||
"""插件"""
|
||||
import datetime
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram._utils.types import JSONDict
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram.ext._utils.types import JobCallback
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.builtins.dispatcher import AbstractDispatcher
|
||||
|
||||
__all__ = ["TimeType", "job", "JobData"]
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
|
||||
|
||||
_JOB_ATTR_NAME = "_job_data"
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class JobData:
|
||||
name: str
|
||||
data: Any
|
||||
chat_id: int
|
||||
user_id: int
|
||||
type: str
|
||||
job_kwargs: JSONDict = field(default_factory=dict)
|
||||
kwargs: JSONDict = field(default_factory=dict)
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None
|
||||
|
||||
|
||||
class _Job:
|
||||
kwargs: Dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = None,
|
||||
data: object = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.name = name
|
||||
self.data = data
|
||||
self.chat_id = chat_id
|
||||
self.user_id = user_id
|
||||
self.job_kwargs = {} if job_kwargs is None else job_kwargs
|
||||
self.kwargs = kwargs
|
||||
if dispatcher is None:
|
||||
from core.builtins.dispatcher import JobDispatcher
|
||||
|
||||
dispatcher = JobDispatcher
|
||||
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
def __call__(self, func: JobCallback) -> JobCallback:
|
||||
data = JobData(
|
||||
name=self.name,
|
||||
data=self.data,
|
||||
chat_id=self.chat_id,
|
||||
user_id=self.user_id,
|
||||
job_kwargs=self.job_kwargs,
|
||||
kwargs=self.kwargs,
|
||||
type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
|
||||
dispatcher=self.dispatcher,
|
||||
)
|
||||
if hasattr(func, _JOB_ATTR_NAME):
|
||||
job_datas = getattr(func, _JOB_ATTR_NAME)
|
||||
job_datas.append(data)
|
||||
setattr(func, _JOB_ATTR_NAME, job_datas)
|
||||
else:
|
||||
setattr(func, _JOB_ATTR_NAME, [data])
|
||||
return func
|
||||
|
||||
|
||||
class _RunOnce(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
when: TimeType,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when)
|
||||
|
||||
|
||||
class _RunRepeating(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
interval: Union[float, datetime.timedelta],
|
||||
first: TimeType = None,
|
||||
last: TimeType = None,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last
|
||||
)
|
||||
|
||||
|
||||
class _RunMonthly(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
when: datetime.time,
|
||||
day: int,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day)
|
||||
|
||||
|
||||
class _RunDaily(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
time: datetime.time,
|
||||
days: Tuple[int, ...] = tuple(range(7)),
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days)
|
||||
|
||||
|
||||
class _RunCustom(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher)
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class job:
|
||||
run_once = _RunOnce
|
||||
run_repeating = _RunRepeating
|
||||
run_monthly = _RunMonthly
|
||||
run_daily = _RunDaily
|
||||
run_custom = _RunCustom
|
@ -1,314 +0,0 @@
|
||||
"""插件"""
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from dataclasses import asdict
|
||||
from datetime import timedelta
|
||||
from functools import partial, wraps
|
||||
from itertools import chain
|
||||
from multiprocessing import RLock as Lock
|
||||
from types import MethodType
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from core.handler.adminhandler import AdminHandler
|
||||
from core.plugin._funcs import ConversationFuncs, PluginFuncs
|
||||
from core.plugin._handler import ConversationDataType
|
||||
from utils.const import WRAPPER_ASSIGNMENTS
|
||||
from utils.helpers import isabstract
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData
|
||||
from core.plugin._job import JobData
|
||||
from multiprocessing.synchronize import RLock as LockType
|
||||
|
||||
__all__ = ("Plugin", "PluginType", "get_all_plugins")
|
||||
|
||||
wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS)
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
||||
|
||||
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||||
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||||
|
||||
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
||||
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
||||
|
||||
_ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
||||
|
||||
_JOB_ATTR_NAME = "_job_data"
|
||||
|
||||
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
|
||||
|
||||
|
||||
class _Plugin(PluginFuncs):
|
||||
"""插件"""
|
||||
|
||||
_lock: ClassVar["LockType"] = Lock()
|
||||
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
|
||||
_installed: bool = False
|
||||
|
||||
_handlers: Optional[List[HandlerType]] = None
|
||||
_error_handlers: Optional[List["ErrorHandlerData"]] = None
|
||||
_jobs: Optional[List[Job]] = None
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError("No application was set for this Plugin.")
|
||||
return self._application
|
||||
|
||||
@property
|
||||
def handlers(self) -> List[HandlerType]:
|
||||
"""该插件的所有 handler"""
|
||||
with self._lock:
|
||||
if self._handlers is None:
|
||||
self._handlers = []
|
||||
|
||||
for attr in dir(self):
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||||
):
|
||||
for data in datas:
|
||||
data: "HandlerData"
|
||||
if data.admin:
|
||||
self._handlers.append(
|
||||
AdminHandler(
|
||||
handler=data.type(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
),
|
||||
application=self.application,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handlers.append(
|
||||
data.type(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
)
|
||||
)
|
||||
return self._handlers
|
||||
|
||||
@property
|
||||
def error_handlers(self) -> List["ErrorHandlerData"]:
|
||||
with self._lock:
|
||||
if self._error_handlers is None:
|
||||
self._error_handlers = []
|
||||
for attr in dir(self):
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, []))
|
||||
):
|
||||
for data in datas:
|
||||
data: "ErrorHandlerData"
|
||||
data.func = func
|
||||
self._error_handlers.append(data)
|
||||
|
||||
return self._error_handlers
|
||||
|
||||
def _install_jobs(self) -> None:
|
||||
if self._jobs is None:
|
||||
self._jobs = []
|
||||
for attr in dir(self):
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _JOB_ATTR_NAME, []))
|
||||
):
|
||||
for data in datas:
|
||||
data: "JobData"
|
||||
self._jobs.append(
|
||||
getattr(self.application.telegram.job_queue, data.type)(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
**{
|
||||
key: value
|
||||
for key, value in asdict(data).items()
|
||||
if key not in ["type", "kwargs", "dispatcher"]
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def jobs(self) -> List[Job]:
|
||||
with self._lock:
|
||||
if self._jobs is None:
|
||||
self._jobs = []
|
||||
self._install_jobs()
|
||||
return self._jobs
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化插件"""
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""销毁插件"""
|
||||
|
||||
async def install(self) -> None:
|
||||
"""安装"""
|
||||
group = id(self)
|
||||
if not self._installed:
|
||||
await self.initialize()
|
||||
# initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题
|
||||
async with self._asyncio_lock:
|
||||
self._install_jobs()
|
||||
|
||||
for h in self.handlers:
|
||||
if not isinstance(h, TypeHandler):
|
||||
self.application.telegram.add_handler(h, group)
|
||||
else:
|
||||
self.application.telegram.add_handler(h, -1)
|
||||
|
||||
for h in self.error_handlers:
|
||||
self.application.telegram.add_error_handler(h.func, h.block)
|
||||
self._installed = True
|
||||
|
||||
async def uninstall(self) -> None:
|
||||
"""卸载"""
|
||||
group = id(self)
|
||||
|
||||
with self._lock:
|
||||
if self._installed:
|
||||
if group in self.application.telegram.handlers:
|
||||
del self.application.telegram.handlers[id(self)]
|
||||
|
||||
for h in self.handlers:
|
||||
if isinstance(h, TypeHandler):
|
||||
self.application.telegram.remove_handler(h, -1)
|
||||
for h in self.error_handlers:
|
||||
self.application.telegram.remove_error_handler(h.func)
|
||||
|
||||
for j in self.application.telegram.job_queue.jobs():
|
||||
j.schedule_removal()
|
||||
await self.shutdown()
|
||||
self._installed = False
|
||||
|
||||
async def reload(self) -> None:
|
||||
await self.uninstall()
|
||||
await self.install()
|
||||
|
||||
|
||||
class _Conversation(_Plugin, ConversationFuncs, ABC):
|
||||
"""Conversation类"""
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
class Config(BaseModel):
|
||||
allow_reentry: bool = False
|
||||
per_chat: bool = True
|
||||
per_user: bool = True
|
||||
per_message: bool = False
|
||||
conversation_timeout: Optional[Union[float, timedelta]] = None
|
||||
name: Optional[str] = None
|
||||
map_to_parent: Optional[Dict[object, object]] = None
|
||||
block: bool = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
cls._conversation_kwargs = kwargs
|
||||
super(_Conversation, cls).__init_subclass__()
|
||||
return cls
|
||||
|
||||
@property
|
||||
def handlers(self) -> List[HandlerType]:
|
||||
with self._lock:
|
||||
if self._handlers is None:
|
||||
self._handlers = []
|
||||
|
||||
entry_points: List[HandlerType] = []
|
||||
states: Dict[Any, List[HandlerType]] = {}
|
||||
fallbacks: List[HandlerType] = []
|
||||
for attr in dir(self):
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and (func := getattr(self, attr, None)) is not None
|
||||
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||||
):
|
||||
conversation_data: "ConversationData"
|
||||
|
||||
handlers: List[HandlerType] = []
|
||||
for data in datas:
|
||||
if data.admin:
|
||||
handlers.append(
|
||||
AdminHandler(
|
||||
handler=data.type(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
),
|
||||
application=self.application,
|
||||
)
|
||||
)
|
||||
else:
|
||||
handlers.append(
|
||||
data.type(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None):
|
||||
if (_type := conversation_data.type) is ConversationDataType.Entry:
|
||||
entry_points.extend(handlers)
|
||||
elif _type is ConversationDataType.State:
|
||||
if conversation_data.state in states:
|
||||
states[conversation_data.state].extend(handlers)
|
||||
else:
|
||||
states[conversation_data.state] = handlers
|
||||
elif _type is ConversationDataType.Fallback:
|
||||
fallbacks.extend(handlers)
|
||||
else:
|
||||
self._handlers.extend(handlers)
|
||||
else:
|
||||
self._handlers.extend(handlers)
|
||||
if entry_points and states and fallbacks:
|
||||
kwargs = self._conversation_kwargs
|
||||
kwargs.update(self.Config().dict())
|
||||
self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs))
|
||||
else:
|
||||
temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks}
|
||||
reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items()))
|
||||
logger.warning(
|
||||
"'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason)
|
||||
)
|
||||
return self._handlers
|
||||
|
||||
|
||||
class Plugin(_Plugin, ABC):
|
||||
"""插件"""
|
||||
|
||||
Conversation = _Conversation
|
||||
|
||||
|
||||
PluginType = TypeVar("PluginType", bound=_Plugin)
|
||||
|
||||
|
||||
def get_all_plugins() -> Iterable[Type[PluginType]]:
|
||||
"""获取所有 Plugin 的子类"""
|
||||
return filter(
|
||||
lambda x: x.__name__[0] != "_" and not isabstract(x),
|
||||
chain(Plugin.__subclasses__(), _Conversation.__subclasses__()),
|
||||
)
|
@ -1,67 +0,0 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import Callable, Coroutine, Any, Union, List, Dict, Optional, TypeVar, Type
|
||||
|
||||
from telegram.error import RetryAfter
|
||||
from telegram.ext import BaseRateLimiter, ApplicationHandlerStop
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
JSONDict: Type[dict[str, Any]] = Dict[str, Any]
|
||||
RL_ARGS = TypeVar("RL_ARGS")
|
||||
|
||||
|
||||
class RateLimiter(BaseRateLimiter[int]):
|
||||
_lock = asyncio.Lock()
|
||||
__slots__ = (
|
||||
"_limiter_info",
|
||||
"_retry_after_event",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self._limiter_info: Dict[Union[str, int], float] = {}
|
||||
self._retry_after_event = asyncio.Event()
|
||||
self._retry_after_event.set()
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, List[JSONDict]]]],
|
||||
args: Any,
|
||||
kwargs: Dict[str, Any],
|
||||
endpoint: str,
|
||||
data: Dict[str, Any],
|
||||
rate_limit_args: Optional[RL_ARGS],
|
||||
) -> Union[bool, JSONDict, List[JSONDict]]:
|
||||
chat_id = data.get("chat_id")
|
||||
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
chat_id = int(chat_id)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
time = loop.time()
|
||||
|
||||
await self._retry_after_event.wait()
|
||||
|
||||
async with self._lock:
|
||||
chat_limit_time = self._limiter_info.get(chat_id)
|
||||
if chat_limit_time:
|
||||
if time >= chat_limit_time:
|
||||
raise ApplicationHandlerStop
|
||||
del self._limiter_info[chat_id]
|
||||
|
||||
try:
|
||||
return await callback(*args, **kwargs)
|
||||
except RetryAfter as exc:
|
||||
logger.warning("chat_id[%s] 触发洪水限制 当前被服务器限制 retry_after[%s]秒", chat_id, exc.retry_after)
|
||||
self._limiter_info[chat_id] = time + (exc.retry_after * 2)
|
||||
sleep = exc.retry_after + 0.1
|
||||
self._retry_after_event.clear()
|
||||
await asyncio.sleep(sleep)
|
||||
finally:
|
||||
self._retry_after_event.set()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
@ -1,97 +1,3 @@
|
||||
from typing import List, Union
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from core.services.cookies.error import CookiesCachePoolExhausted
|
||||
from utils.error import RegionNotFoundError
|
||||
from gram_core.services.cookies.cache import PublicCookiesCache
|
||||
|
||||
__all__ = ("PublicCookiesCache",)
|
||||
|
||||
|
||||
class PublicCookiesCache(BaseService.Component):
|
||||
"""使用优先级(score)进行排序,对使用次数最少的Cookies进行审核"""
|
||||
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.score_qname = "cookie:public"
|
||||
self.user_times_qname = "cookie:public:times"
|
||||
self.end = 20
|
||||
self.user_times_ttl = 60 * 60 * 24
|
||||
|
||||
def get_public_cookies_queue_name(self, region: RegionEnum):
|
||||
if region == RegionEnum.HYPERION:
|
||||
return f"{self.score_qname}:yuanshen"
|
||||
if region == RegionEnum.HOYOLAB:
|
||||
return f"{self.score_qname}:genshin"
|
||||
raise RegionNotFoundError(region.name)
|
||||
|
||||
async def putback_public_cookies(self, uid: int, region: RegionEnum):
|
||||
"""重新添加单个到缓存列表
|
||||
:param uid:
|
||||
:param region:
|
||||
:return:
|
||||
"""
|
||||
qname = self.get_public_cookies_queue_name(region)
|
||||
score_maps = {f"{uid}": 0}
|
||||
result = await self.client.zrem(qname, f"{uid}")
|
||||
if result == 1:
|
||||
await self.client.zadd(qname, score_maps)
|
||||
return result
|
||||
|
||||
async def add_public_cookies(self, uid: Union[List[int], int], region: RegionEnum):
|
||||
"""单个或批量添加到缓存列表
|
||||
:param uid:
|
||||
:param region:
|
||||
:return: 成功返回列表大小
|
||||
"""
|
||||
qname = self.get_public_cookies_queue_name(region)
|
||||
if isinstance(uid, int):
|
||||
score_maps = {f"{uid}": 0}
|
||||
elif isinstance(uid, list):
|
||||
score_maps = {f"{i}": 0 for i in uid}
|
||||
else:
|
||||
raise TypeError("uid variable type error")
|
||||
async with self.client.pipeline(transaction=True) as pipe:
|
||||
# nx:只添加新元素。不要更新已经存在的元素
|
||||
await pipe.zadd(qname, score_maps, nx=True)
|
||||
await pipe.zcard(qname)
|
||||
add, count = await pipe.execute()
|
||||
return int(add), count
|
||||
|
||||
async def get_public_cookies(self, region: RegionEnum):
|
||||
"""从缓存列表获取
|
||||
:param region:
|
||||
:return:
|
||||
"""
|
||||
qname = self.get_public_cookies_queue_name(region)
|
||||
scores = await self.client.zrange(qname, 0, self.end, withscores=True, score_cast_func=int)
|
||||
if len(scores) <= 0:
|
||||
raise CookiesCachePoolExhausted
|
||||
key = scores[0][0]
|
||||
score = scores[0][1]
|
||||
async with self.client.pipeline(transaction=True) as pipe:
|
||||
await pipe.zincrby(qname, 1, key)
|
||||
await pipe.execute()
|
||||
return int(key), score + 1
|
||||
|
||||
async def delete_public_cookies(self, uid: int, region: RegionEnum):
|
||||
qname = self.get_public_cookies_queue_name(region)
|
||||
async with self.client.pipeline(transaction=True) as pipe:
|
||||
await pipe.zrem(qname, uid)
|
||||
return await pipe.execute()
|
||||
|
||||
async def get_public_cookies_count(self, limit: bool = True):
|
||||
async with self.client.pipeline(transaction=True) as pipe:
|
||||
if limit:
|
||||
await pipe.zcount(0, self.end)
|
||||
else:
|
||||
await pipe.zcard(self.score_qname)
|
||||
return await pipe.execute()
|
||||
|
||||
async def incr_by_user_times(self, user_id: Union[List[int], int], amount: int = 1):
|
||||
qname = f"{self.user_times_qname}:{user_id}"
|
||||
times = await self.client.incrby(qname, amount)
|
||||
if times <= 1:
|
||||
await self.client.expire(qname, self.user_times_ttl)
|
||||
return times
|
||||
|
@ -1,12 +1,3 @@
|
||||
class CookieServiceError(Exception):
|
||||
pass
|
||||
from gram_core.services.cookies.error import CookieServiceError, CookiesCachePoolExhausted, TooManyRequestPublicCookies
|
||||
|
||||
|
||||
class CookiesCachePoolExhausted(CookieServiceError):
|
||||
def __init__(self):
|
||||
super().__init__("Cookies cache pool is exhausted")
|
||||
|
||||
|
||||
class TooManyRequestPublicCookies(CookieServiceError):
|
||||
def __init__(self, user_id):
|
||||
super().__init__(f"{user_id} too many request public cookies")
|
||||
__all__ = ("CookieServiceError", "CookiesCachePoolExhausted", "TooManyRequestPublicCookies")
|
||||
|
@ -1,39 +1,3 @@
|
||||
import enum
|
||||
from typing import Optional, Dict
|
||||
|
||||
from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index
|
||||
|
||||
from core.basemodel import RegionEnum
|
||||
from gram_core.services.cookies.models import Cookies, CookiesDataBase, CookiesStatusEnum
|
||||
|
||||
__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum")
|
||||
|
||||
|
||||
class CookiesStatusEnum(int, enum.Enum):
|
||||
STATUS_SUCCESS = 0
|
||||
INVALID_COOKIES = 1
|
||||
TOO_MANY_REQUESTS = 2
|
||||
|
||||
|
||||
class Cookies(SQLModel):
|
||||
__table_args__ = (
|
||||
Index("index_user_account", "user_id", "account_id", unique=True),
|
||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||
)
|
||||
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger()),
|
||||
)
|
||||
account_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger(),
|
||||
),
|
||||
)
|
||||
data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
|
||||
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
|
||||
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
|
||||
is_share: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||
|
||||
|
||||
class CookiesDataBase(Cookies, table=True):
|
||||
__tablename__ = "cookies"
|
||||
|
@ -1,55 +1,3 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.dependence.database import Database
|
||||
from core.services.cookies.models import CookiesDataBase as Cookies
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
from gram_core.services.cookies.repositories import CookiesRepository
|
||||
|
||||
__all__ = ("CookiesRepository",)
|
||||
|
||||
|
||||
class CookiesRepository(BaseService.Component):
|
||||
def __init__(self, database: Database):
|
||||
self.engine = database.engine
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
account_id: Optional[int] = None,
|
||||
region: Optional[RegionEnum] = None,
|
||||
) -> Optional[Cookies]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Cookies).where(Cookies.user_id == user_id)
|
||||
if account_id is not None:
|
||||
statement = statement.where(Cookies.account_id == account_id)
|
||||
if region is not None:
|
||||
statement = statement.where(Cookies.region == region)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, cookies: Cookies) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(cookies)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, cookies: Cookies) -> Cookies:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(cookies)
|
||||
await session.commit()
|
||||
await session.refresh(cookies)
|
||||
return cookies
|
||||
|
||||
async def delete(self, cookies: Cookies) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(cookies)
|
||||
await session.commit()
|
||||
|
||||
async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Cookies).where(Cookies.region == region)
|
||||
results = await session.exec(statement)
|
||||
cookies = results.all()
|
||||
return cookies
|
||||
|
@ -1,90 +1,23 @@
|
||||
from typing import List, Optional
|
||||
from gram_core.base_service import BaseService
|
||||
from gram_core.basemodel import RegionEnum
|
||||
from gram_core.services.cookies.error import CookieServiceError
|
||||
from gram_core.services.cookies.models import CookiesStatusEnum, CookiesDataBase as Cookies
|
||||
from gram_core.services.cookies.services import (
|
||||
CookiesService,
|
||||
PublicCookiesService as BasePublicCookiesService,
|
||||
NeedContinue,
|
||||
)
|
||||
|
||||
from simnet import GenshinClient, Region, Game
|
||||
from simnet.errors import InvalidCookies, BadRequest as SimnetBadRequest, TooManyRequests
|
||||
from simnet.errors import InvalidCookies, TooManyRequests, BadRequest as SimnetBadRequest
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.services.cookies.cache import PublicCookiesCache
|
||||
from core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies
|
||||
from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum
|
||||
from core.services.cookies.repositories import CookiesRepository
|
||||
from utils.log import logger
|
||||
|
||||
__all__ = ("CookiesService", "PublicCookiesService")
|
||||
|
||||
|
||||
class CookiesService(BaseService):
|
||||
def __init__(self, cookies_repository: CookiesRepository) -> None:
|
||||
self._repository: CookiesRepository = cookies_repository
|
||||
|
||||
async def update(self, cookies: Cookies):
|
||||
await self._repository.update(cookies)
|
||||
|
||||
async def add(self, cookies: Cookies):
|
||||
await self._repository.add(cookies)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
account_id: Optional[int] = None,
|
||||
region: Optional[RegionEnum] = None,
|
||||
) -> Optional[Cookies]:
|
||||
return await self._repository.get(user_id, account_id, region)
|
||||
|
||||
async def delete(self, cookies: Cookies) -> None:
|
||||
return await self._repository.delete(cookies)
|
||||
|
||||
|
||||
class PublicCookiesService(BaseService):
|
||||
def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache):
|
||||
self._cache = public_cookies_cache
|
||||
self._repository: CookiesRepository = cookies_repository
|
||||
self.count: int = 0
|
||||
self.user_times_limiter = 3 * 3
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info("正在初始化公共Cookies池")
|
||||
await self.refresh()
|
||||
logger.success("刷新公共Cookies池成功")
|
||||
|
||||
async def refresh(self):
|
||||
"""刷新公共Cookies 定时任务
|
||||
:return:
|
||||
"""
|
||||
user_list: List[int] = []
|
||||
cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2
|
||||
for cookies in cookies_list:
|
||||
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
||||
user_list.append(cookies.user_id)
|
||||
if len(user_list) > 0:
|
||||
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION)
|
||||
logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
|
||||
user_list.clear()
|
||||
cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB)
|
||||
for cookies in cookies_list:
|
||||
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
||||
user_list.append(cookies.user_id)
|
||||
if len(user_list) > 0:
|
||||
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB)
|
||||
logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
|
||||
|
||||
async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL):
|
||||
"""获取公共Cookies
|
||||
:param user_id: 用户ID
|
||||
:param region: 注册的服务器
|
||||
:return:
|
||||
"""
|
||||
user_times = await self._cache.incr_by_user_times(user_id)
|
||||
if int(user_times) > self.user_times_limiter:
|
||||
logger.warning("用户 %s 使用公共Cookies次数已经到达上限", user_id)
|
||||
raise TooManyRequestPublicCookies(user_id)
|
||||
while True:
|
||||
public_id, count = await self._cache.get_public_cookies(region)
|
||||
cookies = await self._repository.get(public_id, region=region)
|
||||
if cookies is None:
|
||||
await self._cache.delete_public_cookies(public_id, region)
|
||||
continue
|
||||
class PublicCookiesService(BaseService, BasePublicCookiesService):
|
||||
async def check_public_cookie(self, region: RegionEnum, cookies: Cookies, public_id: int):
|
||||
if region == RegionEnum.HYPERION:
|
||||
client = GenshinClient(cookies=cookies.data, region=Region.CHINESE)
|
||||
elif region == RegionEnum.HOYOLAB:
|
||||
@ -116,13 +49,13 @@ class PublicCookiesService(BaseService):
|
||||
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
||||
await self._repository.update(cookies)
|
||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||
continue
|
||||
raise NeedContinue
|
||||
except TooManyRequests:
|
||||
logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id)
|
||||
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS
|
||||
await self._repository.update(cookies)
|
||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||
continue
|
||||
raise NeedContinue
|
||||
except SimnetBadRequest as exc:
|
||||
if "invalid content type" in exc.message:
|
||||
raise exc
|
||||
@ -132,28 +65,16 @@ class PublicCookiesService(BaseService):
|
||||
logger.warning("用户 [%s] 获取账号信息发生错误,错误信息为", public_id)
|
||||
logger.exception(exc)
|
||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||
continue
|
||||
raise NeedContinue
|
||||
except RuntimeError as exc:
|
||||
if "account_id not found" in str(exc):
|
||||
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
||||
await self._repository.update(cookies)
|
||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||
continue
|
||||
raise NeedContinue
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||
raise exc
|
||||
finally:
|
||||
await client.shutdown()
|
||||
logger.info("用户 user_id[%s] 请求用户 user_id[%s] 的公共Cookies 该Cookies使用次数为%s次 ", user_id, public_id, count)
|
||||
return cookies
|
||||
|
||||
async def undo(self, user_id: int, cookies: Optional[Cookies] = None, status: Optional[CookiesStatusEnum] = None):
|
||||
await self._cache.incr_by_user_times(user_id, -1)
|
||||
if cookies is not None and status is not None:
|
||||
cookies.status = status
|
||||
await self._repository.update(cookies)
|
||||
await self._cache.delete_public_cookies(cookies.user_id, cookies.region)
|
||||
logger.info("用户 user_id[%s] 反馈用户 user_id[%s] 的Cookies状态为 %s", user_id, cookies.user_id, status.name)
|
||||
else:
|
||||
logger.info("用户 user_id[%s] 撤销一次公共Cookies计数", user_id)
|
||||
|
@ -1,23 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field, Column, Integer, BigInteger
|
||||
from gram_core.services.devices.models import Devices, DevicesDataBase
|
||||
|
||||
__all__ = ("Devices", "DevicesDataBase")
|
||||
|
||||
|
||||
class Devices(SQLModel):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
|
||||
account_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger(),
|
||||
),
|
||||
)
|
||||
device_id: str = Field()
|
||||
device_fp: str = Field()
|
||||
device_name: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class DevicesDataBase(Devices, table=True):
|
||||
__tablename__ = "devices"
|
||||
|
@ -1,41 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.database import Database
|
||||
from core.services.devices.models import DevicesDataBase as Devices
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
from gram_core.services.devices.repositories import DevicesRepository
|
||||
|
||||
__all__ = ("DevicesRepository",)
|
||||
|
||||
|
||||
class DevicesRepository(BaseService.Component):
|
||||
def __init__(self, database: Database):
|
||||
self.engine = database.engine
|
||||
|
||||
async def get(
|
||||
self,
|
||||
account_id: int,
|
||||
) -> Optional[Devices]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Devices).where(Devices.account_id == account_id)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, devices: Devices) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(devices)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, devices: Devices) -> Devices:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(devices)
|
||||
await session.commit()
|
||||
await session.refresh(devices)
|
||||
return devices
|
||||
|
||||
async def delete(self, devices: Devices) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(devices)
|
||||
await session.commit()
|
||||
|
@ -1,25 +1,3 @@
|
||||
from typing import Optional
|
||||
from gram_core.services.devices.services import DevicesService
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.services.devices.repositories import DevicesRepository
|
||||
from core.services.devices.models import DevicesDataBase as Devices
|
||||
|
||||
|
||||
class DevicesService(BaseService):
|
||||
def __init__(self, devices_repository: DevicesRepository) -> None:
|
||||
self._repository: DevicesRepository = devices_repository
|
||||
|
||||
async def update(self, devices: Devices):
|
||||
await self._repository.update(devices)
|
||||
|
||||
async def add(self, devices: Devices):
|
||||
await self._repository.add(devices)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
account_id: int,
|
||||
) -> Optional[Devices]:
|
||||
return await self._repository.get(account_id)
|
||||
|
||||
async def delete(self, devices: Devices) -> None:
|
||||
return await self._repository.delete(devices)
|
||||
__all__ = ("DevicesService",)
|
||||
|
@ -1,2 +1,3 @@
|
||||
class PlayerNotFoundError(Exception):
|
||||
pass
|
||||
from gram_core.services.players.error import PlayerNotFoundError
|
||||
|
||||
__all__ = ("PlayerNotFoundError",)
|
||||
|
@ -1,96 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, BaseSettings
|
||||
from sqlalchemy import TypeDecorator
|
||||
from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime
|
||||
|
||||
from core.basemodel import RegionEnum
|
||||
|
||||
try:
|
||||
import ujson as jsonlib
|
||||
except ImportError:
|
||||
import json as jsonlib
|
||||
from gram_core.services.players.models import (
|
||||
Player,
|
||||
PlayersDataBase,
|
||||
PlayerInfo,
|
||||
PlayerInfoSQLModel,
|
||||
)
|
||||
|
||||
__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel")
|
||||
|
||||
|
||||
class Player(SQLModel):
|
||||
__table_args__ = (
|
||||
Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
|
||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||
)
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||
)
|
||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
account_id: int = Field(default=None, primary_key=True, sa_column=Column(BigInteger()))
|
||||
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
|
||||
is_chosen: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||
|
||||
|
||||
class PlayersDataBase(Player, table=True):
|
||||
__tablename__ = "players"
|
||||
|
||||
|
||||
class ExtraPlayerInfo(BaseModel):
|
||||
class Config(BaseSettings.Config):
|
||||
json_loads = jsonlib.loads
|
||||
json_dumps = jsonlib.dumps
|
||||
|
||||
waifu_id: Optional[int] = None
|
||||
|
||||
|
||||
class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223
|
||||
impl = VARCHAR(length=521)
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""
|
||||
:param value: ExtraPlayerInfo | obj | None
|
||||
:param dialect:
|
||||
:return:
|
||||
"""
|
||||
if value is not None:
|
||||
if isinstance(value, ExtraPlayerInfo):
|
||||
return value.json()
|
||||
raise TypeError
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""
|
||||
:param value: str | obj | None
|
||||
:param dialect:
|
||||
:return:
|
||||
"""
|
||||
if value is not None:
|
||||
return ExtraPlayerInfo.parse_raw(value)
|
||||
return None
|
||||
|
||||
|
||||
class PlayerInfo(SQLModel):
|
||||
__table_args__ = (
|
||||
Index("index_user_account_player", "user_id", "player_id", unique=True),
|
||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||
)
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||
)
|
||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
nickname: Optional[str] = Field()
|
||||
signature: Optional[str] = Field()
|
||||
hand_image: Optional[int] = Field()
|
||||
name_card: Optional[int] = Field()
|
||||
extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType))
|
||||
create_time: Optional[datetime] = Field(
|
||||
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
|
||||
)
|
||||
last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
|
||||
is_update: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||
|
||||
|
||||
class PlayerInfoSQLModel(PlayerInfo, table=True):
|
||||
__tablename__ = "players_info"
|
||||
|
@ -1,110 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select, delete
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.dependence.database import Database
|
||||
from core.services.players.models import PlayerInfoSQLModel
|
||||
from core.services.players.models import PlayersDataBase as Player
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
from gram_core.services.players.repositories import PlayersRepository, PlayerInfoRepository
|
||||
|
||||
__all__ = ("PlayersRepository", "PlayerInfoRepository")
|
||||
|
||||
|
||||
class PlayersRepository(BaseService.Component):
|
||||
def __init__(self, database: Database):
|
||||
self.engine = database.engine
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
player_id: Optional[int] = None,
|
||||
account_id: Optional[int] = None,
|
||||
region: Optional[RegionEnum] = None,
|
||||
is_chosen: Optional[bool] = None,
|
||||
) -> Optional[Player]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Player).where(Player.user_id == user_id)
|
||||
if player_id is not None:
|
||||
statement = statement.where(Player.player_id == player_id)
|
||||
if account_id is not None:
|
||||
statement = statement.where(Player.account_id == account_id)
|
||||
if region is not None:
|
||||
statement = statement.where(Player.region == region)
|
||||
if is_chosen is not None:
|
||||
statement = statement.where(Player.is_chosen == is_chosen)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, player: Player) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
await session.refresh(player)
|
||||
|
||||
async def delete(self, player: Player) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(player)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, player: Player) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
await session.refresh(player)
|
||||
|
||||
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Player).where(Player.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
players = results.all()
|
||||
return players
|
||||
|
||||
|
||||
class PlayerInfoRepository(BaseService.Component):
|
||||
def __init__(self, database: Database):
|
||||
self.engine = database.engine
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
player_id: int,
|
||||
) -> Optional[PlayerInfoSQLModel]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = (
|
||||
select(PlayerInfoSQLModel)
|
||||
.where(PlayerInfoSQLModel.player_id == player_id)
|
||||
.where(PlayerInfoSQLModel.user_id == user_id)
|
||||
)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, player: PlayerInfoSQLModel) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, player: PlayerInfoSQLModel) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(player)
|
||||
await session.commit()
|
||||
|
||||
async def delete_by_id(
|
||||
self,
|
||||
user_id: int,
|
||||
player_id: int,
|
||||
) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = (
|
||||
delete(PlayerInfoSQLModel)
|
||||
.where(PlayerInfoSQLModel.player_id == player_id)
|
||||
.where(PlayerInfoSQLModel.user_id == user_id)
|
||||
)
|
||||
await session.execute(statement)
|
||||
|
||||
async def update(self, player: PlayerInfoSQLModel) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
await session.refresh(player)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from aiohttp import ClientConnectorError
|
||||
from enkanetwork import (
|
||||
@ -11,53 +11,19 @@ from enkanetwork import (
|
||||
)
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.config import config
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo
|
||||
from core.services.players.repositories import PlayersRepository, PlayerInfoRepository
|
||||
from core.services.players.repositories import PlayerInfoRepository
|
||||
from utils.enkanetwork import RedisCache
|
||||
from utils.log import logger
|
||||
from utils.patch.aiohttp import AioHttpTimeoutException
|
||||
|
||||
from gram_core.services.players.services import PlayersService
|
||||
|
||||
__all__ = ("PlayersService", "PlayerInfoService")
|
||||
|
||||
|
||||
class PlayersService(BaseService):
|
||||
def __init__(self, players_repository: PlayersRepository) -> None:
|
||||
self._repository = players_repository
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
player_id: Optional[int] = None,
|
||||
account_id: Optional[int] = None,
|
||||
region: Optional[RegionEnum] = None,
|
||||
is_chosen: Optional[bool] = None,
|
||||
) -> Optional[Player]:
|
||||
return await self._repository.get(user_id, player_id, account_id, region, is_chosen)
|
||||
|
||||
async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]:
|
||||
return await self._repository.get(user_id, region=region, is_chosen=True)
|
||||
|
||||
async def add(self, player: Player) -> None:
|
||||
await self._repository.add(player)
|
||||
|
||||
async def update(self, player: Player) -> None:
|
||||
await self._repository.update(player)
|
||||
|
||||
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
|
||||
return await self._repository.get_all_by_user_id(user_id)
|
||||
|
||||
async def remove_all_by_user_id(self, user_id: int):
|
||||
players = await self._repository.get_all_by_user_id(user_id)
|
||||
for player in players:
|
||||
await self._repository.delete(player)
|
||||
|
||||
async def delete(self, player: Player):
|
||||
await self._repository.delete(player)
|
||||
|
||||
|
||||
class PlayerInfoService(BaseService):
|
||||
def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository):
|
||||
self.cache = redis.client
|
||||
|
@ -1,44 +1,3 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from sqlalchemy import func, BigInteger, JSON
|
||||
from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer
|
||||
from gram_core.services.task.models import Task, TaskStatusEnum, TaskTypeEnum
|
||||
|
||||
__all__ = ("Task", "TaskStatusEnum", "TaskTypeEnum")
|
||||
|
||||
|
||||
class TaskStatusEnum(int, enum.Enum):
|
||||
STATUS_SUCCESS = 0 # 任务执行成功
|
||||
INVALID_COOKIES = 1 # Cookie无效
|
||||
ALREADY_CLAIMED = 2 # 已经获取奖励
|
||||
NEED_CHALLENGE = 3 # 需要验证码
|
||||
GENSHIN_EXCEPTION = 4 # API异常
|
||||
TIMEOUT_ERROR = 5 # 请求超时
|
||||
BAD_REQUEST = 6 # 请求失败
|
||||
FORBIDDEN = 7 # 这错误一般为通知失败 机器人被用户BAN
|
||||
|
||||
|
||||
class TaskTypeEnum(int, enum.Enum):
|
||||
SIGN = 0 # 签到
|
||||
RESIN = 1 # 体力
|
||||
REALM = 2 # 洞天宝钱
|
||||
EXPEDITION = 3 # 委托
|
||||
TRANSFORMER = 4 # 参量质变仪
|
||||
CARD = 5 # 生日画片
|
||||
|
||||
|
||||
class Task(SQLModel, table=True):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||
)
|
||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True))
|
||||
chat_id: Optional[int] = Field(default=None, sa_column=Column(BigInteger()))
|
||||
time_created: Optional[datetime] = Field(
|
||||
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
|
||||
)
|
||||
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
|
||||
type: TaskTypeEnum = Field(primary_key=True, sa_column=Column(Enum(TaskTypeEnum)))
|
||||
status: Optional[TaskStatusEnum] = Field(sa_column=Column(Enum(TaskStatusEnum)))
|
||||
data: Optional[Dict[str, Any]] = Field(sa_column=Column(JSON))
|
||||
|
@ -1,50 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.database import Database
|
||||
from core.services.task.models import Task, TaskTypeEnum
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
from gram_core.services.task.repositories import TaskRepository
|
||||
|
||||
__all__ = ("TaskRepository",)
|
||||
|
||||
|
||||
class TaskRepository(BaseService.Component):
|
||||
def __init__(self, database: Database):
|
||||
self.engine = database.engine
|
||||
|
||||
async def add(self, task: Task):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(task)
|
||||
await session.commit()
|
||||
|
||||
async def remove(self, task: Task):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(task)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, task: Task) -> Task:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(task)
|
||||
await session.commit()
|
||||
await session.refresh(task)
|
||||
return task
|
||||
|
||||
async def get_by_user_id(self, user_id: int, task_type: TaskTypeEnum) -> Optional[Task]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Task).where(Task.user_id == user_id).where(Task.type == task_type)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def get_by_chat_id(self, chat_id: int, task_type: TaskTypeEnum) -> Optional[List[Task]]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Task).where(Task.chat_id == chat_id).where(Task.type == task_type)
|
||||
results = await session.exec(statement)
|
||||
return results.all()
|
||||
|
||||
async def get_all(self, task_type: TaskTypeEnum) -> List[Task]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
query = select(Task).where(Task.type == task_type)
|
||||
results = await session.exec(query)
|
||||
return results.all()
|
||||
|
@ -1,9 +1,11 @@
|
||||
import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.services.task.models import Task, TaskTypeEnum
|
||||
from core.services.task.repositories import TaskRepository
|
||||
from gram_core.services.task.services import (
|
||||
TaskServices,
|
||||
SignServices,
|
||||
TaskCardServices,
|
||||
TaskResinServices,
|
||||
TaskRealmServices,
|
||||
TaskExpeditionServices,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TaskServices",
|
||||
@ -13,201 +15,3 @@ __all__ = [
|
||||
"TaskRealmServices",
|
||||
"TaskExpeditionServices",
|
||||
]
|
||||
|
||||
|
||||
class TaskServices(BaseService):
|
||||
TASK_TYPE: TaskTypeEnum
|
||||
|
||||
def __init__(self, task_repository: TaskRepository) -> None:
|
||||
self._repository: TaskRepository = task_repository
|
||||
|
||||
async def add(self, task: Task):
|
||||
return await self._repository.add(task)
|
||||
|
||||
async def remove(self, task: Task):
|
||||
return await self._repository.remove(task)
|
||||
|
||||
async def update(self, task: Task):
|
||||
task.time_updated = datetime.datetime.now()
|
||||
return await self._repository.update(task)
|
||||
|
||||
async def get_by_user_id(self, user_id: int):
|
||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
||||
|
||||
async def get_all(self):
|
||||
return await self._repository.get_all(self.TASK_TYPE)
|
||||
|
||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
||||
return Task(
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
time_created=datetime.datetime.now(),
|
||||
status=status,
|
||||
type=self.TASK_TYPE,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
class SignServices(BaseService):
|
||||
TASK_TYPE = TaskTypeEnum.SIGN
|
||||
|
||||
def __init__(self, task_repository: TaskRepository) -> None:
|
||||
self._repository: TaskRepository = task_repository
|
||||
|
||||
async def add(self, task: Task):
|
||||
return await self._repository.add(task)
|
||||
|
||||
async def remove(self, task: Task):
|
||||
return await self._repository.remove(task)
|
||||
|
||||
async def update(self, task: Task):
|
||||
task.time_updated = datetime.datetime.now()
|
||||
return await self._repository.update(task)
|
||||
|
||||
async def get_by_user_id(self, user_id: int):
|
||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
||||
|
||||
async def get_all(self):
|
||||
return await self._repository.get_all(self.TASK_TYPE)
|
||||
|
||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
||||
return Task(
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
time_created=datetime.datetime.now(),
|
||||
status=status,
|
||||
type=self.TASK_TYPE,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
class TaskCardServices(BaseService):
|
||||
TASK_TYPE = TaskTypeEnum.CARD
|
||||
|
||||
def __init__(self, task_repository: TaskRepository) -> None:
|
||||
self._repository: TaskRepository = task_repository
|
||||
|
||||
async def add(self, task: Task):
|
||||
return await self._repository.add(task)
|
||||
|
||||
async def remove(self, task: Task):
|
||||
return await self._repository.remove(task)
|
||||
|
||||
async def update(self, task: Task):
|
||||
task.time_updated = datetime.datetime.now()
|
||||
return await self._repository.update(task)
|
||||
|
||||
async def get_by_user_id(self, user_id: int):
|
||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
||||
|
||||
async def get_all(self):
|
||||
return await self._repository.get_all(self.TASK_TYPE)
|
||||
|
||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
||||
return Task(
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
time_created=datetime.datetime.now(),
|
||||
status=status,
|
||||
type=self.TASK_TYPE,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
class TaskResinServices(BaseService):
|
||||
TASK_TYPE = TaskTypeEnum.RESIN
|
||||
|
||||
def __init__(self, task_repository: TaskRepository) -> None:
|
||||
self._repository: TaskRepository = task_repository
|
||||
|
||||
async def add(self, task: Task):
|
||||
return await self._repository.add(task)
|
||||
|
||||
async def remove(self, task: Task):
|
||||
return await self._repository.remove(task)
|
||||
|
||||
async def update(self, task: Task):
|
||||
task.time_updated = datetime.datetime.now()
|
||||
return await self._repository.update(task)
|
||||
|
||||
async def get_by_user_id(self, user_id: int):
|
||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
||||
|
||||
async def get_all(self):
|
||||
return await self._repository.get_all(self.TASK_TYPE)
|
||||
|
||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
||||
return Task(
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
time_created=datetime.datetime.now(),
|
||||
status=status,
|
||||
type=self.TASK_TYPE,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
class TaskRealmServices(BaseService):
|
||||
TASK_TYPE = TaskTypeEnum.REALM
|
||||
|
||||
def __init__(self, task_repository: TaskRepository) -> None:
|
||||
self._repository: TaskRepository = task_repository
|
||||
|
||||
async def add(self, task: Task):
|
||||
return await self._repository.add(task)
|
||||
|
||||
async def remove(self, task: Task):
|
||||
return await self._repository.remove(task)
|
||||
|
||||
async def update(self, task: Task):
|
||||
task.time_updated = datetime.datetime.now()
|
||||
return await self._repository.update(task)
|
||||
|
||||
async def get_by_user_id(self, user_id: int):
|
||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
||||
|
||||
async def get_all(self):
|
||||
return await self._repository.get_all(self.TASK_TYPE)
|
||||
|
||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
||||
return Task(
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
time_created=datetime.datetime.now(),
|
||||
status=status,
|
||||
type=self.TASK_TYPE,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
class TaskExpeditionServices(BaseService):
|
||||
TASK_TYPE = TaskTypeEnum.EXPEDITION
|
||||
|
||||
def __init__(self, task_repository: TaskRepository) -> None:
|
||||
self._repository: TaskRepository = task_repository
|
||||
|
||||
async def add(self, task: Task):
|
||||
return await self._repository.add(task)
|
||||
|
||||
async def remove(self, task: Task):
|
||||
return await self._repository.remove(task)
|
||||
|
||||
async def update(self, task: Task):
|
||||
task.time_updated = datetime.datetime.now()
|
||||
return await self._repository.update(task)
|
||||
|
||||
async def get_by_user_id(self, user_id: int):
|
||||
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
|
||||
|
||||
async def get_all(self):
|
||||
return await self._repository.get_all(self.TASK_TYPE)
|
||||
|
||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
||||
return Task(
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
time_created=datetime.datetime.now(),
|
||||
status=status,
|
||||
type=self.TASK_TYPE,
|
||||
data=data,
|
||||
)
|
||||
|
@ -1,58 +1,3 @@
|
||||
import gzip
|
||||
import pickle # nosec B403
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from gram_core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
|
||||
|
||||
__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"]
|
||||
|
||||
|
||||
class TemplatePreviewCache(BaseService.Component):
|
||||
"""暂存渲染模板的数据用于预览"""
|
||||
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.qname = "bot:template:preview"
|
||||
|
||||
async def get_data(self, key: str) -> Any:
|
||||
data = await self.client.get(self.cache_key(key))
|
||||
if data:
|
||||
# skipcq: BAN-B301
|
||||
return pickle.loads(gzip.decompress(data)) # nosec B301
|
||||
|
||||
async def set_data(self, key: str, data: Any, ttl: int = 8 * 60 * 60):
|
||||
ck = self.cache_key(key)
|
||||
await self.client.set(ck, gzip.compress(pickle.dumps(data)))
|
||||
if ttl != -1:
|
||||
await self.client.expire(ck, ttl)
|
||||
|
||||
def cache_key(self, key: str) -> str:
|
||||
return f"{self.qname}:{key}"
|
||||
|
||||
|
||||
class HtmlToFileIdCache(BaseService.Component):
|
||||
"""html to file_id 的缓存"""
|
||||
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.qname = "bot:template:html-to-file-id"
|
||||
|
||||
async def get_data(self, html: str, file_type: str) -> Optional[str]:
|
||||
data = await self.client.get(self.cache_key(html, file_type))
|
||||
if data:
|
||||
return data.decode()
|
||||
|
||||
async def set_data(self, html: str, file_type: str, file_id: str, ttl: int = 24 * 60 * 60):
|
||||
ck = self.cache_key(html, file_type)
|
||||
await self.client.set(ck, file_id)
|
||||
if ttl != -1:
|
||||
await self.client.expire(ck, ttl)
|
||||
|
||||
async def delete_data(self, html: str, file_type: str) -> bool:
|
||||
return await self.client.delete(self.cache_key(html, file_type))
|
||||
|
||||
def cache_key(self, html: str, file_type: str) -> str:
|
||||
key = sha256(html.encode()).hexdigest()
|
||||
return f"{self.qname}:{file_type}:{key}"
|
||||
|
@ -1,14 +1,8 @@
|
||||
class TemplateException(Exception):
|
||||
pass
|
||||
from gram_core.services.template.error import (
|
||||
ErrorFileType,
|
||||
FileIdNotFound,
|
||||
QuerySelectorNotFound,
|
||||
TemplateException,
|
||||
)
|
||||
|
||||
|
||||
class QuerySelectorNotFound(TemplateException):
|
||||
pass
|
||||
|
||||
|
||||
class ErrorFileType(TemplateException):
|
||||
pass
|
||||
|
||||
|
||||
class FileIdNotFound(TemplateException):
|
||||
pass
|
||||
__all__ = ("TemplateException", "QuerySelectorNotFound", "ErrorFileType", "FileIdNotFound")
|
||||
|
@ -1,146 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from telegram import InputMediaDocument, InputMediaPhoto, Message
|
||||
from telegram.error import BadRequest
|
||||
|
||||
from core.services.template.cache import HtmlToFileIdCache
|
||||
from core.services.template.error import ErrorFileType, FileIdNotFound
|
||||
from gram_core.services.template.models import FileType, RenderResult, RenderGroupResult
|
||||
|
||||
__all__ = ["FileType", "RenderResult", "RenderGroupResult"]
|
||||
|
||||
|
||||
class FileType(Enum):
|
||||
PHOTO = 1
|
||||
DOCUMENT = 2
|
||||
|
||||
@staticmethod
|
||||
def media_type(file_type: "FileType"):
|
||||
"""对应的 Telegram media 类型"""
|
||||
if file_type == FileType.PHOTO:
|
||||
return InputMediaPhoto
|
||||
if file_type == FileType.DOCUMENT:
|
||||
return InputMediaDocument
|
||||
raise ErrorFileType
|
||||
|
||||
|
||||
class RenderResult:
|
||||
"""渲染结果"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
html: str,
|
||||
photo: Union[bytes, str],
|
||||
file_type: FileType,
|
||||
cache: HtmlToFileIdCache,
|
||||
ttl: int = 24 * 60 * 60,
|
||||
caption: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
`html`: str 渲染生成的 html
|
||||
`photo`: Union[bytes, str] 渲染生成的图片。bytes 表示是图片,str 则为 file_id
|
||||
"""
|
||||
self.caption = caption
|
||||
self.parse_mode = parse_mode
|
||||
self.filename = filename
|
||||
self.html = html
|
||||
self.photo = photo
|
||||
self.file_type = file_type
|
||||
self._cache = cache
|
||||
self.ttl = ttl
|
||||
|
||||
async def reply_photo(self, message: Message, *args, **kwargs):
|
||||
"""是 `message.reply_photo` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用"""
|
||||
if self.file_type != FileType.PHOTO:
|
||||
raise ErrorFileType
|
||||
|
||||
try:
|
||||
reply = await message.reply_photo(photo=self.photo, *args, **kwargs)
|
||||
except BadRequest as exc:
|
||||
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
|
||||
await self._cache.delete_data(self.html, self.file_type.name)
|
||||
raise BadRequest(message="Wrong file identifier specified")
|
||||
raise exc
|
||||
|
||||
await self.cache_file_id(reply)
|
||||
|
||||
return reply
|
||||
|
||||
async def reply_document(self, message: Message, *args, **kwargs):
|
||||
"""是 `message.reply_document` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用"""
|
||||
if self.file_type != FileType.DOCUMENT:
|
||||
raise ErrorFileType
|
||||
|
||||
try:
|
||||
reply = await message.reply_document(document=self.photo, *args, **kwargs)
|
||||
except BadRequest as exc:
|
||||
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
|
||||
await self._cache.delete_data(self.html, self.file_type.name)
|
||||
raise BadRequest(message="Wrong file identifier specified")
|
||||
raise exc
|
||||
|
||||
await self.cache_file_id(reply)
|
||||
|
||||
return reply
|
||||
|
||||
async def edit_media(self, message: Message, *args, **kwargs):
|
||||
"""是 `message.edit_media` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用"""
|
||||
if self.file_type != FileType.PHOTO:
|
||||
raise ErrorFileType
|
||||
|
||||
media = InputMediaPhoto(
|
||||
media=self.photo, caption=self.caption, parse_mode=self.parse_mode, filename=self.filename
|
||||
)
|
||||
|
||||
try:
|
||||
edit_media = await message.edit_media(media, *args, **kwargs)
|
||||
except BadRequest as exc:
|
||||
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
|
||||
await self._cache.delete_data(self.html, self.file_type.name)
|
||||
raise BadRequest(message="Wrong file identifier specified")
|
||||
raise exc
|
||||
|
||||
await self.cache_file_id(edit_media)
|
||||
|
||||
return edit_media
|
||||
|
||||
async def cache_file_id(self, reply: Message):
|
||||
"""缓存 telegram 返回的 file_id"""
|
||||
if self.is_file_id():
|
||||
return
|
||||
|
||||
if self.file_type == FileType.PHOTO and reply.photo:
|
||||
file_id = reply.photo[0].file_id
|
||||
elif self.file_type == FileType.DOCUMENT and reply.document:
|
||||
file_id = reply.document.file_id
|
||||
else:
|
||||
raise FileIdNotFound
|
||||
await self._cache.set_data(self.html, self.file_type.name, file_id, self.ttl)
|
||||
|
||||
def is_file_id(self) -> bool:
|
||||
return isinstance(self.photo, str)
|
||||
|
||||
|
||||
class RenderGroupResult:
|
||||
def __init__(self, results: List[RenderResult]):
|
||||
self.results = results
|
||||
|
||||
async def reply_media_group(self, message: Message, *args, **kwargs):
|
||||
"""是 `message.reply_media_group` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用"""
|
||||
|
||||
reply = await message.reply_media_group(
|
||||
media=[
|
||||
FileType.media_type(result.file_type)(
|
||||
media=result.photo, caption=result.caption, parse_mode=result.parse_mode, filename=result.filename
|
||||
)
|
||||
for result in self.results
|
||||
],
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for index, value in enumerate(reply):
|
||||
result = self.results[index]
|
||||
await result.cache_file_id(value)
|
||||
|
@ -1,207 +1,3 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urljoin, urlsplit
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from jinja2 import Environment, FileSystemLoader, Template
|
||||
from playwright.async_api import ViewportSize
|
||||
|
||||
from core.application import Application
|
||||
from core.base_service import BaseService
|
||||
from core.config import config as application_config
|
||||
from core.dependence.aiobrowser import AioBrowser
|
||||
from core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
|
||||
from core.services.template.error import QuerySelectorNotFound
|
||||
from core.services.template.models import FileType, RenderResult
|
||||
from utils.const import PROJECT_ROOT
|
||||
from utils.log import logger
|
||||
from gram_core.services.template.services import TemplateService, TemplatePreviewer
|
||||
|
||||
__all__ = ("TemplateService", "TemplatePreviewer")
|
||||
|
||||
|
||||
class TemplateService(BaseService):
|
||||
def __init__(
|
||||
self,
|
||||
app: Application,
|
||||
browser: AioBrowser,
|
||||
html_to_file_id_cache: HtmlToFileIdCache,
|
||||
preview_cache: TemplatePreviewCache,
|
||||
template_dir: str = "resources",
|
||||
):
|
||||
self._browser = browser
|
||||
self.template_dir = PROJECT_ROOT / template_dir
|
||||
|
||||
self._jinja2_env = Environment(
|
||||
loader=FileSystemLoader(template_dir),
|
||||
enable_async=True,
|
||||
autoescape=True,
|
||||
auto_reload=application_config.debug,
|
||||
)
|
||||
self.using_preview = application_config.debug and application_config.webserver.enable
|
||||
|
||||
if self.using_preview:
|
||||
self.previewer = TemplatePreviewer(self, preview_cache, app.web_app)
|
||||
|
||||
self.html_to_file_id_cache = html_to_file_id_cache
|
||||
|
||||
def get_template(self, template_name: str) -> Template:
|
||||
return self._jinja2_env.get_template(template_name)
|
||||
|
||||
async def render_async(self, template_name: str, template_data: dict) -> str:
|
||||
"""模板渲染
|
||||
:param template_name: 模板文件名
|
||||
:param template_data: 模板数据
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = loop.time()
|
||||
template = self.get_template(template_name)
|
||||
html = await template.render_async(**template_data)
|
||||
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||
return html
|
||||
|
||||
async def render(
|
||||
self,
|
||||
template_name: str,
|
||||
template_data: dict,
|
||||
viewport: Optional[ViewportSize] = None,
|
||||
full_page: bool = True,
|
||||
evaluate: Optional[str] = None,
|
||||
query_selector: Optional[str] = None,
|
||||
file_type: FileType = FileType.PHOTO,
|
||||
ttl: int = 24 * 60 * 60,
|
||||
caption: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
) -> RenderResult:
|
||||
"""模板渲染成图片
|
||||
:param template_name: 模板文件名
|
||||
:param template_data: 模板数据
|
||||
:param viewport: 截图大小
|
||||
:param full_page: 是否长截图
|
||||
:param evaluate: 页面加载后运行的 js
|
||||
:param query_selector: 截图选择器
|
||||
:param file_type: 缓存的文件类型
|
||||
:param ttl: 缓存时间
|
||||
:param caption: 图片描述
|
||||
:param parse_mode: 图片描述解析模式
|
||||
:param filename: 文件名字
|
||||
:return:
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = loop.time()
|
||||
template = self.get_template(template_name)
|
||||
|
||||
if self.using_preview:
|
||||
preview_url = await self.previewer.get_preview_url(template_name, template_data)
|
||||
logger.debug("调试模板 URL: \n%s", preview_url)
|
||||
|
||||
html = await template.render_async(**template_data)
|
||||
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||
|
||||
file_id = await self.html_to_file_id_cache.get_data(html, file_type.name)
|
||||
if file_id and not application_config.debug:
|
||||
logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id)
|
||||
return RenderResult(
|
||||
html=html,
|
||||
photo=file_id,
|
||||
file_type=file_type,
|
||||
cache=self.html_to_file_id_cache,
|
||||
ttl=ttl,
|
||||
caption=caption,
|
||||
parse_mode=parse_mode,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
browser = await self._browser.get_browser()
|
||||
start_time = loop.time()
|
||||
page = await browser.new_page(viewport=viewport)
|
||||
uri = (PROJECT_ROOT / template.filename).as_uri()
|
||||
await page.goto(uri)
|
||||
await page.set_content(html, wait_until="networkidle")
|
||||
if evaluate:
|
||||
await page.evaluate(evaluate)
|
||||
clip = None
|
||||
if query_selector:
|
||||
try:
|
||||
card = await page.query_selector(query_selector)
|
||||
if not card:
|
||||
raise QuerySelectorNotFound
|
||||
clip = await card.bounding_box()
|
||||
if not clip:
|
||||
raise QuerySelectorNotFound
|
||||
except QuerySelectorNotFound:
|
||||
logger.warning("未找到 %s 元素", query_selector)
|
||||
png_data = await page.screenshot(clip=clip, full_page=full_page)
|
||||
await page.close()
|
||||
logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||
return RenderResult(
|
||||
html=html,
|
||||
photo=png_data,
|
||||
file_type=file_type,
|
||||
cache=self.html_to_file_id_cache,
|
||||
ttl=ttl,
|
||||
caption=caption,
|
||||
parse_mode=parse_mode,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug):
|
||||
def __init__(
|
||||
self,
|
||||
template_service: TemplateService,
|
||||
cache: TemplatePreviewCache,
|
||||
web_app: FastAPI,
|
||||
):
|
||||
self.web_app = web_app
|
||||
self.template_service = template_service
|
||||
self.cache = cache
|
||||
self.register_routes()
|
||||
|
||||
async def get_preview_url(self, template: str, data: dict):
|
||||
"""获取预览 URL"""
|
||||
components = urlsplit(application_config.webserver.url)
|
||||
path = urljoin("/preview/", template)
|
||||
query = {}
|
||||
|
||||
# 如果有数据,暂存在 redis 中
|
||||
if data:
|
||||
key = str(uuid4())
|
||||
await self.cache.set_data(key, data)
|
||||
query["key"] = key
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
return components._replace(path=path, query=urlencode(query)).geturl()
|
||||
|
||||
def register_routes(self):
|
||||
"""注册预览用到的路由"""
|
||||
|
||||
@self.web_app.get("/preview/{path:path}")
|
||||
async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612
|
||||
# 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源
|
||||
if not path.endswith((".html", ".jinja2")):
|
||||
full_path = self.template_service.template_dir / path
|
||||
if not full_path.is_file():
|
||||
raise HTTPException(status_code=404, detail=f"Template '{path}' not found")
|
||||
return FileResponse(full_path)
|
||||
|
||||
# 取回暂存的渲染数据
|
||||
data = await self.cache.get_data(key) if key else {}
|
||||
if key and data is None:
|
||||
raise HTTPException(status_code=404, detail=f"Template data {key} not found")
|
||||
|
||||
# 渲染 jinja2 模板
|
||||
html = await self.template_service.render_async(path, data)
|
||||
# 将本地 URL file:// 修改为 HTTP url,因为浏览器内不允许加载本地文件
|
||||
# file:///project_dir/cache/image.jpg => /cache/image.jpg
|
||||
html = html.replace(PROJECT_ROOT.as_uri(), "")
|
||||
return HTMLResponse(html)
|
||||
|
||||
# 其他静态资源
|
||||
for name in ["cache", "resources"]:
|
||||
directory = PROJECT_ROOT / name
|
||||
directory.mkdir(exist_ok=True)
|
||||
self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)
|
||||
|
@ -1,24 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from gram_core.services.users.cache import UserAdminCache
|
||||
|
||||
__all__ = ("UserAdminCache",)
|
||||
|
||||
|
||||
class UserAdminCache(BaseService.Component):
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.qname = "users:admin"
|
||||
|
||||
async def ismember(self, user_id: int) -> bool:
|
||||
return await self.client.sismember(self.qname, user_id)
|
||||
|
||||
async def get_all(self) -> List[int]:
|
||||
return [int(str_data) for str_data in await self.client.smembers(self.qname)]
|
||||
|
||||
async def set(self, user_id: int) -> bool:
|
||||
return await self.client.sadd(self.qname, user_id)
|
||||
|
||||
async def remove(self, user_id: int) -> bool:
|
||||
return await self.client.srem(self.qname, user_id)
|
||||
|
@ -1,34 +1,7 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer
|
||||
from gram_core.services.users.models import User, UserDataBase, PermissionsEnum
|
||||
|
||||
__all__ = (
|
||||
"User",
|
||||
"UserDataBase",
|
||||
"PermissionsEnum",
|
||||
)
|
||||
|
||||
|
||||
class PermissionsEnum(int, enum.Enum):
|
||||
OWNER = 1
|
||||
ADMIN = 2
|
||||
PUBLIC = 3
|
||||
|
||||
|
||||
class User(SQLModel):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||
)
|
||||
user_id: int = Field(unique=True, sa_column=Column(BigInteger()))
|
||||
permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum)))
|
||||
locale: Optional[str] = Field()
|
||||
ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
|
||||
ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
|
||||
is_banned: Optional[int] = Field()
|
||||
|
||||
|
||||
class UserDataBase(User, table=True):
|
||||
__tablename__ = "users"
|
||||
|
@ -1,44 +1,3 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.database import Database
|
||||
from core.services.users.models import UserDataBase as User
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
from gram_core.services.users.repositories import UserRepository
|
||||
|
||||
__all__ = ("UserRepository",)
|
||||
|
||||
|
||||
class UserRepository(BaseService.Component):
|
||||
def __init__(self, database: Database):
|
||||
self.engine = database.engine
|
||||
|
||||
async def get_by_user_id(self, user_id: int) -> Optional[User]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(User).where(User.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, user: User):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, user: User) -> User:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
async def remove(self, user: User):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
|
||||
async def get_all(self) -> List[User]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(User)
|
||||
results = await session.exec(statement)
|
||||
return results.all()
|
||||
|
@ -1,83 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.config import config
|
||||
from core.services.users.cache import UserAdminCache
|
||||
from core.services.users.models import PermissionsEnum, UserDataBase as User
|
||||
from core.services.users.repositories import UserRepository
|
||||
from gram_core.services.users.services import UserService, UserAdminService
|
||||
|
||||
__all__ = ("UserService", "UserAdminService")
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
|
||||
class UserService(BaseService):
|
||||
def __init__(self, user_repository: UserRepository) -> None:
|
||||
self._repository: UserRepository = user_repository
|
||||
|
||||
async def get_user_by_id(self, user_id: int) -> Optional[User]:
|
||||
"""从数据库获取用户信息
|
||||
:param user_id:用户ID
|
||||
:return: User
|
||||
"""
|
||||
return await self._repository.get_by_user_id(user_id)
|
||||
|
||||
async def remove(self, user: User):
|
||||
return await self._repository.remove(user)
|
||||
|
||||
async def update_user(self, user: User):
|
||||
return await self._repository.add(user)
|
||||
|
||||
|
||||
class UserAdminService(BaseService):
|
||||
def __init__(self, user_repository: UserRepository, cache: UserAdminCache):
|
||||
self.user_repository = user_repository
|
||||
self._cache = cache
|
||||
|
||||
async def initialize(self):
|
||||
owner = config.owner
|
||||
if owner:
|
||||
user = await self.user_repository.get_by_user_id(owner)
|
||||
if user:
|
||||
if user.permissions != PermissionsEnum.OWNER:
|
||||
user.permissions = PermissionsEnum.OWNER
|
||||
await self._cache.set(user.user_id)
|
||||
await self.user_repository.update(user)
|
||||
else:
|
||||
user = User(user_id=owner, permissions=PermissionsEnum.OWNER)
|
||||
await self._cache.set(user.user_id)
|
||||
await self.user_repository.add(user)
|
||||
else:
|
||||
logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限")
|
||||
users = await self.user_repository.get_all()
|
||||
for user in users:
|
||||
await self._cache.set(user.user_id)
|
||||
|
||||
async def is_admin(self, user_id: int) -> bool:
|
||||
return await self._cache.ismember(user_id)
|
||||
|
||||
async def get_admin_list(self) -> List[int]:
|
||||
return await self._cache.get_all()
|
||||
|
||||
async def add_admin(self, user_id: int) -> bool:
|
||||
user = await self.user_repository.get_by_user_id(user_id)
|
||||
if user:
|
||||
if user.permissions == PermissionsEnum.OWNER:
|
||||
return False
|
||||
if user.permissions != PermissionsEnum.ADMIN:
|
||||
user.permissions = PermissionsEnum.ADMIN
|
||||
await self.user_repository.update(user)
|
||||
else:
|
||||
user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN)
|
||||
await self.user_repository.add(user)
|
||||
return await self._cache.set(user_id)
|
||||
|
||||
async def delete_admin(self, user_id: int) -> bool:
|
||||
user = await self.user_repository.get_by_user_id(user_id)
|
||||
if user:
|
||||
if user.permissions == PermissionsEnum.OWNER:
|
||||
return True # 假装移除成功
|
||||
user.permissions = PermissionsEnum.PUBLIC
|
||||
await self.user_repository.update(user)
|
||||
return await self._cache.remove(user.user_id)
|
||||
return False
|
||||
|
@ -1,118 +1,3 @@
|
||||
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload
|
||||
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
|
||||
from sqlalchemy.sql.base import Executable as _Executable
|
||||
from sqlmodel.engine.result import Result, ScalarResult
|
||||
from sqlmodel.orm.session import Session
|
||||
from sqlmodel.sql.base import Executable
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
from typing_extensions import Literal
|
||||
|
||||
_TSelectParam = TypeVar("_TSelectParam")
|
||||
from gram_core.sqlmodel.session import AsyncSession
|
||||
|
||||
__all__ = ("AsyncSession",)
|
||||
|
||||
|
||||
class AsyncSession(_AsyncSession): # pylint: disable=W0223
|
||||
sync_session_class = Session
|
||||
sync_session: Session
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
|
||||
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
|
||||
sync_session_class: Type[Session] = Session,
|
||||
**kw: Any,
|
||||
):
|
||||
super().__init__(
|
||||
bind=bind,
|
||||
binds=binds,
|
||||
sync_session_class=sync_session_class,
|
||||
**kw,
|
||||
)
|
||||
|
||||
@overload
|
||||
async def exec(
|
||||
self,
|
||||
statement: Select[_TSelectParam],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
**kw: Any,
|
||||
) -> Result[_TSelectParam]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def exec(
|
||||
self,
|
||||
statement: SelectOfScalar[_TSelectParam],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
**kw: Any,
|
||||
) -> ScalarResult[_TSelectParam]:
|
||||
...
|
||||
|
||||
async def exec(
|
||||
self,
|
||||
statement: Union[
|
||||
Select[_TSelectParam],
|
||||
SelectOfScalar[_TSelectParam],
|
||||
Executable[_TSelectParam],
|
||||
],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
**kw: Any,
|
||||
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
|
||||
results = super().execute(
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
**kw,
|
||||
)
|
||||
if isinstance(statement, SelectOfScalar):
|
||||
return (await results).scalars() # type: ignore
|
||||
return await results # type: ignore
|
||||
|
||||
async def execute( # pylint: disable=W0221
|
||||
self,
|
||||
statement: _Executable,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
**kw: Any,
|
||||
) -> Result[Any]:
|
||||
return await super().execute( # type: ignore
|
||||
statement=statement,
|
||||
params=params,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
**kw,
|
||||
)
|
||||
|
||||
async def get( # pylint: disable=W0221
|
||||
self,
|
||||
entity: Type[_TSelectParam],
|
||||
ident: Any,
|
||||
options: Optional[Sequence[Any]] = None,
|
||||
populate_existing: bool = False,
|
||||
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
|
||||
identity_token: Optional[Any] = None,
|
||||
execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT,
|
||||
) -> Optional[_TSelectParam]:
|
||||
return await super().get(
|
||||
entity=entity,
|
||||
ident=ident,
|
||||
options=options,
|
||||
populate_existing=populate_existing,
|
||||
with_for_update=with_for_update,
|
||||
identity_token=identity_token,
|
||||
execution_options=execution_options,
|
||||
)
|
||||
|
17
poetry.lock
generated
17
poetry.lock
generated
@ -920,6 +920,21 @@ files = [
|
||||
[package.dependencies]
|
||||
gitdb = ">=4.0.1,<5"
|
||||
|
||||
[[package]]
|
||||
name = "gram-core"
|
||||
version = "0.1.0"
|
||||
description = "telegram robot base core."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/PaiGramTeam/GramCore.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "7fb5d4c0e01731e6901829fe317be19023c2c4c7"
|
||||
|
||||
[[package]]
|
||||
name = "greenlet"
|
||||
version = "1.1.3"
|
||||
@ -2800,4 +2815,4 @@ test = ["flaky", "pytest", "pytest-asyncio"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "844402570472817cfd7937dd651a35f01e45f2e7e0d0381e540b357175b73899"
|
||||
content-hash = "ea8005da6cf4ff3c982a5f7b98491328c507c97dcfcd46500f14c74a3621f788"
|
||||
|
@ -45,6 +45,7 @@ pillow = "^10.0.0"
|
||||
playwright = "^1.27.1"
|
||||
aiosqlite = { extras = ["sqlite"], version = "^0.19.0" }
|
||||
simnet = { git = "https://github.com/PaiGramTeam/SIMNet" }
|
||||
gram-core = {git = "https://github.com/PaiGramTeam/GramCore.git"}
|
||||
|
||||
[tool.poetry.extras]
|
||||
pyro = ["Pyrogram", "TgCrypto"]
|
||||
|
@ -8,7 +8,7 @@ anyio==3.7.1 ; python_version >= "3.8" and python_version < "4.0"
|
||||
appdirs==1.4.4 ; python_version >= "3.8" and python_version < "4.0"
|
||||
apscheduler==3.10.1 ; python_version >= "3.8" and python_version < "4.0"
|
||||
arko-wrapper==0.2.8 ; python_version >= "3.8" and python_version < "4.0"
|
||||
async-lru==2.0.3 ; python_version >= "3.8" and python_version < "4.0"
|
||||
async-lru==2.0.4 ; python_version >= "3.8" and python_version < "4.0"
|
||||
async-timeout==4.0.2 ; python_version >= "3.8" and python_version < "4.0"
|
||||
asyncmy==0.2.8 ; python_version >= "3.8" and python_version < "4.0"
|
||||
attrs==23.1.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
@ -26,12 +26,13 @@ cryptography==41.0.2 ; python_version >= "3.8" and python_version < "4.0"
|
||||
enkanetwork-py @ git+https://github.com/mrwan200/EnkaNetwork.py@master ; python_version >= "3.8" and python_version < "4.0"
|
||||
et-xmlfile==1.1.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
exceptiongroup==1.1.2 ; python_version >= "3.8" and python_version < "3.11"
|
||||
fakeredis==2.16.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
fakeredis==2.17.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
fastapi==0.99.1 ; python_version >= "3.8" and python_version < "4.0"
|
||||
flaky==3.7.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
frozenlist==1.4.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
gitdb==4.0.10 ; python_version >= "3.8" and python_version < "4.0"
|
||||
gitpython==3.1.32 ; python_version >= "3.8" and python_version < "4.0"
|
||||
gram-core @ git+https://github.com/PaiGramTeam/GramCore.git@main ; python_version >= "3.8" and python_version < "4.0"
|
||||
greenlet==1.1.3 ; python_version >= "3.8" and python_version < "4.0"
|
||||
h11==0.14.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
httpcore==0.17.3 ; python_version >= "3.8" and python_version < "4.0"
|
||||
@ -73,8 +74,8 @@ pytz==2023.3 ; python_version >= "3.8" and python_version < "4.0"
|
||||
pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4.0"
|
||||
qrcode==7.4.2 ; python_version >= "3.8" and python_version < "4.0"
|
||||
redis==4.6.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
rich==13.4.2 ; python_version >= "3.8" and python_version < "4.0"
|
||||
sentry-sdk==1.28.1 ; python_version >= "3.8" and python_version < "4.0"
|
||||
rich==13.5.1 ; python_version >= "3.8" and python_version < "4.0"
|
||||
sentry-sdk==1.29.2 ; python_version >= "3.8" and python_version < "4.0"
|
||||
setuptools==68.0.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
simnet @ git+https://github.com/PaiGramTeam/SIMNet@main ; python_version >= "3.8" and python_version < "4.0"
|
||||
six==1.16.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
@ -96,7 +97,7 @@ tzdata==2023.3 ; python_version >= "3.8" and python_version < "4.0" and platform
|
||||
tzlocal==5.0.1 ; python_version >= "3.8" and python_version < "4.0"
|
||||
ujson==5.8.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
urllib3==1.26.16 ; python_version >= "3.8" and python_version < "4.0"
|
||||
uvicorn[standard]==0.22.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
uvicorn[standard]==0.23.1 ; python_version >= "3.8" and python_version < "4.0"
|
||||
uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.8" and python_version < "4.0"
|
||||
watchfiles==0.19.0 ; python_version >= "3.8" and python_version < "4.0"
|
||||
websockets==10.4 ; python_version >= "3.8" and python_version < "4.0"
|
||||
|
Loading…
Reference in New Issue
Block a user