♻️ separate core code

Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
This commit is contained in:
omg-xtao 2023-08-02 20:11:35 +08:00 committed by GitHub
parent af2e9bdb9b
commit 865f29bd77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 159 additions and 4849 deletions

View File

@ -1,289 +1,5 @@
"""BOT""" """BOT"""
import asyncio
import signal
from functools import wraps
from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func
from ssl import SSLZeroReturnError
from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar
import pytz from gram_core.application import Application
import uvicorn
from fastapi import FastAPI
from telegram import Bot, Update
from telegram.error import NetworkError, TelegramError, TimedOut
from telegram.ext import (
Application as TelegramApplication,
ApplicationBuilder as TelegramApplicationBuilder,
Defaults,
JobQueue,
)
from typing_extensions import ParamSpec
from uvicorn import Server
from core.config import config as application_config
from core.handler.limiterhandler import LimiterHandler
from core.manager import Managers
from core.override.telegram import HTTPXRequest
from core.ratelimiter import RateLimiter
from utils.const import WRAPPER_ASSIGNMENTS
from utils.log import logger
from utils.models.signal import Singleton
if TYPE_CHECKING:
from asyncio import Task
from types import FrameType
__all__ = ("Application",) __all__ = ("Application",)
R = TypeVar("R")
T = TypeVar("T")
P = ParamSpec("P")
class Application(Singleton):
"""Application"""
_web_server_task: Optional["Task"] = None
_startup_funcs: List[Callable] = []
_shutdown_funcs: List[Callable] = []
def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None:
self._running = False
self.managers = managers
self.telegram = telegram
self.web_server = web_server
self.managers.set_application(application=self) # 给 managers 设置 application
self.managers.build_executor("Application")
@classmethod
def build(cls):
managers = Managers()
telegram = (
TelegramApplicationBuilder()
.get_updates_read_timeout(application_config.update_read_timeout)
.get_updates_write_timeout(application_config.update_write_timeout)
.get_updates_connect_timeout(application_config.update_connect_timeout)
.get_updates_pool_timeout(application_config.update_pool_timeout)
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
.token(application_config.bot_token)
.request(
HTTPXRequest(
connection_pool_size=application_config.connection_pool_size,
proxy_url=application_config.proxy_url,
read_timeout=application_config.read_timeout,
write_timeout=application_config.write_timeout,
connect_timeout=application_config.connect_timeout,
pool_timeout=application_config.pool_timeout,
)
)
.rate_limiter(RateLimiter())
.build()
)
web_server = Server(
uvicorn.Config(
app=FastAPI(debug=application_config.debug),
port=application_config.webserver.port,
host=application_config.webserver.host,
log_config=None,
)
)
return cls(managers, telegram, web_server)
@property
def running(self) -> bool:
"""bot 是否正在运行"""
with self._lock:
return self._running
@property
def web_app(self) -> FastAPI:
"""fastapi app"""
return self.web_server.config.app
@property
def bot(self) -> Optional[Bot]:
return self.telegram.bot
@property
def job_queue(self) -> Optional[JobQueue]:
return self.telegram.job_queue
async def _on_startup(self) -> None:
for func in self._startup_funcs:
await self.managers.executor(func, block=getattr(func, "block", False))
async def _on_shutdown(self) -> None:
for func in self._shutdown_funcs:
await self.managers.executor(func, block=getattr(func, "block", False))
async def initialize(self):
"""BOT 初始化"""
self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制
await self.managers.start_dependency() # 启动基础服务
await self.managers.init_components() # 实例化组件
await self.managers.start_services() # 启动其他服务
await self.managers.install_plugins() # 安装插件
async def shutdown(self):
"""BOT 关闭"""
await self.managers.uninstall_plugins() # 卸载插件
await self.managers.stop_services() # 终止其他服务
await self.managers.stop_dependency() # 终止基础服务
async def start(self) -> None:
"""启动 BOT"""
logger.info("正在启动 BOT 中...")
def error_callback(exc: TelegramError) -> None:
"""错误信息回调"""
self.telegram.create_task(self.telegram.process_error(error=exc, update=None))
await self.telegram.initialize()
logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True})
if application_config.webserver.enable: # 如果使用 web app
server_config = self.web_server.config
server_config.setup_event_loop()
if not server_config.loaded:
server_config.load()
self.web_server.lifespan = server_config.lifespan_class(server_config)
try:
await self.web_server.startup()
except OSError as e:
if e.errno == 10048:
logger.error("Web Server 端口被占用:%s", e)
logger.error("Web Server 启动失败,正在退出")
raise SystemExit from None
if self.web_server.should_exit:
logger.error("Web Server 启动失败,正在退出")
raise SystemExit from None
logger.success("Web Server 启动成功")
self._web_server_task = asyncio.create_task(self.web_server.main_loop())
for _ in range(5): # 连接至 telegram 服务器
try:
await self.telegram.updater.start_polling(
error_callback=error_callback, allowed_updates=Update.ALL_TYPES
)
break
except TimedOut:
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
continue
except NetworkError as e:
logger.exception()
if isinstance(e, SSLZeroReturnError):
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
else:
logger.error("网络连接出现问题, 请检查您的网络状况.")
raise SystemExit from e
await self.initialize()
logger.success("BOT 初始化成功")
logger.debug("BOT 开始启动")
await self._on_startup()
await self.telegram.start()
self._running = True
logger.success("BOT 启动成功")
def stop_signal_handler(self, signum: int):
"""终止信号处理"""
signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")}
logger.debug("接收到了终止信号 %s 正在退出...", signals[signum])
if self._web_server_task:
self._web_server_task.cancel()
async def idle(self) -> None:
"""在接收到中止信号之前堵塞loop"""
task = None
def stop_handler(signum: int, _: "FrameType") -> None:
self.stop_signal_handler(signum)
task.cancel()
for s in (SIGINT, SIGTERM, SIGABRT):
signal_func(s, stop_handler)
while True:
task = asyncio.create_task(asyncio.sleep(600))
try:
await task
except asyncio.CancelledError:
break
async def stop(self) -> None:
"""关闭"""
logger.info("BOT 正在关闭")
self._running = False
await self._on_shutdown()
if self.telegram.updater.running:
await self.telegram.updater.stop()
await self.shutdown()
if self.telegram.running:
await self.telegram.stop()
await self.telegram.shutdown()
if self.web_server is not None:
try:
await self.web_server.shutdown()
logger.info("Web Server 已经关闭")
except AttributeError:
pass
logger.success("BOT 关闭成功")
def launch(self) -> None:
"""启动"""
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(self.start())
loop.run_until_complete(self.idle())
except (SystemExit, KeyboardInterrupt) as exc:
logger.debug("接收到了终止信号BOT 即将关闭", exc_info=exc) # 接收到了终止信号
except NetworkError as e:
if isinstance(e, SSLZeroReturnError):
logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.")
else:
logger.critical("网络连接出现问题, 请检查您的网络状况.")
except Exception as e:
logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e)
finally:
loop.run_until_complete(self.stop())
if application_config.reload:
raise SystemExit from None
def on_startup(self, func: Callable[P, R]) -> Callable[P, R]:
"""注册一个在 BOT 启动时执行的函数"""
if func not in self._startup_funcs:
self._startup_funcs.append(func)
# noinspection PyTypeChecker
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]:
"""注册一个在 BOT 停止时执行的函数"""
if func not in self._shutdown_funcs:
self._shutdown_funcs.append(func)
# noinspection PyTypeChecker
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper

View File

@ -1,60 +1,3 @@
from abc import ABC from gram_core.base_service import BaseService, BaseServiceType, DependenceType, ComponentType, get_all_services
from itertools import chain
from typing import ClassVar, Iterable, Type, TypeVar
from typing_extensions import Self
from utils.helpers import isabstract
__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services") __all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services")
class _BaseService:
"""服务基类"""
_is_component: ClassVar[bool] = False
_is_dependence: ClassVar[bool] = False
def __init_subclass__(cls, load: bool = True, **kwargs):
cls.is_dependence = cls._is_dependence
cls.is_component = cls._is_component
cls.load = load
async def __aenter__(self) -> Self:
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()
async def initialize(self) -> None:
"""Initialize resources used by this service"""
async def shutdown(self) -> None:
"""Stop & clear resources used by this service"""
class _Dependence(_BaseService, ABC):
_is_dependence: ClassVar[bool] = True
class _Component(_BaseService, ABC):
_is_component: ClassVar[bool] = True
class BaseService(_BaseService, ABC):
Dependence: Type[_BaseService] = _Dependence
Component: Type[_BaseService] = _Component
BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService)
DependenceType = TypeVar("DependenceType", bound=_Dependence)
ComponentType = TypeVar("ComponentType", bound=_Component)
# noinspection PyProtectedMember
def get_all_services() -> Iterable[Type[_BaseService]]:
return filter(
lambda x: x.__name__[0] != "_" and x.load and not isabstract(x),
chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()),
)

View File

@ -1,29 +1,3 @@
import enum from gram_core.basemodel import RegionEnum, Settings
try:
import ujson as jsonlib
except ImportError:
import json as jsonlib
from pydantic import BaseSettings
__all__ = ("RegionEnum", "Settings") __all__ = ("RegionEnum", "Settings")
class RegionEnum(int, enum.Enum):
"""账号数据所在服务器"""
NULL = 0
HYPERION = 1 # 米忽悠国服 hyperion
HOYOLAB = 2 # 米忽悠国际服 hoyolab
class Settings(BaseSettings):
def __new__(cls, *args, **kwargs):
cls.update_forward_refs()
return super(Settings, cls).__new__(cls) # pylint: disable=E1120
class Config(BaseSettings.Config):
case_sensitive = False
json_loads = jsonlib.loads
json_dumps = jsonlib.dumps

View File

@ -1 +0,0 @@
"""bot builtins"""

View File

@ -1,38 +0,0 @@
"""上下文管理"""
from contextlib import contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from telegram.ext import CallbackContext
from telegram import Update
__all__ = [
"CallbackContextCV",
"UpdateCV",
"handler_contexts",
"job_contexts",
]
CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback")
UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate")
@contextmanager
def handler_contexts(update: "Update", context: "CallbackContext") -> None:
context_token = CallbackContextCV.set(context)
update_token = UpdateCV.set(update)
try:
yield
finally:
CallbackContextCV.reset(context_token)
UpdateCV.reset(update_token)
@contextmanager
def job_contexts(context: "CallbackContext") -> None:
token = CallbackContextCV.set(context)
try:
yield
finally:
CallbackContextCV.reset(token)

View File

@ -1,309 +0,0 @@
"""参数分发器"""
import asyncio
import inspect
from abc import ABC, abstractmethod
from asyncio import AbstractEventLoop
from functools import cached_property, lru_cache, partial, wraps
from inspect import Parameter, Signature
from itertools import chain
from types import GenericAlias, MethodType
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)
from arkowrapper import ArkoWrapper
from fastapi import FastAPI
from telegram import Bot as TelegramBot, Chat, Message, Update, User
from telegram.ext import Application as TelegramApplication, CallbackContext, Job
from typing_extensions import ParamSpec
from uvicorn import Server
from core.application import Application
from utils.const import WRAPPER_ASSIGNMENTS
from utils.typedefs import R, T
__all__ = (
"catch",
"AbstractDispatcher",
"BaseDispatcher",
"HandlerDispatcher",
"JobDispatcher",
"dispatched",
)
P = ParamSpec("P")
TargetType = Union[Type, str, Callable[[Any], bool]]
_CATCH_TARGET_ATTR = "_catch_targets"
def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorate(func: Callable[P, R]) -> Callable[P, R]:
setattr(func, _CATCH_TARGET_ATTR, targets)
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
return decorate
@lru_cache(64)
def get_signature(func: Union[type, Callable]) -> Signature:
if isinstance(func, type):
return inspect.signature(func.__init__)
return inspect.signature(func)
class AbstractDispatcher(ABC):
"""参数分发器"""
IGNORED_ATTRS = []
_args: List[Any] = []
_kwargs: Dict[Union[str, Type], Any] = {}
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
return self._application
def __init__(self, *args, **kwargs) -> None:
self._args = list(args)
self._kwargs = dict(kwargs)
for _, value in kwargs.items():
type_arg = type(value)
if type_arg != str:
self._kwargs[type_arg] = value
for arg in args:
type_arg = type(arg)
if type_arg != str:
self._kwargs[type_arg] = arg
@cached_property
def catch_funcs(self) -> List[MethodType]:
# noinspection PyTypeChecker
return list(
ArkoWrapper(dir(self))
.filter(lambda x: not x.startswith("_"))
.filter(
lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"]
)
.map(lambda x: getattr(self, x))
.filter(lambda x: isinstance(x, MethodType))
.filter(lambda x: hasattr(x, "_catch_targets"))
)
@cached_property
def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]:
result = {}
for catch_func in self.catch_funcs:
catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR)
for catch_target in catch_targets:
result[catch_target] = catch_func
return result
@cached_property
def dispatch_funcs(self) -> List[MethodType]:
return list(
ArkoWrapper(dir(self))
.filter(lambda x: x.startswith("dispatch_by_"))
.map(lambda x: getattr(self, x))
.filter(lambda x: isinstance(x, MethodType))
)
@abstractmethod
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
"""默认的 dispatch 方法"""
@abstractmethod
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
"""使用 catch_func 获取并分配参数"""
def dispatch(self, func: Callable[P, R]) -> Callable[..., R]:
"""将参数分配给函数,从而合成一个无需参数即可执行的函数"""
params = {}
signature = get_signature(func)
parameters: Dict[str, Parameter] = dict(signature.parameters)
for name, parameter in list(parameters.items()):
parameter: Parameter
if any(
[
name == "self" and isinstance(func, (type, MethodType)),
parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL],
]
):
del parameters[name]
continue
for dispatch_func in self.dispatch_funcs:
parameters[name] = dispatch_func(parameter)
for name, parameter in parameters.items():
if parameter.default != Parameter.empty:
params[name] = parameter.default
else:
params[name] = None
return partial(func, **params)
@catch(Application)
def catch_application(self) -> Application:
return self.application
class BaseDispatcher(AbstractDispatcher):
"""默认参数分发器"""
_instances: Sequence[Any]
def _get_kwargs(self) -> Dict[Type[T], T]:
result = self._get_default_kwargs()
result[AbstractDispatcher] = self
result.update(self._kwargs)
return result
def _get_default_kwargs(self) -> Dict[Type[T], T]:
application = self.application
_default_kwargs = {
FastAPI: application.web_app,
Server: application.web_server,
TelegramApplication: application.telegram,
TelegramBot: application.telegram.bot,
}
if not application.running:
for obj in chain(
application.managers.dependency,
application.managers.components,
application.managers.services,
application.managers.plugins,
):
_default_kwargs[type(obj)] = obj
return {k: v for k, v in _default_kwargs.items() if v is not None}
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
annotation = parameter.annotation
# noinspection PyTypeChecker
if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None:
parameter._default = value # pylint: disable=W0212
return parameter
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
annotation = parameter.annotation
if annotation != Any and isinstance(annotation, GenericAlias):
return parameter
catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name)
if catch_func is not None:
# noinspection PyUnresolvedReferences,PyProtectedMember
parameter._default = catch_func() # pylint: disable=W0212
return parameter
@catch(AbstractEventLoop)
def catch_loop(self) -> AbstractEventLoop:
return asyncio.get_event_loop()
class HandlerDispatcher(BaseDispatcher):
"""Handler 参数分发器"""
def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None:
super().__init__(update=update, context=context, **kwargs)
self._update = update
self._context = context
def dispatch(
self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None
) -> Callable[..., R]:
self._update = update or self._update
self._context = context or self._context
if self._update is None:
from core.builtins.contexts import UpdateCV
self._update = UpdateCV.get()
if self._context is None:
from core.builtins.contexts import CallbackContextCV
self._context = CallbackContextCV.get()
return super().dispatch(func)
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
"""HandlerDispatcher 默认不使用 dispatch_by_default"""
return parameter
@catch(Update)
def catch_update(self) -> Update:
return self._update
@catch(CallbackContext)
def catch_context(self) -> CallbackContext:
return self._context
@catch(Message)
def catch_message(self) -> Message:
return self._update.effective_message
@catch(User)
def catch_user(self) -> User:
return self._update.effective_user
@catch(Chat)
def catch_chat(self) -> Chat:
return self._update.effective_chat
class JobDispatcher(BaseDispatcher):
"""Job 参数分发器"""
def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None:
super().__init__(context=context, **kwargs)
self._context = context
def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]:
self._context = context or self._context
if self._context is None:
from core.builtins.contexts import CallbackContextCV
self._context = CallbackContextCV.get()
return super().dispatch(func)
@catch("data")
def catch_data(self) -> Any:
return self._context.job.data
@catch(Job)
def catch_job(self) -> Job:
return self._context.job
@catch(CallbackContext)
def catch_context(self) -> CallbackContext:
return self._context
def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher):
def decorate(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return dispatcher().dispatch(func)(*args, **kwargs)
return wrapper
return decorate

View File

@ -1,131 +0,0 @@
"""执行器"""
import inspect
from functools import cached_property
from multiprocessing import RLock as Lock
from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar
from telegram import Update
from telegram.ext import CallbackContext
from typing_extensions import ParamSpec, Self
from core.builtins.contexts import handler_contexts, job_contexts
if TYPE_CHECKING:
from core.application import Application
from core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher
from multiprocessing.synchronize import RLock as LockType
__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor")
T = TypeVar("T")
R = TypeVar("R")
P = ParamSpec("P")
class BaseExecutor:
"""执行器
Args:
name(str): 该执行器的名称执行器的名称是唯一的
只支持执行只拥有 POSITIONAL_OR_KEYWORD KEYWORD_ONLY 两种参数类型的函数
"""
_lock: ClassVar["LockType"] = Lock()
_instances: ClassVar[Dict[str, Self]] = {}
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
return self._application
def __new__(cls: Type[T], name: str, *args, **kwargs) -> T:
with cls._lock:
if (instance := cls._instances.get(name)) is None:
instance = object.__new__(cls)
instance.__init__(name, *args, **kwargs)
cls._instances.update({name: instance})
return instance
@cached_property
def name(self) -> str:
"""当前执行器的名称"""
return self._name
def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
self._name = name
self._dispatcher = dispatcher
class Executor(BaseExecutor, Generic[P, R]):
async def __call__(
self,
target: Callable[P, R],
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
**kwargs,
) -> R:
dispatcher = self._dispatcher or dispatcher
dispatcher_instance = dispatcher(**kwargs)
dispatcher_instance.set_application(application=self.application)
dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数
# 执行
if inspect.iscoroutinefunction(target):
result = await dispatched_func()
else:
result = dispatched_func()
return result
class HandlerExecutor(BaseExecutor, Generic[P, R]):
"""Handler专用执行器"""
_callback: Callable[P, R]
_dispatcher: "HandlerDispatcher"
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None:
if dispatcher is None:
from core.builtins.dispatcher import HandlerDispatcher
dispatcher = HandlerDispatcher
super().__init__("handler", dispatcher)
self._callback = func
self._dispatcher = dispatcher()
def set_application(self, application: "Application") -> None:
self._application = application
if self._dispatcher is not None:
self._dispatcher.set_application(application)
async def __call__(self, update: Update, context: CallbackContext) -> R:
with handler_contexts(update, context):
dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context)
return await dispatched_func()
class JobExecutor(BaseExecutor):
"""Job 专用执行器"""
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
if dispatcher is None:
from core.builtins.dispatcher import JobDispatcher
dispatcher = JobDispatcher
super().__init__("job", dispatcher)
self._callback = func
self._dispatcher = dispatcher()
def set_application(self, application: "Application") -> None:
self._application = application
if self._dispatcher is not None:
self._dispatcher.set_application(application)
async def __call__(self, context: CallbackContext) -> R:
with job_contexts(context):
dispatched_func = self._dispatcher.dispatch(self._callback, context=context)
return await dispatched_func()

