diff --git a/manager.py b/manager.py index 3ea70fe..9efc2d3 100644 --- a/manager.py +++ b/manager.py @@ -1,10 +1,10 @@ import asyncio +import sys from importlib import import_module from pathlib import Path from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar from arkowrapper import ArkoWrapper -from async_timeout import timeout from typing_extensions import ParamSpec from gram_core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services @@ -18,6 +18,11 @@ if TYPE_CHECKING: from gram_core.plugin import PluginType from gram_core.builtins.executor import Executor +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout + __all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers") R = TypeVar("R") diff --git a/plugin/_funcs.py b/plugin/_funcs.py index 742ceb7..b9c8a1e 100644 --- a/plugin/_funcs.py +++ b/plugin/_funcs.py @@ -1,27 +1,8 @@ -from pathlib import Path -from typing import List, Optional, Union, TYPE_CHECKING +from telegram import ReplyKeyboardRemove, Update +from telegram.ext import ConversationHandler -import aiofiles -import httpx -from httpx import UnsupportedProtocol -from telegram import Chat, Message, ReplyKeyboardRemove, Update -from telegram.error import Forbidden, NetworkError -from telegram.ext import CallbackContext, ConversationHandler, Job - -from gram_core.dependence.redisdb import RedisDB from gram_core.plugin._handler import conversation, handler -from utils.const import CACHE_DIR, REQUEST_HEADERS -from utils.error import UrlResourcesNotFoundError -from utils.helpers import sha1 -from utils.log import logger - -if TYPE_CHECKING: - from gram_core.application import Application - -try: - import ujson as jsonlib -except ImportError: - import json as jsonlib +from gram_core.plugin.methods import PluginFuncMethods __all__ = ( "PluginFuncs", @@ -29,145 +10,8 @@ __all__ = ( ) -class PluginFuncs: - _application: "Optional[Application]" = None - - def set_application(self, application: "Application") -> None: - self._application = application - - @property - def application(self) -> "Application": - if self._application is None: - raise RuntimeError("No application was set for this PluginManager.") - return self._application - - async def _delete_message(self, context: CallbackContext) -> None: - job = context.job - message_id = job.data - chat_info = f"chat_id[{job.chat_id}]" - - try: - chat = await self.get_chat(job.chat_id) - full_name = chat.full_name - if full_name: - chat_info = f"{full_name}[{chat.id}]" - else: - chat_info = f"{chat.title}[{chat.id}]" - except (NetworkError, Forbidden) as exc: - logger.warning("获取 chat info 失败 %s", exc.message) - except Exception as exc: - logger.warning("获取 chat info 消息失败 %s", str(exc)) - - logger.debug("删除消息 %s message_id[%s]", chat_info, message_id) - - try: - # noinspection PyTypeChecker - await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id) - except NetworkError as exc: - logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) - except Forbidden as exc: - logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) - except Exception as exc: - logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc) - - async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat: - application = self.application - redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None) - - if not redis_db: - return await application.bot.get_chat(chat_id) - - qname = f"bot:chat:{chat_id}" - - data = await redis_db.client.get(qname) - if data: - json_data = jsonlib.loads(data) - return Chat.de_json(json_data, application.telegram.bot) - - chat_info = await application.telegram.bot.get_chat(chat_id) - await redis_db.client.set(qname, chat_info.to_json(), ex=expire) - return chat_info - - def add_delete_message_job( - self, - message: Optional[Union[int, Message]] = None, - *, - delay: int = 60, - name: Optional[str] = None, - chat: Optional[Union[int, Chat]] = None, - context: Optional[CallbackContext] = None, - ) -> Job: - """延迟删除消息""" - - if isinstance(message, Message): - if chat is None: - chat = message.chat_id - message = message.id - - chat = chat.id if isinstance(chat, Chat) else chat - - job_queue = self.application.job_queue or context.job_queue - - if job_queue is None or chat is None: - raise RuntimeError - - return job_queue.run_once( - callback=self._delete_message, - when=delay, - data=message, - name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message", - chat_id=chat, - job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"}, - ) - - @staticmethod - async def download_resource(url: str, return_path: bool = False) -> str: - url_sha1 = sha1(url) # url 的 hash 值 - pathed_url = Path(url) - - file_name = url_sha1 + pathed_url.suffix - file_path = CACHE_DIR.joinpath(file_name) - - if not file_path.exists(): # 若文件不存在,则下载 - async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client: - try: - response = await client.get(url) - except UnsupportedProtocol: - logger.error("链接不支持 url[%s]", url) - return "" - - if response.is_error: - logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code) - raise UrlResourcesNotFoundError(url) - - if response.status_code != 200: - logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code) - raise UrlResourcesNotFoundError(url) - - async with aiofiles.open(file_path, mode="wb") as f: - await f.write(response.content) - - logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path) - - return file_path if return_path else Path(file_path).as_uri() - - @staticmethod - def get_args(context: CallbackContext) -> List[str]: - args = context.args - match = context.match - - if args is None: - if match is not None and (command := match.groups()[0]): - temp = [] - command_parts = command.split(" ") - for command_part in command_parts: - if command_part: - temp.append(command_part) - return temp - return [] - if len(args) >= 1: - return args - return [] +class PluginFuncs(PluginFuncMethods): + """插件方法""" class ConversationFuncs: diff --git a/plugin/methods/__init__.py b/plugin/methods/__init__.py new file mode 100644 index 0000000..aa7b018 --- /dev/null +++ b/plugin/methods/__init__.py @@ -0,0 +1,17 @@ +from .application import ApplicationMethod +from .delete_message import DeleteMessage +from .download_resource import DownloadResource +from .get_args import GetArgs +from .get_chat import GetChat +from .migrate_data import MigrateData + + +class PluginFuncMethods( + ApplicationMethod, + DeleteMessage, + DownloadResource, + GetArgs, + GetChat, + MigrateData, +): + """插件方法""" diff --git a/plugin/methods/application.py b/plugin/methods/application.py new file mode 100644 index 0000000..ad363f1 --- /dev/null +++ b/plugin/methods/application.py @@ -0,0 +1,17 @@ +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from gram_core.application import Application + + +class ApplicationMethod: + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError("No application was set for this PluginManager.") + return self._application diff --git a/plugin/methods/delete_message.py b/plugin/methods/delete_message.py new file mode 100644 index 0000000..8aaee2a --- /dev/null +++ b/plugin/methods/delete_message.py @@ -0,0 +1,76 @@ +from typing import Optional, Union, TYPE_CHECKING + +from telegram import Chat, Message +from telegram.error import Forbidden, NetworkError +from telegram.ext import CallbackContext, Job + +from utils.log import logger + +if TYPE_CHECKING: + from . import PluginFuncMethods + + +class DeleteMessage: + async def _delete_message(self: "PluginFuncMethods", context: "CallbackContext") -> None: + job = context.job + message_id = job.data + chat_info = f"chat_id[{job.chat_id}]" + + try: + chat = await self.get_chat(job.chat_id) + full_name = chat.full_name + if full_name: + chat_info = f"{full_name}[{chat.id}]" + else: + chat_info = f"{chat.title}[{chat.id}]" + except (NetworkError, Forbidden) as exc: + logger.warning("获取 chat info 失败 %s", exc.message) + except Exception as exc: + logger.warning("获取 chat info 消息失败 %s", str(exc)) + + logger.debug("删除消息 %s message_id[%s]", chat_info, message_id) + + try: + # noinspection PyTypeChecker + await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id) + except NetworkError as exc: + logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) + except Forbidden as exc: + logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) + except Exception as exc: + logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc) + + def add_delete_message_job( + self: "PluginFuncMethods", + message: Optional[Union[int, Message]] = None, + *, + delay: int = 60, + name: Optional[str] = None, + chat: Optional[Union[int, Chat]] = None, + context: Optional[CallbackContext] = None, + ) -> Job: + """延迟删除消息""" + + if isinstance(message, Message): + if chat is None: + chat = message.chat_id + message = message.id + + chat = chat.id if isinstance(chat, Chat) else chat + + job_queue = self.application.job_queue or context.job_queue + + if job_queue is None or chat is None: + raise RuntimeError + + return job_queue.run_once( + callback=self._delete_message, + when=delay, + data=message, + name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message", + chat_id=chat, + job_kwargs={ + "replace_existing": True, + "id": f"{chat}|{message}|delete_message", + }, + ) diff --git a/plugin/methods/download_resource.py b/plugin/methods/download_resource.py new file mode 100644 index 0000000..2e2ca28 --- /dev/null +++ b/plugin/methods/download_resource.py @@ -0,0 +1,47 @@ +from pathlib import Path + +import aiofiles +import httpx +from httpx import UnsupportedProtocol + +from utils.const import CACHE_DIR, REQUEST_HEADERS +from utils.error import UrlResourcesNotFoundError +from utils.helpers import sha1 +from utils.log import logger + + +class DownloadResource: + @staticmethod + async def download_resource(url: str, return_path: bool = False) -> str: + url_sha1 = sha1(url) # url 的 hash 值 + pathed_url = Path(url) + + file_name = url_sha1 + pathed_url.suffix + file_path = CACHE_DIR.joinpath(file_name) + + if not file_path.exists(): # 若文件不存在,则下载 + async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client: + try: + response = await client.get(url) + except UnsupportedProtocol: + logger.error("链接不支持 url[%s]", url) + return "" + + if response.is_error: + logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code) + raise UrlResourcesNotFoundError(url) + + if response.status_code != 200: + logger.error( + "download_resource 获取url[%s] 错误 status_code[%s]", + url, + response.status_code, + ) + raise UrlResourcesNotFoundError(url) + + async with aiofiles.open(file_path, mode="wb") as f: + await f.write(response.content) + + logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path) + + return file_path if return_path else Path(file_path).as_uri() diff --git a/plugin/methods/get_args.py b/plugin/methods/get_args.py new file mode 100644 index 0000000..5e668cc --- /dev/null +++ b/plugin/methods/get_args.py @@ -0,0 +1,24 @@ +from typing import List, TYPE_CHECKING + +if TYPE_CHECKING: + from telegram.ext import CallbackContext + + +class GetArgs: + @staticmethod + def get_args(context: "CallbackContext") -> List[str]: + args = context.args + match = context.match + + if args is None: + if match is not None and (command := match.groups()[0]): + temp = [] + command_parts = command.split(" ") + for command_part in command_parts: + if command_part: + temp.append(command_part) + return temp + return [] + if len(args) >= 1: + return args + return [] diff --git a/plugin/methods/get_chat.py b/plugin/methods/get_chat.py new file mode 100644 index 0000000..29fee88 --- /dev/null +++ b/plugin/methods/get_chat.py @@ -0,0 +1,38 @@ +from typing import Union, Optional, TYPE_CHECKING + +from telegram import Chat + +from gram_core.dependence.redisdb import RedisDB + +if TYPE_CHECKING: + from . import PluginFuncMethods + +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib + + +class GetChat: + async def get_chat( + self: "PluginFuncMethods", + chat_id: Union[str, int], + redis_db: Optional[RedisDB] = None, + expire: int = 86400, + ) -> Chat: + application = self.application + redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None) + + if not redis_db: + return await application.bot.get_chat(chat_id) + + qname = f"bot:chat:{chat_id}" + + data = await redis_db.client.get(qname) + if data: + json_data = jsonlib.loads(data) + return Chat.de_json(json_data, application.telegram.bot) + + chat_info = await application.telegram.bot.get_chat(chat_id) + await redis_db.client.set(qname, chat_info.to_json(), ex=expire) + return chat_info diff --git a/plugin/methods/migrate_data.py b/plugin/methods/migrate_data.py new file mode 100644 index 0000000..9050fcc --- /dev/null +++ b/plugin/methods/migrate_data.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from typing import Optional, TypeVar, List, Any, Tuple, Type, TYPE_CHECKING + +if TYPE_CHECKING: + from gram_core.services.players.models import PlayersDataBase as Player + +T = TypeVar("T") + + +class MigrateDataException(Exception): + """迁移数据异常""" + + def __init__(self, msg: str): + self.msg = msg + + +class IMigrateData(ABC): + @abstractmethod + async def migrate_data_msg(self) -> str: + """返回迁移数据的提示信息""" + + @abstractmethod + async def migrate_data(self) -> bool: + """迁移数据""" + + @staticmethod + def get_sql_data_by_key(model: T, keys: Tuple[Any, ...]) -> tuple[Any, ...]: + """通过 key 获取数据""" + data = [] + for i in keys: + data.append(getattr(model, i.key)) + return tuple(data) + + @staticmethod + async def filter_sql_data( + model: Type[T], service_method, old_user_id: int, new_user_id: int, keys: Tuple[Any, ...] + ) -> Tuple[List[T], List[T]]: + """过滤数据库数据""" + data: List[model] = await service_method(old_user_id) + if not data: + return [], [] + new_data = await service_method(new_user_id) + new_data_index = [IMigrateData.get_sql_data_by_key(p, keys) for p in new_data] + need_migrate = [] + for d in data: + if IMigrateData.get_sql_data_by_key(d, keys) not in new_data_index: + need_migrate.append(d) + return need_migrate, new_data + + +class MigrateData: + async def get_migrate_data( + self, old_user_id: int, new_user_id: int, old_players: List["Player"] + ) -> Optional[IMigrateData]: + """获取迁移数据""" + if not (old_user_id and new_user_id and old_players): + return None + return None diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..fca6af6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +# 格式配置 +[tool.black] +include = '\.pyi?$' +line-length = 120 +target-version = ['py311'] diff --git a/services/players/repositories.py b/services/players/repositories.py index 2ad640c..d39980e 100644 --- a/services/players/repositories.py +++ b/services/players/repositories.py @@ -108,3 +108,10 @@ class PlayerInfoRepository(BaseService.Component): session.add(player) await session.commit() await session.refresh(player) + + async def get_all_by_user_id(self, user_id: int) -> List[PlayerInfoSQLModel]: + async with AsyncSession(self.engine) as session: + statement = select(PlayerInfoSQLModel).where(PlayerInfoSQLModel.user_id == user_id) + results = await session.exec(statement) + players = results.all() + return players diff --git a/services/task/repositories.py b/services/task/repositories.py index ce5f5fd..6fcea59 100644 --- a/services/task/repositories.py +++ b/services/task/repositories.py @@ -48,3 +48,9 @@ class TaskRepository(BaseService.Component): query = select(Task).where(Task.type == task_type) results = await session.exec(query) return results.all() + + async def get_all_by_user_id(self, user_id: int) -> List[Task]: + async with AsyncSession(self.engine) as session: + query = select(Task).where(Task.user_id == user_id) + results = await session.exec(query) + return results.all() diff --git a/services/task/services.py b/services/task/services.py index d2359bb..4b518df 100644 --- a/services/task/services.py +++ b/services/task/services.py @@ -37,6 +37,9 @@ class TaskServices: async def get_all(self): return await self._repository.get_all(self.TASK_TYPE) + async def get_all_by_user_id(self, user_id: int): + return await self._repository.get_all_by_user_id(user_id) + def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): return Task( user_id=user_id,