View File

@ -1,185 +0,0 @@
import inspect
import multiprocessing
import os
import signal
import threading
from pathlib import Path
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
from watchfiles import watch
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
from utils.log import logger
from utils.typedefs import StrOrPath
if TYPE_CHECKING:
from multiprocessing.process import BaseProcess
__all__ = ("Reloader",)
multiprocessing.allow_connection_pickling()
spawn = multiprocessing.get_context("spawn")
class FileFilter:
"""监控文件过滤"""
def __init__(self, includes: List[str], excludes: List[str]) -> None:
default_includes = ["*.py"]
self.includes = [default for default in default_includes if default not in excludes]
self.includes.extend(includes)
self.includes = list(set(self.includes))
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
self.excludes = [default for default in default_excludes if default not in includes]
self.exclude_dirs = []
for e in excludes:
p = Path(e)
try:
is_dir = p.is_dir()
except OSError:
is_dir = False
if is_dir:
self.exclude_dirs.append(p)
else:
self.excludes.append(e)
self.excludes = list(set(self.excludes))
def __call__(self, path: Path) -> bool:
for include_pattern in self.includes:
if path.match(include_pattern):
for exclude_dir in self.exclude_dirs:
if exclude_dir in path.parents:
return False
for exclude_pattern in self.excludes:
if path.match(exclude_pattern):
return False
return True
return False
class Reloader:
_target: Callable[..., None]
_process: "BaseProcess"
@property
def process(self) -> "BaseProcess":
return self._process
@property
def target(self) -> Callable[..., None]:
return self._target
def __init__(
self,
target: Callable[..., None],
*,
reload_delay: float = 0.25,
reload_dirs: List[StrOrPath] = None,
reload_includes: List[str] = None,
reload_excludes: List[str] = None,
):
if inspect.iscoroutinefunction(target):
raise ValueError("不支持异步函数")
self._target = target
self.reload_delay = reload_delay
_reload_dirs = []
for reload_dir in reload_dirs or []:
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
self.reload_dirs = []
for reload_dir in _reload_dirs:
append = True
for parent in reload_dir.parents:
if parent in _reload_dirs:
append = False
break
if append:
self.reload_dirs.append(reload_dir)
if not self.reload_dirs:
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
self._should_exit = threading.Event()
frame = inspect.currentframe().f_back
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
self.watcher = watch(
*self.reload_dirs,
watch_filter=None,
stop_event=self._should_exit,
yield_on_timeout=True,
)
def get_changes(self) -> Optional[List[Path]]:
if not self._process.is_alive():
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
self._should_exit.set()
try:
changes = next(self.watcher)
except StopIteration:
return None
if changes:
unique_paths = {Path(c[1]) for c in changes}
return [p for p in unique_paths if self.watch_filter(p)]
return None
def __iter__(self) -> Iterator[Optional[List[Path]]]:
return self
def __next__(self) -> Optional[List[Path]]:
return self.get_changes()
def run(self) -> None:
self.startup()
for changes in self:
if changes:
logger.warning(
"检测到文件 %s 发生改变, 正在重载...",
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
extra={"tag": "Reloader"},
)
self.restart()
self.shutdown()
def signal_handler(self, *_) -> None:
"""当接收到结束信号量时"""
self._process.join(3)
if self._process.is_alive():
self._process.terminate()
self._process.join()
self._should_exit.set()
def startup(self) -> None:
"""启动进程"""
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.signal_handler)
self._process = spawn.Process(target=self._target)
self._process.start()
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
def restart(self) -> None:
"""重启进程"""
self._process.terminate()
self._process.join(10)
self._process = spawn.Process(target=self._target)
self._process.start()
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
def shutdown(self) -> None:
"""关闭进程"""
self._process.terminate()
self._process.join(10)
logger.info("重载器已经关闭", extra={"tag": "Reloader"})

View File

@ -1,165 +1,3 @@
from enum import Enum from gram_core.config import ApplicationConfig, config, JoinGroups
from pathlib import Path
from typing import List, Optional, Union
import dotenv
from pydantic import AnyUrl, Field
from core.basemodel import Settings
from utils.const import PROJECT_ROOT
from utils.typedefs import NaturalNumber
__all__ = ("ApplicationConfig", "config", "JoinGroups") __all__ = ("ApplicationConfig", "config", "JoinGroups")
dotenv.load_dotenv()
class JoinGroups(str, Enum):
NO_ALLOW = "NO_ALLOW"
ALLOW_AUTH_USER = "ALLOW_AUTH_USER"
ALLOW_USER = "ALLOW_USER"
ALLOW_ALL = "ALLOW_ALL"
class DatabaseConfig(Settings):
driver_name: str = "mysql+asyncmy"
host: Optional[str] = None
port: Optional[int] = None
username: Optional[str] = None
password: Optional[str] = None
database: Optional[str] = None
class Config(Settings.Config):
env_prefix = "db_"
class RedisConfig(Settings):
host: str = "127.0.0.1"
port: int = 6379
database: int = Field(default=0, env="redis_db")
password: Optional[str] = None
class Config(Settings.Config):
env_prefix = "redis_"
class LoggerConfig(Settings):
name: str = "PaiGram"
width: Optional[int] = None
time_format: str = "[%Y-%m-%d %X]"
traceback_max_frames: int = 20
path: Path = PROJECT_ROOT / "logs"
render_keywords: List[str] = ["BOT"]
locals_max_length: int = 10
locals_max_string: int = 80
locals_max_depth: Optional[NaturalNumber] = None
filtered_names: List[str] = ["uvicorn"]
class Config(Settings.Config):
env_prefix = "logger_"
class MTProtoConfig(Settings):
api_id: Optional[int] = None
api_hash: Optional[str] = None
class WebServerConfig(Settings):
enable: bool = False
"""是否启用WebServer"""
host: str = "localhost"
port: int = 8080
class Config(Settings.Config):
env_prefix = "web_"
@property
def url(self) -> str:
# noinspection HttpUrlsUsage
return "http://" + self.host + ":" + str(self.port)
class ErrorConfig(Settings):
pb_url: str = ""
pb_sunset: int = 43200
pb_max_lines: int = 1000
sentry_dsn: str = ""
notification_chat_id: Optional[str] = None
class Config(Settings.Config):
env_prefix = "error_"
class ReloadConfig(Settings):
delay: float = 0.25
dirs: List[str] = []
include: List[str] = []
exclude: List[str] = []
class Config(Settings.Config):
env_prefix = "reload_"
class NoticeConfig(Settings):
user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!"
class Config(Settings.Config):
env_prefix = "notice_"
class ApplicationConfig(Settings):
debug: bool = False
"""debug 开关"""
retry: int = 5
"""重试次数"""
auto_reload: bool = False
"""自动重载"""
proxy_url: Optional[AnyUrl] = None
"""代理链接"""
upload_bbs_host: Optional[AnyUrl] = "https://upload-bbs.miyoushe.com"
bot_token: str = ""
"""BOT的token"""
owner: Optional[int] = None
channels: List[int] = []
"""文章推送群组"""
verify_groups: List[Union[int, str]] = []
"""启用群验证功能的群组"""
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
"""是否允许机器人被邀请到其它群组"""
timeout: int = 10
connection_pool_size: int = 256
read_timeout: Optional[float] = None
write_timeout: Optional[float] = None
connect_timeout: Optional[float] = None
pool_timeout: Optional[float] = None
update_read_timeout: Optional[float] = None
update_write_timeout: Optional[float] = None
update_connect_timeout: Optional[float] = None
update_pool_timeout: Optional[float] = None
genshin_ttl: Optional[int] = None
enka_network_api_agent: str = ""
pass_challenge_api: str = ""
pass_challenge_app_key: str = ""
pass_challenge_user_web: str = ""
reload: ReloadConfig = ReloadConfig()
database: DatabaseConfig = DatabaseConfig()
logger: LoggerConfig = LoggerConfig()
webserver: WebServerConfig = WebServerConfig()
redis: RedisConfig = RedisConfig()
mtproto: MTProtoConfig = MTProtoConfig()
error: ErrorConfig = ErrorConfig()
notice: NoticeConfig = NoticeConfig()
ApplicationConfig.update_forward_refs()
config = ApplicationConfig()

View File

@ -1,56 +1,3 @@
from typing import Optional, TYPE_CHECKING from gram_core.dependence.aiobrowser import AioBrowser
from playwright.async_api import Error, async_playwright
from core.base_service import BaseService
from utils.log import logger
if TYPE_CHECKING:
from playwright.async_api import Playwright as AsyncPlaywright, Browser
__all__ = ("AioBrowser",) __all__ = ("AioBrowser",)
class AioBrowser(BaseService.Dependence):
@property
def browser(self):
return self._browser
def __init__(self, loop=None):
self._browser: Optional["Browser"] = None
self._playwright: Optional["AsyncPlaywright"] = None
self._loop = loop
async def get_browser(self):
if self._browser is None:
await self.initialize()
return self._browser
async def initialize(self):
if self._playwright is None:
logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True})
self._playwright = await async_playwright().start()
logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True})
if self._browser is None:
logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True})
try:
self._browser = await self._playwright.chromium.launch(timeout=5000)
logger.success("[blue]Browser[/] 启动成功", extra={"markup": True})
except Error as err:
if "playwright install" in str(err):
logger.error(
"检查到 [blue]playwright[/] 刚刚安装或者未升级\n"
"请运行以下命令下载新浏览器\n"
"[blue bold]playwright install chromium[/]",
extra={"markup": True},
)
raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium")
raise err
return self._browser
async def shutdown(self):
if self._browser is not None:
await self._browser.close()
if self._playwright is not None:
self._playwright.stop()

View File

@ -1,16 +0,0 @@
from asyncio import AbstractEventLoop
from playwright.async_api import Browser, Playwright as AsyncPlaywright
from core.base_service import BaseService
__all__ = ("AioBrowser",)
class AioBrowser(BaseService.Dependence):
_browser: Browser | None
_playwright: AsyncPlaywright | None
_loop: AbstractEventLoop
@property
def browser(self) -> Browser | None: ...
async def get_browser(self) -> Browser: ...

View File

@ -1,51 +1,3 @@
import contextlib from gram_core.dependence.database import Database
from typing import Optional
from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from typing_extensions import Self
from core.base_service import BaseService
from core.config import ApplicationConfig
from core.sqlmodel.session import AsyncSession
__all__ = ("Database",) __all__ = ("Database",)
class Database(BaseService.Dependence):
@classmethod
def from_config(cls, config: ApplicationConfig) -> Self:
return cls(**config.database.dict())
def __init__(
self,
driver_name: str,
host: Optional[str] = None,
port: Optional[int] = None,
username: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
):
self.database = database # skipcq: PTC-W0052
self.password = password
self.username = username
self.port = port
self.host = host
self.url = URL.create(
driver_name,
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database,
)
self.engine = create_async_engine(self.url)
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
@contextlib.asynccontextmanager
async def session(self) -> AsyncSession:
yield self.Session()
async def shutdown(self):
self.Session.close_all()

View File

@ -1,67 +1,3 @@
import os from gram_core.dependence.mtproto import MTProto
from typing import Optional
from urllib.parse import urlparse
import aiofiles __all__ = ("MTProto",)
from core.base_service import BaseService
from core.config import config as bot_config
from utils.log import logger
try:
from pyrogram import Client
from pyrogram.session import session
session.log.debug = lambda *args, **kwargs: None # 关闭日记
PYROGRAM_AVAILABLE = True
except ImportError:
Client = None
session = None
PYROGRAM_AVAILABLE = False
class MTProto(BaseService.Dependence):
async def get_session(self):
async with aiofiles.open(self.session_path, mode="r") as f:
return await f.read()
async def set_session(self, b: str):
async with aiofiles.open(self.session_path, mode="w+") as f:
await f.write(b)
def session_exists(self):
return os.path.exists(self.session_path)
def __init__(self):
self.name = "paigram"
current_dir = os.getcwd()
self.session_path = os.path.join(current_dir, "paigram.session")
self.client: Optional[Client] = None
self.proxy: Optional[dict] = None
http_proxy = os.environ.get("HTTP_PROXY")
if http_proxy is not None:
http_proxy_url = urlparse(http_proxy)
self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port}
async def initialize(self): # pylint: disable=W0221
if not PYROGRAM_AVAILABLE:
logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None")
return
if bot_config.mtproto.api_id is None:
logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None")
return
if bot_config.mtproto.api_hash is None:
logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None")
return
self.client = Client(
api_id=bot_config.mtproto.api_id,
api_hash=bot_config.mtproto.api_hash,
name=self.name,
bot_token=bot_config.bot_token,
proxy=self.proxy,
)
await self.client.start()
async def shutdown(self): # pylint: disable=W0221
if self.client is not None:
await self.client.stop(block=False)

View File

@ -1,31 +0,0 @@
from __future__ import annotations
from typing import TypedDict
from core.base_service import BaseService
try:
from pyrogram import Client
from pyrogram.session import session
PYROGRAM_AVAILABLE = True
except ImportError:
Client = None
session = None
PYROGRAM_AVAILABLE = False
__all__ = ("MTProto",)
class _ProxyType(TypedDict):
scheme: str
hostname: str | None
port: int | None
class MTProto(BaseService.Dependence):
name: str
session_path: str
client: Client | None
proxy: _ProxyType | None
async def get_session(self) -> str: ...
async def set_session(self, b: str) -> None: ...
def session_exists(self) -> bool: ...

View File

@ -1,50 +1,3 @@
from typing import Optional, Union from gram_core.dependence.redisdb import RedisDB
import fakeredis.aioredis
from redis import asyncio as aioredis
from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError
from typing_extensions import Self
from core.base_service import BaseService
from core.config import ApplicationConfig
from utils.log import logger
__all__ = ["RedisDB"] __all__ = ["RedisDB"]
class RedisDB(BaseService.Dependence):
@classmethod
def from_config(cls, config: ApplicationConfig) -> Self:
return cls(**config.redis.dict())
def __init__(
self, host: str = "127.0.0.1", port: int = 6379, database: Union[str, int] = 0, password: Optional[str] = None
):
self.client = aioredis.Redis(host=host, port=port, db=database, password=password)
self.ttl = 600
async def ping(self):
# noinspection PyUnresolvedReferences
if await self.client.ping():
logger.info("连接 [red]Redis[/] 成功", extra={"markup": True})
else:
logger.info("连接 [red]Redis[/] 失败", extra={"markup": True})
raise RuntimeError("连接 Redis 失败")
async def start_fake_redis(self):
self.client = fakeredis.aioredis.FakeRedis()
await self.ping()
async def initialize(self):
logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True})
try:
await self.ping()
except (RedisTimeoutError, RedisConnectionError) as exc:
if isinstance(exc, RedisTimeoutError):
logger.warning("连接 [red]Redis[/] 超时,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
if isinstance(exc, RedisConnectionError):
logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
await self.start_fake_redis()
async def shutdown(self):
await self.client.close()

View File

@ -1,7 +1,4 @@
"""此模块包含核心模块的错误的基类""" """此模块包含核心模块的错误的基类"""
from typing import Union from gram_core.error import ServiceNotFoundError
__all__ = ("ServiceNotFoundError",)
class ServiceNotFoundError(Exception):
def __init__(self, name: Union[str, type]):
super().__init__(f"No service named '{name if isinstance(name, str) else name.__name__}'")

View File

@ -1,59 +1,3 @@
import asyncio from gram_core.handler.adminhandler import AdminHandler
from typing import TypeVar, TYPE_CHECKING, Any, Optional
from telegram import Update __all__ = ("AdminHandler",)
from telegram.ext import ApplicationHandlerStop, BaseHandler
from core.error import ServiceNotFoundError
from core.services.users.services import UserAdminService
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
from telegram.ext import Application as TelegramApplication
RT = TypeVar("RT")
UT = TypeVar("UT")
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
class AdminHandler(BaseHandler[Update, CCT]):
_lock = asyncio.Lock()
def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None:
self.handler = handler
self.application = application
self.user_service: Optional["UserAdminService"] = None
super().__init__(self.handler.callback, self.handler.block)
def check_update(self, update: object) -> bool:
if not isinstance(update, Update):
return False
return self.handler.check_update(update)
async def _user_service(self) -> "UserAdminService":
async with self._lock:
if self.user_service is not None:
return self.user_service
user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None)
if user_service is None:
raise ServiceNotFoundError("UserAdminService")
self.user_service = user_service
return self.user_service
async def handle_update(
self,
update: "UT",
application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]",
check_result: Any,
context: "CCT",
) -> RT:
user_service = await self._user_service()
user = update.effective_user
if await user_service.is_admin(user.id):
return await self.handler.handle_update(update, application, check_result, context)
message = update.effective_message
logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id)
await message.reply_text("权限不足")
raise ApplicationHandlerStop

View File

@ -1,62 +1,3 @@
import asyncio from gram_core.handler.callbackqueryhandler import CallbackQueryHandler, OverlappingException, OverlappingContext
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type
from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop __all__ = ("CallbackQueryHandler", "OverlappingException", "OverlappingContext")
from utils.log import logger
if TYPE_CHECKING:
from telegram.ext import Application
RT = TypeVar("RT")
UT = TypeVar("UT")
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
class OverlappingException(Exception):
pass
class OverlappingContext(AbstractAsyncContextManager):
_lock = asyncio.Lock()
def __init__(self, context: "CCT"):
self.context = context
async def __aenter__(self) -> None:
async with self._lock:
flag = self.context.user_data.get("overlapping", False)
if flag:
raise OverlappingException
self.context.user_data["overlapping"] = True
return None
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
async with self._lock:
del self.context.user_data["overlapping"]
return None
class CallbackQueryHandler(BaseCallbackQueryHandler):
async def handle_update(
self,
update: "UT",
application: "Application[Any, CCT, Any, Any, Any, Any]",
check_result: Any,
context: "CCT",
) -> RT:
self.collect_additional_context(context, update, application, check_result)
try:
async with OverlappingContext(context):
return await self.callback(update, context)
except OverlappingException as exc:
user = update.effective_user
logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id)
raise ApplicationHandlerStop from exc

View File

@ -1,71 +1,3 @@
import asyncio from gram_core.handler.limiterhandler import LimiterHandler
from typing import TypeVar, Optional
from telegram import Update __all__ = ("LimiterHandler",)
from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler
from utils.log import logger
UT = TypeVar("UT")
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
class LimiterHandler(TypeHandler[UT, CCT]):
_lock = asyncio.Lock()
def __init__(
self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None
):
"""Limiter Handler 通过
`Leaky bucket algorithm <https://en.wikipedia.org/wiki/Leaky_bucket>`_
实现对用户的输入的精确控制
输入超过一定速率后代码会抛出
:class:`telegram.ext.ApplicationHandlerStop`
异常并在一段时间内防止用户执行任何其他操作
:param max_rate: 在抛出异常之前最多允许 频率/ 的速度
:param time_period: 在限制速率的时间段的持续时间
:param amount: 提供的容量
:param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount
"""
self.max_rate = max_rate
self.amount = amount
self._rate_per_sec = max_rate / time_period
self.limit_time = limit_time
super().__init__(Update, self.limiter_callback)
async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
if update.inline_query is not None:
return
loop = asyncio.get_running_loop()
async with self._lock:
time = loop.time()
user_data = context.user_data
if user_data is None:
return
user_limit_time = user_data.get("limit_time")
if user_limit_time is not None:
if time >= user_limit_time:
del user_data["limit_time"]
else:
raise ApplicationHandlerStop
last_task_time = user_data.get("last_task_time", 0)
if last_task_time:
task_level = user_data.get("task_level", 0)
elapsed = time - last_task_time
decrement = elapsed * self._rate_per_sec
task_level = max(task_level - decrement, 0)
user_data["task_level"] = task_level
if not task_level + self.amount <= self.max_rate:
if self.limit_time:
limit_time = self.limit_time
else:
limit_time = 1 / self._rate_per_sec * self.amount
user_data["limit_time"] = time + limit_time
user = update.effective_user
logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s", user.full_name, user.id, limit_time)
raise ApplicationHandlerStop
user_data["last_task_time"] = time
task_level = user_data.get("task_level", 0)
user_data["task_level"] = task_level + self.amount

View File

@ -1,286 +0,0 @@
import asyncio
from importlib import import_module
from pathlib import Path
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
from arkowrapper import ArkoWrapper
from async_timeout import timeout
from typing_extensions import ParamSpec
from core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
from core.config import config as bot_config
from utils.const import PLUGIN_DIR, PROJECT_ROOT
from utils.helpers import gen_pkg
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
from core.plugin import PluginType
from core.builtins.executor import Executor
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
R = TypeVar("R")
T = TypeVar("T")
P = ParamSpec("P")
def _load_module(path: Path) -> None:
for pkg in gen_pkg(path):
try:
logger.debug('正在导入 "%s"', pkg)
import_module(pkg)
except Exception as e:
logger.exception(
'在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
)
raise SystemExit from e
class Manager(Generic[T]):
"""生命周期控制基类"""
_executor: Optional["Executor"] = None
_lib: Dict[Type[T], T] = {}
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
return self._application
@property
def executor(self) -> "Executor":
"""执行器"""
if self._executor is None:
raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.")
return self._executor
def build_executor(self, name: str):
from core.builtins.executor import Executor
from core.builtins.dispatcher import BaseDispatcher
self._executor = Executor(name, dispatcher=BaseDispatcher)
self._executor.set_application(self.application)
class DependenceManager(Manager[DependenceType]):
"""基础依赖管理"""
_dependency: Dict[Type[DependenceType], DependenceType] = {}
@property
def dependency(self) -> List[DependenceType]:
return list(self._dependency.values())
@property
def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]:
return self._dependency
async def start_dependency(self) -> None:
_load_module(PROJECT_ROOT / "core/dependence")
for dependence in filter(lambda x: x.is_dependence, get_all_services()):
dependence: Type[DependenceType]
instance: DependenceType
try:
if hasattr(dependence, "from_config"): # 如果有 from_config 方法
instance = dependence.from_config(bot_config) # 用 from_config 实例化服务
else:
instance = await self.executor(dependence)
await instance.initialize()
logger.success('基础服务 "%s" 启动成功', dependence.__name__)
self._lib[dependence] = instance
self._dependency[dependence] = instance
except Exception as e:
logger.exception('基础服务 "%s" 初始化失败BOT 将自动关闭', dependence.__name__)
raise SystemExit from e
async def stop_dependency(self) -> None:
async def task(d):
try:
async with timeout(5):
await d.shutdown()
logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__)
except asyncio.TimeoutError:
logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__)
except Exception as e:
logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e)
tasks = []
for dependence in self._dependency.values():
tasks.append(asyncio.create_task(task(dependence)))
await asyncio.gather(*tasks)
class ComponentManager(Manager[ComponentType]):
"""组件管理"""
_components: Dict[Type[ComponentType], ComponentType] = {}
@property
def components(self) -> List[ComponentType]:
return list(self._components.values())
@property
def components_map(self) -> Dict[Type[ComponentType], ComponentType]:
return self._components
async def init_components(self):
for path in filter(
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
):
_load_module(path)
components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component)
retry_times = 0
max_retry_times = len(components)
while components:
start_len = len(components)
for component in list(components):
component: Type[ComponentType]
instance: ComponentType
try:
instance = await self.executor(component)
self._lib[component] = instance
self._components[component] = instance
components = components.remove(component)
except Exception as e: # pylint: disable=W0703
logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True})
end_len = len(list(components))
if start_len == end_len:
retry_times += 1
if retry_times == max_retry_times and components:
for component in components:
logger.error('组件 "%s" 初始化失败', component.__name__)
raise SystemExit
class ServiceManager(Manager[BaseServiceType]):
"""服务控制类"""
_services: Dict[Type[BaseServiceType], BaseServiceType] = {}
@property
def services(self) -> List[BaseServiceType]:
return list(self._services.values())
@property
def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]:
return self._services
async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType:
instance: BaseServiceType
try:
if hasattr(target, "from_config"): # 如果有 from_config 方法
instance = target.from_config(bot_config) # 用 from_config 实例化服务
else:
instance = await self.executor(target)
await instance.initialize()
logger.success('服务 "%s" 启动成功', target.__name__)
return instance
except Exception as e: # pylint: disable=W0703
logger.exception('服务 "%s" 初始化失败BOT 将自动关闭', target.__name__)
raise SystemExit from e
async def start_services(self) -> None:
for path in filter(
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
):
_load_module(path)
for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类
instance = await self._initialize_service(service)
self._lib[service] = instance
self._services[service] = instance
async def stop_services(self) -> None:
"""关闭服务"""
if not self._services:
return
async def task(s):
try:
async with timeout(5):
await s.shutdown()
logger.success('服务 "%s" 关闭成功', s.__class__.__name__)
except asyncio.TimeoutError:
logger.warning('服务 "%s" 关闭超时', s.__class__.__name__)
except Exception as e:
logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e)
logger.info("正在关闭服务")
tasks = []
for service in self._services.values():
tasks.append(asyncio.create_task(task(service)))
await asyncio.gather(*tasks)
class PluginManager(Manager["PluginType"]):
"""插件管理"""
_plugins: Dict[Type["PluginType"], "PluginType"] = {}
@property
def plugins(self) -> List["PluginType"]:
"""所有已经加载的插件"""
return list(self._plugins.values())
@property
def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]:
return self._plugins
async def install_plugins(self) -> None:
"""安装所有插件"""
from core.plugin import get_all_plugins
for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()):
_load_module(path)
for plugin in get_all_plugins():
plugin: Type["PluginType"]
try:
instance: "PluginType" = await self.executor(plugin)
except Exception as e: # pylint: disable=W0703
logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
continue
self._plugins[plugin] = instance
if self._application is not None:
instance.set_application(self._application)
await asyncio.create_task(self.plugin_install_task(plugin, instance))
@staticmethod
async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"):
try:
await instance.install()
logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}")
except Exception as e: # pylint: disable=W0703
logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
async def uninstall_plugins(self) -> None:
for plugin in self._plugins.values():
try:
await plugin.uninstall()
except Exception as e: # pylint: disable=W0703
logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager):
"""BOT 除自身外的生命周期管理类"""

View File

@ -1,117 +0,0 @@
"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化"""
from typing import Any, AsyncIterable, Optional
import httpcore
from httpx import (
AsyncByteStream,
AsyncHTTPTransport as DefaultAsyncHTTPTransport,
Limits,
Response as DefaultResponse,
Timeout,
)
from telegram.request import HTTPXRequest as DefaultHTTPXRequest
try:
import ujson as jsonlib
except ImportError:
import json as jsonlib
__all__ = ("HTTPXRequest",)
class Response(DefaultResponse):
def json(self, **kwargs: Any) -> Any:
# noinspection PyProtectedMember
from httpx._utils import guess_json_utf
if self.charset_encoding is None and self.content and len(self.content) > 3:
encoding = guess_json_utf(self.content)
if encoding is not None:
return jsonlib.loads(self.content.decode(encoding), **kwargs)
return jsonlib.loads(self.text, **kwargs)
# noinspection PyProtectedMember
class AsyncHTTPTransport(DefaultAsyncHTTPTransport):
async def handle_async_request(self, request) -> Response:
from httpx._transports.default import (
map_httpcore_exceptions,
AsyncResponseStream,
)
if not isinstance(request.stream, AsyncByteStream):
raise AssertionError
req = httpcore.Request(
method=request.method,
url=httpcore.URL(
scheme=request.url.raw_scheme,
host=request.url.raw_host,
port=request.url.port,
target=request.url.raw_path,
),
headers=request.headers.raw,
content=request.stream,
extensions=request.extensions,
)
with map_httpcore_exceptions():
resp = await self._pool.handle_async_request(req)
if not isinstance(resp.stream, AsyncIterable):
raise AssertionError
return Response(
status_code=resp.status,
headers=resp.headers,
stream=AsyncResponseStream(resp.stream),
extensions=resp.extensions,
)
class HTTPXRequest(DefaultHTTPXRequest):
def __init__( # pylint: disable=W0231
self,
connection_pool_size: int = 1,
proxy_url: str = None,
read_timeout: Optional[float] = 5.0,
write_timeout: Optional[float] = 5.0,
connect_timeout: Optional[float] = 5.0,
pool_timeout: Optional[float] = 1.0,
http_version: str = "1.1",
):
self._http_version = http_version
timeout = Timeout(
connect=connect_timeout,
read=read_timeout,
write=write_timeout,
pool=pool_timeout,
)
limits = Limits(
max_connections=connection_pool_size,
max_keepalive_connections=connection_pool_size,
)
if http_version not in ("1.1", "2"):
raise ValueError("`http_version` must be either '1.1' or '2'.")
http1 = http_version == "1.1"
self._client_kwargs = dict(
timeout=timeout,
proxies=proxy_url,
limits=limits,
transport=AsyncHTTPTransport(limits=limits),
http1=http1,
http2=not http1,
)
try:
self._client = self._build_client()
except ImportError as exc:
if "httpx[http2]" not in str(exc) and "httpx[socks]" not in str(exc):
raise exc
if "httpx[socks]" in str(exc):
raise RuntimeError(
"To use Socks5 proxies, PTB must be installed via `pip install " "python-telegram-bot[socks]`."
) from exc
raise RuntimeError(
"To use HTTP/2, PTB must be installed via `pip install " "python-telegram-bot[http2]`."
) from exc

View File

@ -1,8 +1,8 @@
"""插件""" """插件"""
from core.plugin._handler import conversation, error_handler, handler from gram_core.plugin._handler import conversation, error_handler, handler
from core.plugin._job import TimeType, job from gram_core.plugin._job import TimeType, job
from core.plugin._plugin import Plugin, PluginType, get_all_plugins from gram_core.plugin._plugin import Plugin, PluginType, get_all_plugins
__all__ = ( __all__ = (
"Plugin", "Plugin",

View File

@ -1,178 +0,0 @@
from pathlib import Path
from typing import List, Optional, Union, TYPE_CHECKING
import aiofiles
import httpx
from httpx import UnsupportedProtocol
from telegram import Chat, Message, ReplyKeyboardRemove, Update
from telegram.error import Forbidden, NetworkError
from telegram.ext import CallbackContext, ConversationHandler, Job
from core.dependence.redisdb import RedisDB
from core.plugin._handler import conversation, handler
from utils.const import CACHE_DIR, REQUEST_HEADERS
from utils.error import UrlResourcesNotFoundError
from utils.helpers import sha1
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
try:
import ujson as json
except ImportError:
import json
__all__ = (
"PluginFuncs",
"ConversationFuncs",
)
class PluginFuncs:
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError("No application was set for this PluginManager.")
return self._application
async def _delete_message(self, context: CallbackContext) -> None:
job = context.job
message_id = job.data
chat_info = f"chat_id[{job.chat_id}]"
try:
chat = await self.get_chat(job.chat_id)
full_name = chat.full_name
if full_name:
chat_info = f"{full_name}[{chat.id}]"
else:
chat_info = f"{chat.title}[{chat.id}]"
except (NetworkError, Forbidden) as exc:
logger.warning("获取 chat info 失败 %s", exc.message)
except Exception as exc:
logger.warning("获取 chat info 消息失败 %s", str(exc))
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
try:
# noinspection PyTypeChecker
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
except NetworkError as exc:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
except Forbidden as exc:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
except Exception as exc:
logger.error("删除消息 %s message_id[%s] 失败", chat_info, message_id, exc_info=exc)
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat:
application = self.application
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
if not redis_db:
return await application.bot.get_chat(chat_id)
qname = f"bot:chat:{chat_id}"
data = await redis_db.client.get(qname)
if data:
json_data = json.loads(data)
return Chat.de_json(json_data, application.telegram.bot)
chat_info = await application.telegram.bot.get_chat(chat_id)
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
return chat_info
def add_delete_message_job(
self,
message: Optional[Union[int, Message]] = None,
*,
delay: int = 60,
name: Optional[str] = None,
chat: Optional[Union[int, Chat]] = None,
context: Optional[CallbackContext] = None,
) -> Job:
"""延迟删除消息"""
if isinstance(message, Message):
if chat is None:
chat = message.chat_id
message = message.id
chat = chat.id if isinstance(chat, Chat) else chat
job_queue = self.application.job_queue or context.job_queue
if job_queue is None or chat is None:
raise RuntimeError
return job_queue.run_once(
callback=self._delete_message,
when=delay,
data=message,
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
chat_id=chat,
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
)
@staticmethod
async def download_resource(url: str, return_path: bool = False) -> str:
url_sha1 = sha1(url) # url 的 hash 值
pathed_url = Path(url)
file_name = url_sha1 + pathed_url.suffix
file_path = CACHE_DIR.joinpath(file_name)
if not file_path.exists(): # 若文件不存在,则下载
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
try:
response = await client.get(url)
except UnsupportedProtocol:
logger.error("链接不支持 url[%s]", url)
return ""
if response.is_error:
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
raise UrlResourcesNotFoundError(url)
if response.status_code != 200:
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
raise UrlResourcesNotFoundError(url)
async with aiofiles.open(file_path, mode="wb") as f:
await f.write(response.content)
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
return file_path if return_path else Path(file_path).as_uri()
@staticmethod
def get_args(context: CallbackContext) -> List[str]:
args = context.args
match = context.match
if args is None:
if match is not None and (command := match.groups()[0]):
temp = []
command_parts = command.split(" ")
for command_part in command_parts:
if command_part:
temp.append(command_part)
return temp
return []
if len(args) >= 1:
return args
return []
class ConversationFuncs:
@conversation.fallback
@handler.command(command="cancel", block=False)
async def cancel(self, update: Update, _) -> int:
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
return ConversationHandler.END

View File

@ -1,380 +0,0 @@
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from importlib import import_module
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Pattern,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel
# noinspection PyProtectedMember
from telegram._utils.defaultvalue import DEFAULT_TRUE
# noinspection PyProtectedMember
from telegram._utils.types import DVInput
from telegram.ext import BaseHandler
from telegram.ext.filters import BaseFilter
from typing_extensions import ParamSpec
from core.handler.callbackqueryhandler import CallbackQueryHandler
from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS
if TYPE_CHECKING:
from core.builtins.dispatcher import AbstractDispatcher
__all__ = (
"handler",
"conversation",
"ConversationDataType",
"ConversationData",
"HandlerData",
"ErrorHandlerData",
"error_handler",
)
P = ParamSpec("P")
T = TypeVar("T")
R = TypeVar("R")
UT = TypeVar("UT")
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
HandlerCls = Type[HandlerType]
Module = import_module("telegram.ext")
HANDLER_DATA_ATTR_NAME = "_handler_datas"
"""用于储存生成 handler 时所需要的参数(例如 block的属性名"""
ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block的属性名"""
WRAPPER_ASSIGNMENTS = list(
set(
_WRAPPER_ASSIGNMENTS
+ [
HANDLER_DATA_ATTR_NAME,
ERROR_HANDLER_ATTR_NAME,
CONVERSATION_HANDLER_ATTR_NAME,
]
)
)
@dataclass(init=True)
class HandlerData:
type: Type[HandlerType]
admin: bool
kwargs: Dict[str, Any]
dispatcher: Optional[Type["AbstractDispatcher"]] = None
class _Handler:
_type: Type["HandlerType"]
kwargs: Dict[str, Any] = {}
def __init_subclass__(cls, **kwargs) -> None:
"""用于获取 python-telegram-bot 中对应的 handler class"""
handler_name = f"{cls.__name__.strip('_')}Handler"
if handler_name == "CallbackQueryHandler":
cls._type = CallbackQueryHandler
return
cls._type = getattr(Module, handler_name, None)
def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None:
self.dispatcher = dispatcher
self.admin = admin
self.kwargs = kwargs
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
"""decorator实现从 func 生成 Handler"""
handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, [])
handler_datas.append(
HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher)
)
setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas)
return func
class _CallbackQuery(_Handler):
def __init__(
self,
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
*,
block: DVInput[bool] = DEFAULT_TRUE,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
class _ChatJoinRequest(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher)
class _ChatMember(_Handler):
def __init__(
self,
chat_member_types: int = -1,
*,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher)
class _ChosenInlineResult(_Handler):
def __init__(
self,
block: DVInput[bool] = DEFAULT_TRUE,
*,
pattern: Union[str, Pattern] = None,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(block=block, pattern=pattern, dispatcher=dispatcher)
class _Command(_Handler):
def __init__(
self,
command: Union[str, List[str]],
filters: "BaseFilter" = None,
*,
block: DVInput[bool] = DEFAULT_TRUE,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_Command, self).__init__(
command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher
)
class _InlineQuery(_Handler):
def __init__(
self,
pattern: Union[str, Pattern] = None,
chat_types: List[str] = None,
*,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher)
class _Message(_Handler):
def __init__(
self,
filters: BaseFilter,
*,
block: DVInput[bool] = DEFAULT_TRUE,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
) -> None:
super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher)
class _PollAnswer(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher)
class _Poll(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_Poll, self).__init__(block=block, dispatcher=dispatcher)
class _PreCheckoutQuery(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher)
class _Prefix(_Handler):
def __init__(
self,
prefix: str,
command: str,
filters: BaseFilter = None,
*,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_Prefix, self).__init__(
prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher
)
class _ShippingQuery(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher)
class _StringCommand(_Handler):
def __init__(
self,
command: str,
*,
admin: bool = False,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher)
class _StringRegex(_Handler):
def __init__(
self,
pattern: Union[str, Pattern],
*,
block: DVInput[bool] = DEFAULT_TRUE,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
class _Type(_Handler):
# noinspection PyShadowingBuiltins
def __init__(
self,
type: Type[UT], # pylint: disable=W0622
strict: bool = False,
*,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
): # pylint: disable=redefined-builtin
super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher)
# noinspection PyPep8Naming
class handler(_Handler):
callback_query = _CallbackQuery
chat_join_request = _ChatJoinRequest
chat_member = _ChatMember
chosen_inline_result = _ChosenInlineResult
command = _Command
inline_query = _InlineQuery
message = _Message
poll_answer = _PollAnswer
pool = _Poll
pre_checkout_query = _PreCheckoutQuery
prefix = _Prefix
shipping_query = _ShippingQuery
string_command = _StringCommand
string_regex = _StringRegex
type = _Type
def __init__(
self,
handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]],
*,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
**kwargs: P.kwargs,
) -> None:
self._type = handler_type
super().__init__(admin=admin, dispatcher=dispatcher, **kwargs)
class ConversationDataType(Enum):
"""conversation handler 的类型"""
Entry = "entry"
State = "state"
Fallback = "fallback"
class ConversationData(BaseModel):
"""用于储存 conversation handler 的数据"""
type: ConversationDataType
state: Optional[Any] = None
class _ConversationType:
_type: ClassVar[ConversationDataType]
def __init_subclass__(cls, **kwargs) -> None:
cls._type = ConversationDataType(cls.__name__.lstrip("_").lower())
def _entry(func: Callable[P, R]) -> Callable[P, R]:
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry))
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapped
class _State(_ConversationType):
def __init__(self, state: Any) -> None:
self.state = state
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state))
return func
def _fallback(func: Callable[P, R]) -> Callable[P, R]:
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback))
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapped
# noinspection PyPep8Naming
class conversation(_Handler):
entry_point = _entry
state = _State
fallback = _fallback
@dataclass(init=True)
class ErrorHandlerData:
block: bool
func: Optional[Callable] = None
# noinspection PyPep8Naming
class error_handler:
_func: Callable[P, R]
def __init__(
self,
*,
block: bool = DEFAULT_TRUE,
):
self._block = block
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
self._func = func
wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self)
handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, [])
handler_datas.append(ErrorHandlerData(block=self._block))
setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas)
return self._func

View File

@ -1,173 +0,0 @@
"""插件"""
import datetime
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
# noinspection PyProtectedMember
from telegram._utils.types import JSONDict
# noinspection PyProtectedMember
from telegram.ext._utils.types import JobCallback
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from core.builtins.dispatcher import AbstractDispatcher
__all__ = ["TimeType", "job", "JobData"]
P = ParamSpec("P")
T = TypeVar("T")
R = TypeVar("R")
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
_JOB_ATTR_NAME = "_job_data"
@dataclass(init=True)
class JobData:
name: str
data: Any
chat_id: int
user_id: int
type: str
job_kwargs: JSONDict = field(default_factory=dict)
kwargs: JSONDict = field(default_factory=dict)
dispatcher: Optional[Type["AbstractDispatcher"]] = None
class _Job:
kwargs: Dict = {}
def __init__(
self,
name: str = None,
data: object = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
**kwargs,
):
self.name = name
self.data = data
self.chat_id = chat_id
self.user_id = user_id
self.job_kwargs = {} if job_kwargs is None else job_kwargs
self.kwargs = kwargs
if dispatcher is None:
from core.builtins.dispatcher import JobDispatcher
dispatcher = JobDispatcher
self.dispatcher = dispatcher
def __call__(self, func: JobCallback) -> JobCallback:
data = JobData(
name=self.name,
data=self.data,
chat_id=self.chat_id,
user_id=self.user_id,
job_kwargs=self.job_kwargs,
kwargs=self.kwargs,
type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
dispatcher=self.dispatcher,
)
if hasattr(func, _JOB_ATTR_NAME):
job_datas = getattr(func, _JOB_ATTR_NAME)
job_datas.append(data)
setattr(func, _JOB_ATTR_NAME, job_datas)
else:
setattr(func, _JOB_ATTR_NAME, [data])
return func
class _RunOnce(_Job):
def __init__(
self,
when: TimeType,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when)
class _RunRepeating(_Job):
def __init__(
self,
interval: Union[float, datetime.timedelta],
first: TimeType = None,
last: TimeType = None,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(
name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last
)
class _RunMonthly(_Job):
def __init__(
self,
when: datetime.time,
day: int,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day)
class _RunDaily(_Job):
def __init__(
self,
time: datetime.time,
days: Tuple[int, ...] = tuple(range(7)),
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days)
class _RunCustom(_Job):
def __init__(
self,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher)
# noinspection PyPep8Naming
class job:
run_once = _RunOnce
run_repeating = _RunRepeating
run_monthly = _RunMonthly
run_daily = _RunDaily
run_custom = _RunCustom

View File

@ -1,314 +0,0 @@
"""插件"""
import asyncio
from abc import ABC
from dataclasses import asdict
from datetime import timedelta
from functools import partial, wraps
from itertools import chain
from multiprocessing import RLock as Lock
from types import MethodType
from typing import (
Any,
ClassVar,
Dict,
Iterable,
List,
Optional,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel
from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler
from typing_extensions import ParamSpec
from core.handler.adminhandler import AdminHandler
from core.plugin._funcs import ConversationFuncs, PluginFuncs
from core.plugin._handler import ConversationDataType
from utils.const import WRAPPER_ASSIGNMENTS
from utils.helpers import isabstract
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData
from core.plugin._job import JobData
from multiprocessing.synchronize import RLock as LockType
__all__ = ("Plugin", "PluginType", "get_all_plugins")
wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS)
P = ParamSpec("P")
T = TypeVar("T")
R = TypeVar("R")
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
"""用于储存生成 handler 时所需要的参数(例如 block的属性名"""
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block的属性名"""
_ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
_JOB_ATTR_NAME = "_job_data"
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
class _Plugin(PluginFuncs):
"""插件"""
_lock: ClassVar["LockType"] = Lock()
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
_installed: bool = False
_handlers: Optional[List[HandlerType]] = None
_error_handlers: Optional[List["ErrorHandlerData"]] = None
_jobs: Optional[List[Job]] = None
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError("No application was set for this Plugin.")
return self._application
@property
def handlers(self) -> List[HandlerType]:
"""该插件的所有 handler"""
with self._lock:
if self._handlers is None:
self._handlers = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
):
for data in datas:
data: "HandlerData"
if data.admin:
self._handlers.append(
AdminHandler(
handler=data.type(
callback=func,
**data.kwargs,
),
application=self.application,
)
)
else:
self._handlers.append(
data.type(
callback=func,
**data.kwargs,
)
)
return self._handlers
@property
def error_handlers(self) -> List["ErrorHandlerData"]:
with self._lock:
if self._error_handlers is None:
self._error_handlers = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, []))
):
for data in datas:
data: "ErrorHandlerData"
data.func = func
self._error_handlers.append(data)
return self._error_handlers
def _install_jobs(self) -> None:
if self._jobs is None:
self._jobs = []
for attr in dir(self):
# noinspection PyUnboundLocalVariable
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _JOB_ATTR_NAME, []))
):
for data in datas:
data: "JobData"
self._jobs.append(
getattr(self.application.telegram.job_queue, data.type)(
callback=func,
**data.kwargs,
**{
key: value
for key, value in asdict(data).items()
if key not in ["type", "kwargs", "dispatcher"]
},
)
)
@property
def jobs(self) -> List[Job]:
with self._lock:
if self._jobs is None:
self._jobs = []
self._install_jobs()
return self._jobs
async def initialize(self) -> None:
"""初始化插件"""
async def shutdown(self) -> None:
"""销毁插件"""
async def install(self) -> None:
"""安装"""
group = id(self)
if not self._installed:
await self.initialize()
# initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题
async with self._asyncio_lock:
self._install_jobs()
for h in self.handlers:
if not isinstance(h, TypeHandler):
self.application.telegram.add_handler(h, group)
else:
self.application.telegram.add_handler(h, -1)
for h in self.error_handlers:
self.application.telegram.add_error_handler(h.func, h.block)
self._installed = True
async def uninstall(self) -> None:
"""卸载"""
group = id(self)
with self._lock:
if self._installed:
if group in self.application.telegram.handlers:
del self.application.telegram.handlers[id(self)]
for h in self.handlers:
if isinstance(h, TypeHandler):
self.application.telegram.remove_handler(h, -1)
for h in self.error_handlers:
self.application.telegram.remove_error_handler(h.func)
for j in self.application.telegram.job_queue.jobs():
j.schedule_removal()
await self.shutdown()
self._installed = False
async def reload(self) -> None:
await self.uninstall()
await self.install()
class _Conversation(_Plugin, ConversationFuncs, ABC):
"""Conversation类"""
# noinspection SpellCheckingInspection
class Config(BaseModel):
allow_reentry: bool = False
per_chat: bool = True
per_user: bool = True
per_message: bool = False
conversation_timeout: Optional[Union[float, timedelta]] = None
name: Optional[str] = None
map_to_parent: Optional[Dict[object, object]] = None
block: bool = False
def __init_subclass__(cls, **kwargs):
cls._conversation_kwargs = kwargs
super(_Conversation, cls).__init_subclass__()
return cls
@property
def handlers(self) -> List[HandlerType]:
with self._lock:
if self._handlers is None:
self._handlers = []
entry_points: List[HandlerType] = []
states: Dict[Any, List[HandlerType]] = {}
fallbacks: List[HandlerType] = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and (func := getattr(self, attr, None)) is not None
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
):
conversation_data: "ConversationData"
handlers: List[HandlerType] = []
for data in datas:
if data.admin:
handlers.append(
AdminHandler(
handler=data.type(
callback=func,
**data.kwargs,
),
application=self.application,
)
)
else:
handlers.append(
data.type(
callback=func,
**data.kwargs,
)
)
if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None):
if (_type := conversation_data.type) is ConversationDataType.Entry:
entry_points.extend(handlers)
elif _type is ConversationDataType.State:
if conversation_data.state in states:
states[conversation_data.state].extend(handlers)
else:
states[conversation_data.state] = handlers
elif _type is ConversationDataType.Fallback:
fallbacks.extend(handlers)
else:
self._handlers.extend(handlers)
else:
self._handlers.extend(handlers)
if entry_points and states and fallbacks:
kwargs = self._conversation_kwargs
kwargs.update(self.Config().dict())
self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs))
else:
temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks}
reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items()))
logger.warning(
"'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason)
)
return self._handlers
class Plugin(_Plugin, ABC):
"""插件"""
Conversation = _Conversation
PluginType = TypeVar("PluginType", bound=_Plugin)
def get_all_plugins() -> Iterable[Type[PluginType]]:
"""获取所有 Plugin 的子类"""
return filter(
lambda x: x.__name__[0] != "_" and not isabstract(x),
chain(Plugin.__subclasses__(), _Conversation.__subclasses__()),
)

View File

@ -1,67 +0,0 @@
import asyncio
import contextlib
from typing import Callable, Coroutine, Any, Union, List, Dict, Optional, TypeVar, Type
from telegram.error import RetryAfter
from telegram.ext import BaseRateLimiter, ApplicationHandlerStop
from utils.log import logger
JSONDict: Type[dict[str, Any]] = Dict[str, Any]
RL_ARGS = TypeVar("RL_ARGS")
class RateLimiter(BaseRateLimiter[int]):
_lock = asyncio.Lock()
__slots__ = (
"_limiter_info",
"_retry_after_event",
)
def __init__(self):
self._limiter_info: Dict[Union[str, int], float] = {}
self._retry_after_event = asyncio.Event()
self._retry_after_event.set()
async def process_request(
self,
callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, List[JSONDict]]]],
args: Any,
kwargs: Dict[str, Any],
endpoint: str,
data: Dict[str, Any],
rate_limit_args: Optional[RL_ARGS],
) -> Union[bool, JSONDict, List[JSONDict]]:
chat_id = data.get("chat_id")
with contextlib.suppress(ValueError, TypeError):
chat_id = int(chat_id)
loop = asyncio.get_running_loop()
time = loop.time()
await self._retry_after_event.wait()
async with self._lock:
chat_limit_time = self._limiter_info.get(chat_id)
if chat_limit_time:
if time >= chat_limit_time:
raise ApplicationHandlerStop
del self._limiter_info[chat_id]
try:
return await callback(*args, **kwargs)
except RetryAfter as exc:
logger.warning("chat_id[%s] 触发洪水限制 当前被服务器限制 retry_after[%s]秒", chat_id, exc.retry_after)
self._limiter_info[chat_id] = time + (exc.retry_after * 2)
sleep = exc.retry_after + 0.1
self._retry_after_event.clear()
await asyncio.sleep(sleep)
finally:
self._retry_after_event.set()
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass

View File

@ -1,97 +1,3 @@
from typing import List, Union from gram_core.services.cookies.cache import PublicCookiesCache
from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.dependence.redisdb import RedisDB
from core.services.cookies.error import CookiesCachePoolExhausted
from utils.error import RegionNotFoundError
__all__ = ("PublicCookiesCache",) __all__ = ("PublicCookiesCache",)
class PublicCookiesCache(BaseService.Component):
"""使用优先级(score)进行排序对使用次数最少的Cookies进行审核"""
def __init__(self, redis: RedisDB):
self.client = redis.client
self.score_qname = "cookie:public"
self.user_times_qname = "cookie:public:times"
self.end = 20
self.user_times_ttl = 60 * 60 * 24
def get_public_cookies_queue_name(self, region: RegionEnum):
if region == RegionEnum.HYPERION:
return f"{self.score_qname}:yuanshen"
if region == RegionEnum.HOYOLAB:
return f"{self.score_qname}:genshin"
raise RegionNotFoundError(region.name)
async def putback_public_cookies(self, uid: int, region: RegionEnum):
"""重新添加单个到缓存列表
:param uid:
:param region:
:return:
"""
qname = self.get_public_cookies_queue_name(region)
score_maps = {f"{uid}": 0}
result = await self.client.zrem(qname, f"{uid}")
if result == 1:
await self.client.zadd(qname, score_maps)
return result
async def add_public_cookies(self, uid: Union[List[int], int], region: RegionEnum):
"""单个或批量添加到缓存列表
:param uid:
:param region:
:return: 成功返回列表大小
"""
qname = self.get_public_cookies_queue_name(region)
if isinstance(uid, int):
score_maps = {f"{uid}": 0}
elif isinstance(uid, list):
score_maps = {f"{i}": 0 for i in uid}
else:
raise TypeError("uid variable type error")
async with self.client.pipeline(transaction=True) as pipe:
# nx:只添加新元素。不要更新已经存在的元素
await pipe.zadd(qname, score_maps, nx=True)
await pipe.zcard(qname)
add, count = await pipe.execute()
return int(add), count
async def get_public_cookies(self, region: RegionEnum):
"""从缓存列表获取
:param region:
:return:
"""
qname = self.get_public_cookies_queue_name(region)
scores = await self.client.zrange(qname, 0, self.end, withscores=True, score_cast_func=int)
if len(scores) <= 0:
raise CookiesCachePoolExhausted
key = scores[0][0]
score = scores[0][1]
async with self.client.pipeline(transaction=True) as pipe:
await pipe.zincrby(qname, 1, key)
await pipe.execute()
return int(key), score + 1
async def delete_public_cookies(self, uid: int, region: RegionEnum):
qname = self.get_public_cookies_queue_name(region)
async with self.client.pipeline(transaction=True) as pipe:
await pipe.zrem(qname, uid)
return await pipe.execute()
async def get_public_cookies_count(self, limit: bool = True):
async with self.client.pipeline(transaction=True) as pipe:
if limit:
await pipe.zcount(0, self.end)
else:
await pipe.zcard(self.score_qname)
return await pipe.execute()
async def incr_by_user_times(self, user_id: Union[List[int], int], amount: int = 1):
qname = f"{self.user_times_qname}:{user_id}"
times = await self.client.incrby(qname, amount)
if times <= 1:
await self.client.expire(qname, self.user_times_ttl)
return times

View File

@ -1,12 +1,3 @@
class CookieServiceError(Exception): from gram_core.services.cookies.error import CookieServiceError, CookiesCachePoolExhausted, TooManyRequestPublicCookies
pass
__all__ = ("CookieServiceError", "CookiesCachePoolExhausted", "TooManyRequestPublicCookies")
class CookiesCachePoolExhausted(CookieServiceError):
def __init__(self):
super().__init__("Cookies cache pool is exhausted")
class TooManyRequestPublicCookies(CookieServiceError):
def __init__(self, user_id):
super().__init__(f"{user_id} too many request public cookies")

View File

@ -1,39 +1,3 @@
import enum from gram_core.services.cookies.models import Cookies, CookiesDataBase, CookiesStatusEnum
from typing import Optional, Dict
from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index
from core.basemodel import RegionEnum
__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum") __all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum")
class CookiesStatusEnum(int, enum.Enum):
STATUS_SUCCESS = 0
INVALID_COOKIES = 1
TOO_MANY_REQUESTS = 2
class Cookies(SQLModel):
__table_args__ = (
Index("index_user_account", "user_id", "account_id", unique=True),
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
)
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
user_id: int = Field(
sa_column=Column(BigInteger()),
)
account_id: int = Field(
default=None,
sa_column=Column(
BigInteger(),
),
)
data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
is_share: Optional[bool] = Field(sa_column=Column(Boolean))
class CookiesDataBase(Cookies, table=True):
__tablename__ = "cookies"

View File

@ -1,55 +1,3 @@
from typing import Optional, List from gram_core.services.cookies.repositories import CookiesRepository
from sqlmodel import select
from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.dependence.database import Database
from core.services.cookies.models import CookiesDataBase as Cookies
from core.sqlmodel.session import AsyncSession
__all__ = ("CookiesRepository",) __all__ = ("CookiesRepository",)
class CookiesRepository(BaseService.Component):
def __init__(self, database: Database):
self.engine = database.engine
async def get(
self,
user_id: int,
account_id: Optional[int] = None,
region: Optional[RegionEnum] = None,
) -> Optional[Cookies]:
async with AsyncSession(self.engine) as session:
statement = select(Cookies).where(Cookies.user_id == user_id)
if account_id is not None:
statement = statement.where(Cookies.account_id == account_id)
if region is not None:
statement = statement.where(Cookies.region == region)
results = await session.exec(statement)
return results.first()
async def add(self, cookies: Cookies) -> None:
async with AsyncSession(self.engine) as session:
session.add(cookies)
await session.commit()
async def update(self, cookies: Cookies) -> Cookies:
async with AsyncSession(self.engine) as session:
session.add(cookies)
await session.commit()
await session.refresh(cookies)
return cookies
async def delete(self, cookies: Cookies) -> None:
async with AsyncSession(self.engine) as session:
await session.delete(cookies)
await session.commit()
async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]:
async with AsyncSession(self.engine) as session:
statement = select(Cookies).where(Cookies.region == region)
results = await session.exec(statement)
cookies = results.all()
return cookies

View File

@ -1,159 +1,80 @@
from typing import List, Optional from gram_core.base_service import BaseService
from gram_core.basemodel import RegionEnum
from gram_core.services.cookies.error import CookieServiceError
from gram_core.services.cookies.models import CookiesStatusEnum, CookiesDataBase as Cookies
from gram_core.services.cookies.services import (
CookiesService,
PublicCookiesService as BasePublicCookiesService,
NeedContinue,
)
from simnet import GenshinClient, Region, Game from simnet import GenshinClient, Region, Game
from simnet.errors import InvalidCookies, BadRequest as SimnetBadRequest, TooManyRequests from simnet.errors import InvalidCookies, TooManyRequests, BadRequest as SimnetBadRequest
from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.services.cookies.cache import PublicCookiesCache
from core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies
from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum
from core.services.cookies.repositories import CookiesRepository
from utils.log import logger from utils.log import logger
__all__ = ("CookiesService", "PublicCookiesService") __all__ = ("CookiesService", "PublicCookiesService")
class CookiesService(BaseService): class PublicCookiesService(BaseService, BasePublicCookiesService):
def __init__(self, cookies_repository: CookiesRepository) -> None: async def check_public_cookie(self, region: RegionEnum, cookies: Cookies, public_id: int):
self._repository: CookiesRepository = cookies_repository if region == RegionEnum.HYPERION:
client = GenshinClient(cookies=cookies.data, region=Region.CHINESE)
async def update(self, cookies: Cookies): elif region == RegionEnum.HOYOLAB:
await self._repository.update(cookies) client = GenshinClient(cookies=cookies.data, region=Region.OVERSEAS, lang="zh-cn")
else:
async def add(self, cookies: Cookies): raise CookieServiceError
await self._repository.add(cookies) try:
if client.account_id is None:
async def get( raise RuntimeError("account_id not found")
self, record_cards = await client.get_record_cards()
user_id: int, for record_card in record_cards:
account_id: Optional[int] = None, if record_card.game == Game.GENSHIN:
region: Optional[RegionEnum] = None, await client.get_partial_genshin_user(record_card.uid)
) -> Optional[Cookies]: break
return await self._repository.get(user_id, account_id, region)
async def delete(self, cookies: Cookies) -> None:
return await self._repository.delete(cookies)
class PublicCookiesService(BaseService):
def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache):
self._cache = public_cookies_cache
self._repository: CookiesRepository = cookies_repository
self.count: int = 0
self.user_times_limiter = 3 * 3
async def initialize(self) -> None:
logger.info("正在初始化公共Cookies池")
await self.refresh()
logger.success("刷新公共Cookies池成功")
async def refresh(self):
"""刷新公共Cookies 定时任务
:return:
"""
user_list: List[int] = []
cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2
for cookies in cookies_list:
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
user_list.append(cookies.user_id)
if len(user_list) > 0:
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION)
logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
user_list.clear()
cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB)
for cookies in cookies_list:
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
user_list.append(cookies.user_id)
if len(user_list) > 0:
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB)
logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL):
"""获取公共Cookies
:param user_id: 用户ID
:param region: 注册的服务器
:return:
"""
user_times = await self._cache.incr_by_user_times(user_id)
if int(user_times) > self.user_times_limiter:
logger.warning("用户 %s 使用公共Cookies次数已经到达上限", user_id)
raise TooManyRequestPublicCookies(user_id)
while True:
public_id, count = await self._cache.get_public_cookies(region)
cookies = await self._repository.get(public_id, region=region)
if cookies is None:
await self._cache.delete_public_cookies(public_id, region)
continue
if region == RegionEnum.HYPERION:
client = GenshinClient(cookies=cookies.data, region=Region.CHINESE)
elif region == RegionEnum.HOYOLAB:
client = GenshinClient(cookies=cookies.data, region=Region.OVERSEAS, lang="zh-cn")
else: else:
raise CookieServiceError accounts = await client.get_game_accounts()
try: for account in accounts:
if client.account_id is None: if account.game == Game.GENSHIN:
raise RuntimeError("account_id not found") await client.get_partial_genshin_user(account.uid)
record_cards = await client.get_record_cards()
for record_card in record_cards:
if record_card.game == Game.GENSHIN:
await client.get_partial_genshin_user(record_card.uid)
break break
else: except InvalidCookies as exc:
accounts = await client.get_game_accounts() if exc.ret_code in (10001, -100):
for account in accounts: logger.warning("用户 [%s] Cookies无效", public_id)
if account.game == Game.GENSHIN: elif exc.ret_code == 10103:
await client.get_partial_genshin_user(account.uid) logger.warning("用户 [%s] Cookies有效但没有绑定到游戏帐户", public_id)
break else:
except InvalidCookies as exc: logger.warning("Cookies无效 ")
if exc.ret_code in (10001, -100): logger.exception(exc)
logger.warning("用户 [%s] Cookies无效", public_id) cookies.status = CookiesStatusEnum.INVALID_COOKIES
elif exc.ret_code == 10103: await self._repository.update(cookies)
logger.warning("用户 [%s] Cookies有效但没有绑定到游戏帐户", public_id) await self._cache.delete_public_cookies(cookies.user_id, region)
else: raise NeedContinue
logger.warning("Cookies无效 ") except TooManyRequests:
logger.exception(exc) logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id)
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS
await self._repository.update(cookies)
await self._cache.delete_public_cookies(cookies.user_id, region)
raise NeedContinue
except SimnetBadRequest as exc:
if "invalid content type" in exc.message:
raise exc
if exc.ret_code == 1034:
logger.warning("用户 [%s] 触发验证", public_id)
else:
logger.warning("用户 [%s] 获取账号信息发生错误,错误信息为", public_id)
logger.exception(exc)
await self._cache.delete_public_cookies(cookies.user_id, region)
raise NeedContinue
except RuntimeError as exc:
if "account_id not found" in str(exc):
cookies.status = CookiesStatusEnum.INVALID_COOKIES cookies.status = CookiesStatusEnum.INVALID_COOKIES
await self._repository.update(cookies) await self._repository.update(cookies)
await self._cache.delete_public_cookies(cookies.user_id, region) await self._cache.delete_public_cookies(cookies.user_id, region)
continue raise NeedContinue
except TooManyRequests: raise exc
logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id) except Exception as exc:
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS await self._cache.delete_public_cookies(cookies.user_id, region)
await self._repository.update(cookies) raise exc
await self._cache.delete_public_cookies(cookies.user_id, region) finally:
continue await client.shutdown()
except SimnetBadRequest as exc:
if "invalid content type" in exc.message:
raise exc
if exc.ret_code == 1034:
logger.warning("用户 [%s] 触发验证", public_id)
else:
logger.warning("用户 [%s] 获取账号信息发生错误,错误信息为", public_id)
logger.exception(exc)
await self._cache.delete_public_cookies(cookies.user_id, region)
continue
except RuntimeError as exc:
if "account_id not found" in str(exc):
cookies.status = CookiesStatusEnum.INVALID_COOKIES
await self._repository.update(cookies)
await self._cache.delete_public_cookies(cookies.user_id, region)
continue
raise exc
except Exception as exc:
await self._cache.delete_public_cookies(cookies.user_id, region)
raise exc
finally:
await client.shutdown()
logger.info("用户 user_id[%s] 请求用户 user_id[%s] 的公共Cookies 该Cookies使用次数为%s", user_id, public_id, count)
return cookies
async def undo(self, user_id: int, cookies: Optional[Cookies] = None, status: Optional[CookiesStatusEnum] = None):
await self._cache.incr_by_user_times(user_id, -1)
if cookies is not None and status is not None:
cookies.status = status
await self._repository.update(cookies)
await self._cache.delete_public_cookies(cookies.user_id, cookies.region)
logger.info("用户 user_id[%s] 反馈用户 user_id[%s] 的Cookies状态为 %s", user_id, cookies.user_id, status.name)
else:
logger.info("用户 user_id[%s] 撤销一次公共Cookies计数", user_id)

View File

@ -1,23 +1,3 @@
from typing import Optional from gram_core.services.devices.models import Devices, DevicesDataBase
from sqlmodel import SQLModel, Field, Column, Integer, BigInteger
__all__ = ("Devices", "DevicesDataBase") __all__ = ("Devices", "DevicesDataBase")
class Devices(SQLModel):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
account_id: int = Field(
default=None,
sa_column=Column(
BigInteger(),
),
)
device_id: str = Field()
device_fp: str = Field()
device_name: Optional[str] = Field(default=None)
class DevicesDataBase(Devices, table=True):
__tablename__ = "devices"

View File

@ -1,41 +1,3 @@
from typing import Optional from gram_core.services.devices.repositories import DevicesRepository
from sqlmodel import select
from core.base_service import BaseService
from core.dependence.database import Database
from core.services.devices.models import DevicesDataBase as Devices
from core.sqlmodel.session import AsyncSession
__all__ = ("DevicesRepository",) __all__ = ("DevicesRepository",)
class DevicesRepository(BaseService.Component):
def __init__(self, database: Database):
self.engine = database.engine
async def get(
self,
account_id: int,
) -> Optional[Devices]:
async with AsyncSession(self.engine) as session:
statement = select(Devices).where(Devices.account_id == account_id)
results = await session.exec(statement)
return results.first()
async def add(self, devices: Devices) -> None:
async with AsyncSession(self.engine) as session:
session.add(devices)
await session.commit()
async def update(self, devices: Devices) -> Devices:
async with AsyncSession(self.engine) as session:
session.add(devices)
await session.commit()
await session.refresh(devices)
return devices
async def delete(self, devices: Devices) -> None:
async with AsyncSession(self.engine) as session:
await session.delete(devices)
await session.commit()

View File

@ -1,25 +1,3 @@
from typing import Optional from gram_core.services.devices.services import DevicesService
from core.base_service import BaseService __all__ = ("DevicesService",)
from core.services.devices.repositories import DevicesRepository
from core.services.devices.models import DevicesDataBase as Devices
class DevicesService(BaseService):
def __init__(self, devices_repository: DevicesRepository) -> None:
self._repository: DevicesRepository = devices_repository
async def update(self, devices: Devices):
await self._repository.update(devices)
async def add(self, devices: Devices):
await self._repository.add(devices)
async def get(
self,
account_id: int,
) -> Optional[Devices]:
return await self._repository.get(account_id)
async def delete(self, devices: Devices) -> None:
return await self._repository.delete(devices)

View File

@ -1,2 +1,3 @@
class PlayerNotFoundError(Exception): from gram_core.services.players.error import PlayerNotFoundError
pass
__all__ = ("PlayerNotFoundError",)

View File

@ -1,96 +1,8 @@
from datetime import datetime from gram_core.services.players.models import (
from typing import Optional Player,
PlayersDataBase,
from pydantic import BaseModel, BaseSettings PlayerInfo,
from sqlalchemy import TypeDecorator PlayerInfoSQLModel,
from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime )
from core.basemodel import RegionEnum
try:
import ujson as jsonlib
except ImportError:
import json as jsonlib
__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel") __all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel")
class Player(SQLModel):
__table_args__ = (
Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
)
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
)
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
account_id: int = Field(default=None, primary_key=True, sa_column=Column(BigInteger()))
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
is_chosen: Optional[bool] = Field(sa_column=Column(Boolean))
class PlayersDataBase(Player, table=True):
__tablename__ = "players"
class ExtraPlayerInfo(BaseModel):
class Config(BaseSettings.Config):
json_loads = jsonlib.loads
json_dumps = jsonlib.dumps
waifu_id: Optional[int] = None
class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223
impl = VARCHAR(length=521)
cache_ok = True
def process_bind_param(self, value, dialect):
"""
:param value: ExtraPlayerInfo | obj | None
:param dialect:
:return:
"""
if value is not None:
if isinstance(value, ExtraPlayerInfo):
return value.json()
raise TypeError
return value
def process_result_value(self, value, dialect):
"""
:param value: str | obj | None
:param dialect:
:return:
"""
if value is not None:
return ExtraPlayerInfo.parse_raw(value)
return None
class PlayerInfo(SQLModel):
__table_args__ = (
Index("index_user_account_player", "user_id", "player_id", unique=True),
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
)
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
)
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
nickname: Optional[str] = Field()
signature: Optional[str] = Field()
hand_image: Optional[int] = Field()
name_card: Optional[int] = Field()
extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType))
create_time: Optional[datetime] = Field(
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
)
last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
is_update: Optional[bool] = Field(sa_column=Column(Boolean))
class PlayerInfoSQLModel(PlayerInfo, table=True):
__tablename__ = "players_info"

View File

@ -1,110 +1,3 @@
from typing import List, Optional from gram_core.services.players.repositories import PlayersRepository, PlayerInfoRepository
from sqlmodel import select, delete
from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.dependence.database import Database
from core.services.players.models import PlayerInfoSQLModel
from core.services.players.models import PlayersDataBase as Player
from core.sqlmodel.session import AsyncSession
__all__ = ("PlayersRepository", "PlayerInfoRepository") __all__ = ("PlayersRepository", "PlayerInfoRepository")
class PlayersRepository(BaseService.Component):
def __init__(self, database: Database):
self.engine = database.engine
async def get(
self,
user_id: int,
player_id: Optional[int] = None,
account_id: Optional[int] = None,
region: Optional[RegionEnum] = None,
is_chosen: Optional[bool] = None,
) -> Optional[Player]:
async with AsyncSession(self.engine) as session:
statement = select(Player).where(Player.user_id == user_id)
if player_id is not None:
statement = statement.where(Player.player_id == player_id)
if account_id is not None:
statement = statement.where(Player.account_id == account_id)
if region is not None:
statement = statement.where(Player.region == region)
if is_chosen is not None:
statement = statement.where(Player.is_chosen == is_chosen)
results = await session.exec(statement)
return results.first()
async def add(self, player: Player) -> None:
async with AsyncSession(self.engine) as session:
session.add(player)
await session.commit()
await session.refresh(player)
async def delete(self, player: Player) -> None:
async with AsyncSession(self.engine) as session:
await session.delete(player)
await session.commit()
async def update(self, player: Player) -> None:
async with AsyncSession(self.engine) as session:
session.add(player)
await session.commit()
await session.refresh(player)
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
async with AsyncSession(self.engine) as session:
statement = select(Player).where(Player.user_id == user_id)
results = await session.exec(statement)
players = results.all()
return players
class PlayerInfoRepository(BaseService.Component):
def __init__(self, database: Database):
self.engine = database.engine
async def get(
self,
user_id: int,
player_id: int,
) -> Optional[PlayerInfoSQLModel]:
async with AsyncSession(self.engine) as session:
statement = (
select(PlayerInfoSQLModel)
.where(PlayerInfoSQLModel.player_id == player_id)
.where(PlayerInfoSQLModel.user_id == user_id)
)
results = await session.exec(statement)
return results.first()
async def add(self, player: PlayerInfoSQLModel) -> None:
async with AsyncSession(self.engine) as session:
session.add(player)
await session.commit()
async def delete(self, player: PlayerInfoSQLModel) -> None:
async with AsyncSession(self.engine) as session:
await session.delete(player)
await session.commit()
async def delete_by_id(
self,
user_id: int,
player_id: int,
) -> None:
async with AsyncSession(self.engine) as session:
statement = (
delete(PlayerInfoSQLModel)
.where(PlayerInfoSQLModel.player_id == player_id)
.where(PlayerInfoSQLModel.user_id == user_id)
)
await session.execute(statement)
async def update(self, player: PlayerInfoSQLModel) -> None:
async with AsyncSession(self.engine) as session:
session.add(player)
await session.commit()
await session.refresh(player)

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Optional from typing import Optional
from aiohttp import ClientConnectorError from aiohttp import ClientConnectorError
from enkanetwork import ( from enkanetwork import (
@ -11,53 +11,19 @@ from enkanetwork import (
) )
from core.base_service import BaseService from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.config import config from core.config import config
from core.dependence.redisdb import RedisDB from core.dependence.redisdb import RedisDB
from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo
from core.services.players.repositories import PlayersRepository, PlayerInfoRepository from core.services.players.repositories import PlayerInfoRepository
from utils.enkanetwork import RedisCache from utils.enkanetwork import RedisCache
from utils.log import logger from utils.log import logger
from utils.patch.aiohttp import AioHttpTimeoutException from utils.patch.aiohttp import AioHttpTimeoutException
from gram_core.services.players.services import PlayersService
__all__ = ("PlayersService", "PlayerInfoService") __all__ = ("PlayersService", "PlayerInfoService")
class PlayersService(BaseService):
def __init__(self, players_repository: PlayersRepository) -> None:
self._repository = players_repository
async def get(
self,
user_id: int,
player_id: Optional[int] = None,
account_id: Optional[int] = None,
region: Optional[RegionEnum] = None,
is_chosen: Optional[bool] = None,
) -> Optional[Player]:
return await self._repository.get(user_id, player_id, account_id, region, is_chosen)
async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]:
return await self._repository.get(user_id, region=region, is_chosen=True)
async def add(self, player: Player) -> None:
await self._repository.add(player)
async def update(self, player: Player) -> None:
await self._repository.update(player)
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
return await self._repository.get_all_by_user_id(user_id)
async def remove_all_by_user_id(self, user_id: int):
players = await self._repository.get_all_by_user_id(user_id)
for player in players:
await self._repository.delete(player)
async def delete(self, player: Player):
await self._repository.delete(player)
class PlayerInfoService(BaseService): class PlayerInfoService(BaseService):
def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository): def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository):
self.cache = redis.client self.cache = redis.client

View File

@ -1,44 +1,3 @@
import enum from gram_core.services.task.models import Task, TaskStatusEnum, TaskTypeEnum
from datetime import datetime
from typing import Optional, Dict, Any
from sqlalchemy import func, BigInteger, JSON
from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer
__all__ = ("Task", "TaskStatusEnum", "TaskTypeEnum") __all__ = ("Task", "TaskStatusEnum", "TaskTypeEnum")
class TaskStatusEnum(int, enum.Enum):
STATUS_SUCCESS = 0 # 任务执行成功
INVALID_COOKIES = 1 # Cookie无效
ALREADY_CLAIMED = 2 # 已经获取奖励
NEED_CHALLENGE = 3 # 需要验证码
GENSHIN_EXCEPTION = 4 # API异常
TIMEOUT_ERROR = 5 # 请求超时
BAD_REQUEST = 6 # 请求失败
FORBIDDEN = 7 # 这错误一般为通知失败 机器人被用户BAN
class TaskTypeEnum(int, enum.Enum):
SIGN = 0 # 签到
RESIN = 1 # 体力
REALM = 2 # 洞天宝钱
EXPEDITION = 3 # 委托
TRANSFORMER = 4 # 参量质变仪
CARD = 5 # 生日画片
class Task(SQLModel, table=True):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
)
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True))
chat_id: Optional[int] = Field(default=None, sa_column=Column(BigInteger()))
time_created: Optional[datetime] = Field(
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
)
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
type: TaskTypeEnum = Field(primary_key=True, sa_column=Column(Enum(TaskTypeEnum)))
status: Optional[TaskStatusEnum] = Field(sa_column=Column(Enum(TaskStatusEnum)))
data: Optional[Dict[str, Any]] = Field(sa_column=Column(JSON))

View File

@ -1,50 +1,3 @@
from typing import List, Optional from gram_core.services.task.repositories import TaskRepository
from sqlmodel import select
from core.base_service import BaseService
from core.dependence.database import Database
from core.services.task.models import Task, TaskTypeEnum
from core.sqlmodel.session import AsyncSession
__all__ = ("TaskRepository",) __all__ = ("TaskRepository",)
class TaskRepository(BaseService.Component):
def __init__(self, database: Database):
self.engine = database.engine
async def add(self, task: Task):
async with AsyncSession(self.engine) as session:
session.add(task)
await session.commit()
async def remove(self, task: Task):
async with AsyncSession(self.engine) as session:
await session.delete(task)
await session.commit()
async def update(self, task: Task) -> Task:
async with AsyncSession(self.engine) as session:
session.add(task)
await session.commit()
await session.refresh(task)
return task
async def get_by_user_id(self, user_id: int, task_type: TaskTypeEnum) -> Optional[Task]:
async with AsyncSession(self.engine) as session:
statement = select(Task).where(Task.user_id == user_id).where(Task.type == task_type)
results = await session.exec(statement)
return results.first()
async def get_by_chat_id(self, chat_id: int, task_type: TaskTypeEnum) -> Optional[List[Task]]:
async with AsyncSession(self.engine) as session:
statement = select(Task).where(Task.chat_id == chat_id).where(Task.type == task_type)
results = await session.exec(statement)
return results.all()
async def get_all(self, task_type: TaskTypeEnum) -> List[Task]:
async with AsyncSession(self.engine) as session:
query = select(Task).where(Task.type == task_type)
results = await session.exec(query)
return results.all()

View File

@ -1,9 +1,11 @@
import datetime from gram_core.services.task.services import (
from typing import Optional, Dict, Any TaskServices,
SignServices,
from core.base_service import BaseService TaskCardServices,
from core.services.task.models import Task, TaskTypeEnum TaskResinServices,
from core.services.task.repositories import TaskRepository TaskRealmServices,
TaskExpeditionServices,
)
__all__ = [ __all__ = [
"TaskServices", "TaskServices",
@ -13,201 +15,3 @@ __all__ = [
"TaskRealmServices", "TaskRealmServices",
"TaskExpeditionServices", "TaskExpeditionServices",
] ]
class TaskServices(BaseService):
TASK_TYPE: TaskTypeEnum
def __init__(self, task_repository: TaskRepository) -> None:
self._repository: TaskRepository = task_repository
async def add(self, task: Task):
return await self._repository.add(task)
async def remove(self, task: Task):
return await self._repository.remove(task)
async def update(self, task: Task):
task.time_updated = datetime.datetime.now()
return await self._repository.update(task)
async def get_by_user_id(self, user_id: int):
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
async def get_all(self):
return await self._repository.get_all(self.TASK_TYPE)
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
return Task(
user_id=user_id,
chat_id=chat_id,
time_created=datetime.datetime.now(),
status=status,
type=self.TASK_TYPE,
data=data,
)
class SignServices(BaseService):
TASK_TYPE = TaskTypeEnum.SIGN
def __init__(self, task_repository: TaskRepository) -> None:
self._repository: TaskRepository = task_repository
async def add(self, task: Task):
return await self._repository.add(task)
async def remove(self, task: Task):
return await self._repository.remove(task)
async def update(self, task: Task):
task.time_updated = datetime.datetime.now()
return await self._repository.update(task)
async def get_by_user_id(self, user_id: int):
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
async def get_all(self):
return await self._repository.get_all(self.TASK_TYPE)
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
return Task(
user_id=user_id,
chat_id=chat_id,
time_created=datetime.datetime.now(),
status=status,
type=self.TASK_TYPE,
data=data,
)
class TaskCardServices(BaseService):
TASK_TYPE = TaskTypeEnum.CARD
def __init__(self, task_repository: TaskRepository) -> None:
self._repository: TaskRepository = task_repository
async def add(self, task: Task):
return await self._repository.add(task)
async def remove(self, task: Task):
return await self._repository.remove(task)
async def update(self, task: Task):
task.time_updated = datetime.datetime.now()
return await self._repository.update(task)
async def get_by_user_id(self, user_id: int):
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
async def get_all(self):
return await self._repository.get_all(self.TASK_TYPE)
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
return Task(
user_id=user_id,
chat_id=chat_id,
time_created=datetime.datetime.now(),
status=status,
type=self.TASK_TYPE,
data=data,
)
class TaskResinServices(BaseService):
TASK_TYPE = TaskTypeEnum.RESIN
def __init__(self, task_repository: TaskRepository) -> None:
self._repository: TaskRepository = task_repository
async def add(self, task: Task):
return await self._repository.add(task)
async def remove(self, task: Task):
return await self._repository.remove(task)
async def update(self, task: Task):
task.time_updated = datetime.datetime.now()
return await self._repository.update(task)
async def get_by_user_id(self, user_id: int):
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
async def get_all(self):
return await self._repository.get_all(self.TASK_TYPE)
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
return Task(
user_id=user_id,
chat_id=chat_id,
time_created=datetime.datetime.now(),
status=status,
type=self.TASK_TYPE,
data=data,
)
class TaskRealmServices(BaseService):
TASK_TYPE = TaskTypeEnum.REALM
def __init__(self, task_repository: TaskRepository) -> None:
self._repository: TaskRepository = task_repository
async def add(self, task: Task):
return await self._repository.add(task)
async def remove(self, task: Task):
return await self._repository.remove(task)
async def update(self, task: Task):
task.time_updated = datetime.datetime.now()
return await self._repository.update(task)
async def get_by_user_id(self, user_id: int):
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
async def get_all(self):
return await self._repository.get_all(self.TASK_TYPE)
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
return Task(
user_id=user_id,
chat_id=chat_id,
time_created=datetime.datetime.now(),
status=status,
type=self.TASK_TYPE,
data=data,
)
class TaskExpeditionServices(BaseService):
TASK_TYPE = TaskTypeEnum.EXPEDITION
def __init__(self, task_repository: TaskRepository) -> None:
self._repository: TaskRepository = task_repository
async def add(self, task: Task):
return await self._repository.add(task)
async def remove(self, task: Task):
return await self._repository.remove(task)
async def update(self, task: Task):
task.time_updated = datetime.datetime.now()
return await self._repository.update(task)
async def get_by_user_id(self, user_id: int):
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE)
async def get_all(self):
return await self._repository.get_all(self.TASK_TYPE)
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
return Task(
user_id=user_id,
chat_id=chat_id,
time_created=datetime.datetime.now(),
status=status,
type=self.TASK_TYPE,
data=data,
)

View File

@ -1,58 +1,3 @@
import gzip from gram_core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
import pickle # nosec B403
from hashlib import sha256
from typing import Any, Optional
from core.base_service import BaseService
from core.dependence.redisdb import RedisDB
__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"] __all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"]
class TemplatePreviewCache(BaseService.Component):
"""暂存渲染模板的数据用于预览"""
def __init__(self, redis: RedisDB):
self.client = redis.client
self.qname = "bot:template:preview"
async def get_data(self, key: str) -> Any:
data = await self.client.get(self.cache_key(key))
if data:
# skipcq: BAN-B301
return pickle.loads(gzip.decompress(data)) # nosec B301
async def set_data(self, key: str, data: Any, ttl: int = 8 * 60 * 60):
ck = self.cache_key(key)
await self.client.set(ck, gzip.compress(pickle.dumps(data)))
if ttl != -1:
await self.client.expire(ck, ttl)
def cache_key(self, key: str) -> str:
return f"{self.qname}:{key}"
class HtmlToFileIdCache(BaseService.Component):
"""html to file_id 的缓存"""
def __init__(self, redis: RedisDB):
self.client = redis.client
self.qname = "bot:template:html-to-file-id"
async def get_data(self, html: str, file_type: str) -> Optional[str]:
data = await self.client.get(self.cache_key(html, file_type))
if data:
return data.decode()
async def set_data(self, html: str, file_type: str, file_id: str, ttl: int = 24 * 60 * 60):
ck = self.cache_key(html, file_type)
await self.client.set(ck, file_id)
if ttl != -1:
await self.client.expire(ck, ttl)
async def delete_data(self, html: str, file_type: str) -> bool:
return await self.client.delete(self.cache_key(html, file_type))
def cache_key(self, html: str, file_type: str) -> str:
key = sha256(html.encode()).hexdigest()
return f"{self.qname}:{file_type}:{key}"

View File

@ -1,14 +1,8 @@
class TemplateException(Exception): from gram_core.services.template.error import (
pass ErrorFileType,
FileIdNotFound,
QuerySelectorNotFound,
TemplateException,
)
__all__ = ("TemplateException", "QuerySelectorNotFound", "ErrorFileType", "FileIdNotFound")
class QuerySelectorNotFound(TemplateException):
pass
class ErrorFileType(TemplateException):
pass
class FileIdNotFound(TemplateException):
pass

View File

@ -1,146 +1,3 @@
from enum import Enum from gram_core.services.template.models import FileType, RenderResult, RenderGroupResult
from typing import List, Optional, Union
from telegram import InputMediaDocument, InputMediaPhoto, Message
from telegram.error import BadRequest
from core.services.template.cache import HtmlToFileIdCache
from core.services.template.error import ErrorFileType, FileIdNotFound
__all__ = ["FileType", "RenderResult", "RenderGroupResult"] __all__ = ["FileType", "RenderResult", "RenderGroupResult"]
class FileType(Enum):
PHOTO = 1
DOCUMENT = 2
@staticmethod
def media_type(file_type: "FileType"):
"""对应的 Telegram media 类型"""
if file_type == FileType.PHOTO:
return InputMediaPhoto
if file_type == FileType.DOCUMENT:
return InputMediaDocument
raise ErrorFileType
class RenderResult:
"""渲染结果"""
def __init__(
self,
html: str,
photo: Union[bytes, str],
file_type: FileType,
cache: HtmlToFileIdCache,
ttl: int = 24 * 60 * 60,
caption: Optional[str] = None,
parse_mode: Optional[str] = None,
filename: Optional[str] = None,
):
"""
`html`: str 渲染生成的 html
`photo`: Union[bytes, str] 渲染生成的图片bytes 表示是图片str 则为 file_id
"""
self.caption = caption
self.parse_mode = parse_mode
self.filename = filename
self.html = html
self.photo = photo
self.file_type = file_type
self._cache = cache
self.ttl = ttl
async def reply_photo(self, message: Message, *args, **kwargs):
"""是 `message.reply_photo` 的封装,上传成功后,缓存 telegram 返回的 file_id方便重复使用"""
if self.file_type != FileType.PHOTO:
raise ErrorFileType
try:
reply = await message.reply_photo(photo=self.photo, *args, **kwargs)
except BadRequest as exc:
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
await self._cache.delete_data(self.html, self.file_type.name)
raise BadRequest(message="Wrong file identifier specified")
raise exc
await self.cache_file_id(reply)
return reply
async def reply_document(self, message: Message, *args, **kwargs):
"""是 `message.reply_document` 的封装,上传成功后,缓存 telegram 返回的 file_id方便重复使用"""
if self.file_type != FileType.DOCUMENT:
raise ErrorFileType
try:
reply = await message.reply_document(document=self.photo, *args, **kwargs)
except BadRequest as exc:
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
await self._cache.delete_data(self.html, self.file_type.name)
raise BadRequest(message="Wrong file identifier specified")
raise exc
await self.cache_file_id(reply)
return reply
async def edit_media(self, message: Message, *args, **kwargs):
"""是 `message.edit_media` 的封装,上传成功后,缓存 telegram 返回的 file_id方便重复使用"""
if self.file_type != FileType.PHOTO:
raise ErrorFileType
media = InputMediaPhoto(
media=self.photo, caption=self.caption, parse_mode=self.parse_mode, filename=self.filename
)
try:
edit_media = await message.edit_media(media, *args, **kwargs)
except BadRequest as exc:
if "Wrong file identifier" in exc.message and isinstance(self.photo, str):
await self._cache.delete_data(self.html, self.file_type.name)
raise BadRequest(message="Wrong file identifier specified")
raise exc
await self.cache_file_id(edit_media)
return edit_media
async def cache_file_id(self, reply: Message):
"""缓存 telegram 返回的 file_id"""
if self.is_file_id():
return
if self.file_type == FileType.PHOTO and reply.photo:
file_id = reply.photo[0].file_id
elif self.file_type == FileType.DOCUMENT and reply.document:
file_id = reply.document.file_id
else:
raise FileIdNotFound
await self._cache.set_data(self.html, self.file_type.name, file_id, self.ttl)
def is_file_id(self) -> bool:
return isinstance(self.photo, str)
class RenderGroupResult:
def __init__(self, results: List[RenderResult]):
self.results = results
async def reply_media_group(self, message: Message, *args, **kwargs):
"""是 `message.reply_media_group` 的封装,上传成功后,缓存 telegram 返回的 file_id方便重复使用"""
reply = await message.reply_media_group(
media=[
FileType.media_type(result.file_type)(
media=result.photo, caption=result.caption, parse_mode=result.parse_mode, filename=result.filename
)
for result in self.results
],
*args,
**kwargs,
)
for index, value in enumerate(reply):
result = self.results[index]
await result.cache_file_id(value)

View File

@ -1,207 +1,3 @@
import asyncio from gram_core.services.template.services import TemplateService, TemplatePreviewer
from typing import Optional
from urllib.parse import urlencode, urljoin, urlsplit
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from jinja2 import Environment, FileSystemLoader, Template
from playwright.async_api import ViewportSize
from core.application import Application
from core.base_service import BaseService
from core.config import config as application_config
from core.dependence.aiobrowser import AioBrowser
from core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
from core.services.template.error import QuerySelectorNotFound
from core.services.template.models import FileType, RenderResult
from utils.const import PROJECT_ROOT
from utils.log import logger
__all__ = ("TemplateService", "TemplatePreviewer") __all__ = ("TemplateService", "TemplatePreviewer")
class TemplateService(BaseService):
def __init__(
self,
app: Application,
browser: AioBrowser,
html_to_file_id_cache: HtmlToFileIdCache,
preview_cache: TemplatePreviewCache,
template_dir: str = "resources",
):
self._browser = browser
self.template_dir = PROJECT_ROOT / template_dir
self._jinja2_env = Environment(
loader=FileSystemLoader(template_dir),
enable_async=True,
autoescape=True,
auto_reload=application_config.debug,
)
self.using_preview = application_config.debug and application_config.webserver.enable
if self.using_preview:
self.previewer = TemplatePreviewer(self, preview_cache, app.web_app)
self.html_to_file_id_cache = html_to_file_id_cache
def get_template(self, template_name: str) -> Template:
return self._jinja2_env.get_template(template_name)
async def render_async(self, template_name: str, template_data: dict) -> str:
"""模板渲染
:param template_name: 模板文件名
:param template_data: 模板数据
"""
loop = asyncio.get_event_loop()
start_time = loop.time()
template = self.get_template(template_name)
html = await template.render_async(**template_data)
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
return html
async def render(
self,
template_name: str,
template_data: dict,
viewport: Optional[ViewportSize] = None,
full_page: bool = True,
evaluate: Optional[str] = None,
query_selector: Optional[str] = None,
file_type: FileType = FileType.PHOTO,
ttl: int = 24 * 60 * 60,
caption: Optional[str] = None,
parse_mode: Optional[str] = None,
filename: Optional[str] = None,
) -> RenderResult:
"""模板渲染成图片
:param template_name: 模板文件名
:param template_data: 模板数据
:param viewport: 截图大小
:param full_page: 是否长截图
:param evaluate: 页面加载后运行的 js
:param query_selector: 截图选择器
:param file_type: 缓存的文件类型
:param ttl: 缓存时间
:param caption: 图片描述
:param parse_mode: 图片描述解析模式
:param filename: 文件名字
:return:
"""
loop = asyncio.get_event_loop()
start_time = loop.time()
template = self.get_template(template_name)
if self.using_preview:
preview_url = await self.previewer.get_preview_url(template_name, template_data)
logger.debug("调试模板 URL: \n%s", preview_url)
html = await template.render_async(**template_data)
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
file_id = await self.html_to_file_id_cache.get_data(html, file_type.name)
if file_id and not application_config.debug:
logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id)
return RenderResult(
html=html,
photo=file_id,
file_type=file_type,
cache=self.html_to_file_id_cache,
ttl=ttl,
caption=caption,
parse_mode=parse_mode,
filename=filename,
)
browser = await self._browser.get_browser()
start_time = loop.time()
page = await browser.new_page(viewport=viewport)
uri = (PROJECT_ROOT / template.filename).as_uri()
await page.goto(uri)
await page.set_content(html, wait_until="networkidle")
if evaluate:
await page.evaluate(evaluate)
clip = None
if query_selector:
try:
card = await page.query_selector(query_selector)
if not card:
raise QuerySelectorNotFound
clip = await card.bounding_box()
if not clip:
raise QuerySelectorNotFound
except QuerySelectorNotFound:
logger.warning("未找到 %s 元素", query_selector)
png_data = await page.screenshot(clip=clip, full_page=full_page)
await page.close()
logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time))
return RenderResult(
html=html,
photo=png_data,
file_type=file_type,
cache=self.html_to_file_id_cache,
ttl=ttl,
caption=caption,
parse_mode=parse_mode,
filename=filename,
)
class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug):
def __init__(
self,
template_service: TemplateService,
cache: TemplatePreviewCache,
web_app: FastAPI,
):
self.web_app = web_app
self.template_service = template_service
self.cache = cache
self.register_routes()
async def get_preview_url(self, template: str, data: dict):
"""获取预览 URL"""
components = urlsplit(application_config.webserver.url)
path = urljoin("/preview/", template)
query = {}
# 如果有数据,暂存在 redis 中
if data:
key = str(uuid4())
await self.cache.set_data(key, data)
query["key"] = key
# noinspection PyProtectedMember
return components._replace(path=path, query=urlencode(query)).geturl()
def register_routes(self):
"""注册预览用到的路由"""
@self.web_app.get("/preview/{path:path}")
async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612
# 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源
if not path.endswith((".html", ".jinja2")):
full_path = self.template_service.template_dir / path
if not full_path.is_file():
raise HTTPException(status_code=404, detail=f"Template '{path}' not found")
return FileResponse(full_path)
# 取回暂存的渲染数据
data = await self.cache.get_data(key) if key else {}
if key and data is None:
raise HTTPException(status_code=404, detail=f"Template data {key} not found")
# 渲染 jinja2 模板
html = await self.template_service.render_async(path, data)
# 将本地 URL file:// 修改为 HTTP url因为浏览器内不允许加载本地文件
# file:///project_dir/cache/image.jpg => /cache/image.jpg
html = html.replace(PROJECT_ROOT.as_uri(), "")
return HTMLResponse(html)
# 其他静态资源
for name in ["cache", "resources"]:
directory = PROJECT_ROOT / name
directory.mkdir(exist_ok=True)
self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)

View File

@ -1,24 +1,3 @@
from typing import List from gram_core.services.users.cache import UserAdminCache
from core.base_service import BaseService
from core.dependence.redisdb import RedisDB
__all__ = ("UserAdminCache",) __all__ = ("UserAdminCache",)
class UserAdminCache(BaseService.Component):
def __init__(self, redis: RedisDB):
self.client = redis.client
self.qname = "users:admin"
async def ismember(self, user_id: int) -> bool:
return await self.client.sismember(self.qname, user_id)
async def get_all(self) -> List[int]:
return [int(str_data) for str_data in await self.client.smembers(self.qname)]
async def set(self, user_id: int) -> bool:
return await self.client.sadd(self.qname, user_id)
async def remove(self, user_id: int) -> bool:
return await self.client.srem(self.qname, user_id)

View File

@ -1,34 +1,7 @@
import enum from gram_core.services.users.models import User, UserDataBase, PermissionsEnum
from datetime import datetime
from typing import Optional
from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer
__all__ = ( __all__ = (
"User", "User",
"UserDataBase", "UserDataBase",
"PermissionsEnum", "PermissionsEnum",
) )
class PermissionsEnum(int, enum.Enum):
OWNER = 1
ADMIN = 2
PUBLIC = 3
class User(SQLModel):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
)
user_id: int = Field(unique=True, sa_column=Column(BigInteger()))
permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum)))
locale: Optional[str] = Field()
ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
is_banned: Optional[int] = Field()
class UserDataBase(User, table=True):
__tablename__ = "users"

View File

@ -1,44 +1,3 @@
from typing import Optional, List from gram_core.services.users.repositories import UserRepository
from sqlmodel import select
from core.base_service import BaseService
from core.dependence.database import Database
from core.services.users.models import UserDataBase as User
from core.sqlmodel.session import AsyncSession
__all__ = ("UserRepository",) __all__ = ("UserRepository",)
class UserRepository(BaseService.Component):
def __init__(self, database: Database):
self.engine = database.engine
async def get_by_user_id(self, user_id: int) -> Optional[User]:
async with AsyncSession(self.engine) as session:
statement = select(User).where(User.user_id == user_id)
results = await session.exec(statement)
return results.first()
async def add(self, user: User):
async with AsyncSession(self.engine) as session:
session.add(user)
await session.commit()
async def update(self, user: User) -> User:
async with AsyncSession(self.engine) as session:
session.add(user)
await session.commit()
await session.refresh(user)
return user
async def remove(self, user: User):
async with AsyncSession(self.engine) as session:
await session.delete(user)
await session.commit()
async def get_all(self) -> List[User]:
async with AsyncSession(self.engine) as session:
statement = select(User)
results = await session.exec(statement)
return results.all()

View File

@ -1,83 +1,3 @@
from typing import List, Optional from gram_core.services.users.services import UserService, UserAdminService
from core.base_service import BaseService
from core.config import config
from core.services.users.cache import UserAdminCache
from core.services.users.models import PermissionsEnum, UserDataBase as User
from core.services.users.repositories import UserRepository
__all__ = ("UserService", "UserAdminService") __all__ = ("UserService", "UserAdminService")
from utils.log import logger
class UserService(BaseService):
def __init__(self, user_repository: UserRepository) -> None:
self._repository: UserRepository = user_repository
async def get_user_by_id(self, user_id: int) -> Optional[User]:
"""从数据库获取用户信息
:param user_id:用户ID
:return: User
"""
return await self._repository.get_by_user_id(user_id)
async def remove(self, user: User):
return await self._repository.remove(user)
async def update_user(self, user: User):
return await self._repository.add(user)
class UserAdminService(BaseService):
def __init__(self, user_repository: UserRepository, cache: UserAdminCache):
self.user_repository = user_repository
self._cache = cache
async def initialize(self):
owner = config.owner
if owner:
user = await self.user_repository.get_by_user_id(owner)
if user:
if user.permissions != PermissionsEnum.OWNER:
user.permissions = PermissionsEnum.OWNER
await self._cache.set(user.user_id)
await self.user_repository.update(user)
else:
user = User(user_id=owner, permissions=PermissionsEnum.OWNER)
await self._cache.set(user.user_id)
await self.user_repository.add(user)
else:
logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限")
users = await self.user_repository.get_all()
for user in users:
await self._cache.set(user.user_id)
async def is_admin(self, user_id: int) -> bool:
return await self._cache.ismember(user_id)
async def get_admin_list(self) -> List[int]:
return await self._cache.get_all()
async def add_admin(self, user_id: int) -> bool:
user = await self.user_repository.get_by_user_id(user_id)
if user:
if user.permissions == PermissionsEnum.OWNER:
return False
if user.permissions != PermissionsEnum.ADMIN:
user.permissions = PermissionsEnum.ADMIN
await self.user_repository.update(user)
else:
user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN)
await self.user_repository.add(user)
return await self._cache.set(user_id)
async def delete_admin(self, user_id: int) -> bool:
user = await self.user_repository.get_by_user_id(user_id)
if user:
if user.permissions == PermissionsEnum.OWNER:
return True # 假装移除成功
user.permissions = PermissionsEnum.PUBLIC
await self.user_repository.update(user)
return await self._cache.remove(user.user_id)
return False

View File

@ -1,118 +1,3 @@
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload from gram_core.sqlmodel.session import AsyncSession
from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.sql.base import Executable as _Executable
from sqlmodel.engine.result import Result, ScalarResult
from sqlmodel.orm.session import Session
from sqlmodel.sql.base import Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import Literal
_TSelectParam = TypeVar("_TSelectParam")
__all__ = ("AsyncSession",) __all__ = ("AsyncSession",)
class AsyncSession(_AsyncSession): # pylint: disable=W0223
sync_session_class = Session
sync_session: Session
def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
sync_session_class: Type[Session] = Session,
**kw: Any,
):
super().__init__(
bind=bind,
binds=binds,
sync_session_class=sync_session_class,
**kw,
)
@overload
async def exec(
self,
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Result[_TSelectParam]:
...
@overload
async def exec(
self,
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...
async def exec(
self,
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
results = super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
if isinstance(statement, SelectOfScalar):
return (await results).scalars() # type: ignore
return await results # type: ignore
async def execute( # pylint: disable=W0221
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Result[Any]:
return await super().execute( # type: ignore
statement=statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
async def get( # pylint: disable=W0221
self,
entity: Type[_TSelectParam],
ident: Any,
options: Optional[Sequence[Any]] = None,
populate_existing: bool = False,
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
identity_token: Optional[Any] = None,
execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT,
) -> Optional[_TSelectParam]:
return await super().get(
entity=entity,
ident=ident,
options=options,
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
execution_options=execution_options,
)

17
poetry.lock generated
View File

@ -920,6 +920,21 @@ files = [
[package.dependencies] [package.dependencies]
gitdb = ">=4.0.1,<5" gitdb = ">=4.0.1,<5"
[[package]]
name = "gram-core"
version = "0.1.0"
description = "telegram robot base core."
optional = false
python-versions = ">=3.8"
files = []
develop = false
[package.source]
type = "git"
url = "https://github.com/PaiGramTeam/GramCore.git"
reference = "HEAD"
resolved_reference = "7fb5d4c0e01731e6901829fe317be19023c2c4c7"
[[package]] [[package]]
name = "greenlet" name = "greenlet"
version = "1.1.3" version = "1.1.3"
@ -2800,4 +2815,4 @@ test = ["flaky", "pytest", "pytest-asyncio"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "844402570472817cfd7937dd651a35f01e45f2e7e0d0381e540b357175b73899" content-hash = "ea8005da6cf4ff3c982a5f7b98491328c507c97dcfcd46500f14c74a3621f788"

View File

@ -45,6 +45,7 @@ pillow = "^10.0.0"
playwright = "^1.27.1" playwright = "^1.27.1"
aiosqlite = { extras = ["sqlite"], version = "^0.19.0" } aiosqlite = { extras = ["sqlite"], version = "^0.19.0" }
simnet = { git = "https://github.com/PaiGramTeam/SIMNet" } simnet = { git = "https://github.com/PaiGramTeam/SIMNet" }
gram-core = {git = "https://github.com/PaiGramTeam/GramCore.git"}
[tool.poetry.extras] [tool.poetry.extras]
pyro = ["Pyrogram", "TgCrypto"] pyro = ["Pyrogram", "TgCrypto"]

View File

@ -8,7 +8,7 @@ anyio==3.7.1 ; python_version >= "3.8" and python_version < "4.0"
appdirs==1.4.4 ; python_version >= "3.8" and python_version < "4.0" appdirs==1.4.4 ; python_version >= "3.8" and python_version < "4.0"
apscheduler==3.10.1 ; python_version >= "3.8" and python_version < "4.0" apscheduler==3.10.1 ; python_version >= "3.8" and python_version < "4.0"
arko-wrapper==0.2.8 ; python_version >= "3.8" and python_version < "4.0" arko-wrapper==0.2.8 ; python_version >= "3.8" and python_version < "4.0"
async-lru==2.0.3 ; python_version >= "3.8" and python_version < "4.0" async-lru==2.0.4 ; python_version >= "3.8" and python_version < "4.0"
async-timeout==4.0.2 ; python_version >= "3.8" and python_version < "4.0" async-timeout==4.0.2 ; python_version >= "3.8" and python_version < "4.0"
asyncmy==0.2.8 ; python_version >= "3.8" and python_version < "4.0" asyncmy==0.2.8 ; python_version >= "3.8" and python_version < "4.0"
attrs==23.1.0 ; python_version >= "3.8" and python_version < "4.0" attrs==23.1.0 ; python_version >= "3.8" and python_version < "4.0"
@ -26,12 +26,13 @@ cryptography==41.0.2 ; python_version >= "3.8" and python_version < "4.0"
enkanetwork-py @ git+https://github.com/mrwan200/EnkaNetwork.py@master ; python_version >= "3.8" and python_version < "4.0" enkanetwork-py @ git+https://github.com/mrwan200/EnkaNetwork.py@master ; python_version >= "3.8" and python_version < "4.0"
et-xmlfile==1.1.0 ; python_version >= "3.8" and python_version < "4.0" et-xmlfile==1.1.0 ; python_version >= "3.8" and python_version < "4.0"
exceptiongroup==1.1.2 ; python_version >= "3.8" and python_version < "3.11" exceptiongroup==1.1.2 ; python_version >= "3.8" and python_version < "3.11"
fakeredis==2.16.0 ; python_version >= "3.8" and python_version < "4.0" fakeredis==2.17.0 ; python_version >= "3.8" and python_version < "4.0"
fastapi==0.99.1 ; python_version >= "3.8" and python_version < "4.0" fastapi==0.99.1 ; python_version >= "3.8" and python_version < "4.0"
flaky==3.7.0 ; python_version >= "3.8" and python_version < "4.0" flaky==3.7.0 ; python_version >= "3.8" and python_version < "4.0"
frozenlist==1.4.0 ; python_version >= "3.8" and python_version < "4.0" frozenlist==1.4.0 ; python_version >= "3.8" and python_version < "4.0"
gitdb==4.0.10 ; python_version >= "3.8" and python_version < "4.0" gitdb==4.0.10 ; python_version >= "3.8" and python_version < "4.0"
gitpython==3.1.32 ; python_version >= "3.8" and python_version < "4.0" gitpython==3.1.32 ; python_version >= "3.8" and python_version < "4.0"
gram-core @ git+https://github.com/PaiGramTeam/GramCore.git@main ; python_version >= "3.8" and python_version < "4.0"
greenlet==1.1.3 ; python_version >= "3.8" and python_version < "4.0" greenlet==1.1.3 ; python_version >= "3.8" and python_version < "4.0"
h11==0.14.0 ; python_version >= "3.8" and python_version < "4.0" h11==0.14.0 ; python_version >= "3.8" and python_version < "4.0"
httpcore==0.17.3 ; python_version >= "3.8" and python_version < "4.0" httpcore==0.17.3 ; python_version >= "3.8" and python_version < "4.0"
@ -73,8 +74,8 @@ pytz==2023.3 ; python_version >= "3.8" and python_version < "4.0"
pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4.0" pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4.0"
qrcode==7.4.2 ; python_version >= "3.8" and python_version < "4.0" qrcode==7.4.2 ; python_version >= "3.8" and python_version < "4.0"
redis==4.6.0 ; python_version >= "3.8" and python_version < "4.0" redis==4.6.0 ; python_version >= "3.8" and python_version < "4.0"
rich==13.4.2 ; python_version >= "3.8" and python_version < "4.0" rich==13.5.1 ; python_version >= "3.8" and python_version < "4.0"
sentry-sdk==1.28.1 ; python_version >= "3.8" and python_version < "4.0" sentry-sdk==1.29.2 ; python_version >= "3.8" and python_version < "4.0"
setuptools==68.0.0 ; python_version >= "3.8" and python_version < "4.0" setuptools==68.0.0 ; python_version >= "3.8" and python_version < "4.0"
simnet @ git+https://github.com/PaiGramTeam/SIMNet@main ; python_version >= "3.8" and python_version < "4.0" simnet @ git+https://github.com/PaiGramTeam/SIMNet@main ; python_version >= "3.8" and python_version < "4.0"
six==1.16.0 ; python_version >= "3.8" and python_version < "4.0" six==1.16.0 ; python_version >= "3.8" and python_version < "4.0"
@ -96,7 +97,7 @@ tzdata==2023.3 ; python_version >= "3.8" and python_version < "4.0" and platform
tzlocal==5.0.1 ; python_version >= "3.8" and python_version < "4.0" tzlocal==5.0.1 ; python_version >= "3.8" and python_version < "4.0"
ujson==5.8.0 ; python_version >= "3.8" and python_version < "4.0" ujson==5.8.0 ; python_version >= "3.8" and python_version < "4.0"
urllib3==1.26.16 ; python_version >= "3.8" and python_version < "4.0" urllib3==1.26.16 ; python_version >= "3.8" and python_version < "4.0"
uvicorn[standard]==0.22.0 ; python_version >= "3.8" and python_version < "4.0" uvicorn[standard]==0.23.1 ; python_version >= "3.8" and python_version < "4.0"
uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.8" and python_version < "4.0" uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.8" and python_version < "4.0"
watchfiles==0.19.0 ; python_version >= "3.8" and python_version < "4.0" watchfiles==0.19.0 ; python_version >= "3.8" and python_version < "4.0"
websockets==10.4 ; python_version >= "3.8" and python_version < "4.0" websockets==10.4 ; python_version >= "3.8" and python_version < "4.0"

2
run.py
View File

@ -19,7 +19,7 @@ def run():
def main(): def main():
from core.builtins.reloader import Reloader from gram_core.builtins.reloader import Reloader
from core.config import config from core.config import config
if config.auto_reload: # 是否启动重载器 if config.auto_reload: # 是否启动重载器