mirror of
https://github.com/PaiGramTeam/GramCore.git
synced 2024-12-03 10:26:10 +00:00
✨ Support migrate user data
This commit is contained in:
parent
c41bdedfe8
commit
4718860a87
@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
from arkowrapper import ArkoWrapper
|
||||
from async_timeout import timeout
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from gram_core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
|
||||
@ -18,6 +18,11 @@ if TYPE_CHECKING:
|
||||
from gram_core.plugin import PluginType
|
||||
from gram_core.builtins.executor import Executor
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from asyncio import timeout
|
||||
else:
|
||||
from async_timeout import timeout
|
||||
|
||||
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
|
||||
|
||||
R = TypeVar("R")
|
||||
|
166
plugin/_funcs.py
166
plugin/_funcs.py
@ -1,27 +1,8 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union, TYPE_CHECKING
|
||||
from telegram import ReplyKeyboardRemove, Update
|
||||
from telegram.ext import ConversationHandler
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from httpx import UnsupportedProtocol
|
||||
from telegram import Chat, Message, ReplyKeyboardRemove, Update
|
||||
from telegram.error import Forbidden, NetworkError
|
||||
from telegram.ext import CallbackContext, ConversationHandler, Job
|
||||
|
||||
from gram_core.dependence.redisdb import RedisDB
|
||||
from gram_core.plugin._handler import conversation, handler
|
||||
from utils.const import CACHE_DIR, REQUEST_HEADERS
|
||||
from utils.error import UrlResourcesNotFoundError
|
||||
from utils.helpers import sha1
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gram_core.application import Application
|
||||
|
||||
try:
|
||||
import ujson as jsonlib
|
||||
except ImportError:
|
||||
import json as jsonlib
|
||||
from gram_core.plugin.methods import PluginFuncMethods
|
||||
|
||||
__all__ = (
|
||||
"PluginFuncs",
|
||||
@ -29,145 +10,8 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
class PluginFuncs:
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError("No application was set for this PluginManager.")
|
||||
return self._application
|
||||
|
||||
async def _delete_message(self, context: CallbackContext) -> None:
|
||||
job = context.job
|
||||
message_id = job.data
|
||||
chat_info = f"chat_id[{job.chat_id}]"
|
||||
|
||||
try:
|
||||
chat = await self.get_chat(job.chat_id)
|
||||
full_name = chat.full_name
|
||||
if full_name:
|
||||
chat_info = f"{full_name}[{chat.id}]"
|
||||
else:
|
||||
chat_info = f"{chat.title}[{chat.id}]"
|
||||
except (NetworkError, Forbidden) as exc:
|
||||
logger.warning("获取 chat info 失败 %s", exc.message)
|
||||
except Exception as exc:
|
||||
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
||||
|
||||
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
||||
|
||||
try:
|
||||
# noinspection PyTypeChecker
|
||||
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
||||
except NetworkError as exc:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
except Forbidden as exc:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
except Exception as exc:
|
||||
logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc)
|
||||
|
||||
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat:
|
||||
application = self.application
|
||||
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
|
||||
|
||||
if not redis_db:
|
||||
return await application.bot.get_chat(chat_id)
|
||||
|
||||
qname = f"bot:chat:{chat_id}"
|
||||
|
||||
data = await redis_db.client.get(qname)
|
||||
if data:
|
||||
json_data = jsonlib.loads(data)
|
||||
return Chat.de_json(json_data, application.telegram.bot)
|
||||
|
||||
chat_info = await application.telegram.bot.get_chat(chat_id)
|
||||
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
|
||||
return chat_info
|
||||
|
||||
def add_delete_message_job(
|
||||
self,
|
||||
message: Optional[Union[int, Message]] = None,
|
||||
*,
|
||||
delay: int = 60,
|
||||
name: Optional[str] = None,
|
||||
chat: Optional[Union[int, Chat]] = None,
|
||||
context: Optional[CallbackContext] = None,
|
||||
) -> Job:
|
||||
"""延迟删除消息"""
|
||||
|
||||
if isinstance(message, Message):
|
||||
if chat is None:
|
||||
chat = message.chat_id
|
||||
message = message.id
|
||||
|
||||
chat = chat.id if isinstance(chat, Chat) else chat
|
||||
|
||||
job_queue = self.application.job_queue or context.job_queue
|
||||
|
||||
if job_queue is None or chat is None:
|
||||
raise RuntimeError
|
||||
|
||||
return job_queue.run_once(
|
||||
callback=self._delete_message,
|
||||
when=delay,
|
||||
data=message,
|
||||
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
|
||||
chat_id=chat,
|
||||
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def download_resource(url: str, return_path: bool = False) -> str:
|
||||
url_sha1 = sha1(url) # url 的 hash 值
|
||||
pathed_url = Path(url)
|
||||
|
||||
file_name = url_sha1 + pathed_url.suffix
|
||||
file_path = CACHE_DIR.joinpath(file_name)
|
||||
|
||||
if not file_path.exists(): # 若文件不存在,则下载
|
||||
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
|
||||
try:
|
||||
response = await client.get(url)
|
||||
except UnsupportedProtocol:
|
||||
logger.error("链接不支持 url[%s]", url)
|
||||
return ""
|
||||
|
||||
if response.is_error:
|
||||
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
|
||||
raise UrlResourcesNotFoundError(url)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
|
||||
raise UrlResourcesNotFoundError(url)
|
||||
|
||||
async with aiofiles.open(file_path, mode="wb") as f:
|
||||
await f.write(response.content)
|
||||
|
||||
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
|
||||
|
||||
return file_path if return_path else Path(file_path).as_uri()
|
||||
|
||||
@staticmethod
|
||||
def get_args(context: CallbackContext) -> List[str]:
|
||||
args = context.args
|
||||
match = context.match
|
||||
|
||||
if args is None:
|
||||
if match is not None and (command := match.groups()[0]):
|
||||
temp = []
|
||||
command_parts = command.split(" ")
|
||||
for command_part in command_parts:
|
||||
if command_part:
|
||||
temp.append(command_part)
|
||||
return temp
|
||||
return []
|
||||
if len(args) >= 1:
|
||||
return args
|
||||
return []
|
||||
class PluginFuncs(PluginFuncMethods):
|
||||
"""插件方法"""
|
||||
|
||||
|
||||
class ConversationFuncs:
|
||||
|
17
plugin/methods/__init__.py
Normal file
17
plugin/methods/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from .application import ApplicationMethod
|
||||
from .delete_message import DeleteMessage
|
||||
from .download_resource import DownloadResource
|
||||
from .get_args import GetArgs
|
||||
from .get_chat import GetChat
|
||||
from .migrate_data import MigrateData
|
||||
|
||||
|
||||
class PluginFuncMethods(
|
||||
ApplicationMethod,
|
||||
DeleteMessage,
|
||||
DownloadResource,
|
||||
GetArgs,
|
||||
GetChat,
|
||||
MigrateData,
|
||||
):
|
||||
"""插件方法"""
|
17
plugin/methods/application.py
Normal file
17
plugin/methods/application.py
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gram_core.application import Application
|
||||
|
||||
|
||||
class ApplicationMethod:
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError("No application was set for this PluginManager.")
|
||||
return self._application
|
76
plugin/methods/delete_message.py
Normal file
76
plugin/methods/delete_message.py
Normal file
@ -0,0 +1,76 @@
|
||||
from typing import Optional, Union, TYPE_CHECKING
|
||||
|
||||
from telegram import Chat, Message
|
||||
from telegram.error import Forbidden, NetworkError
|
||||
from telegram.ext import CallbackContext, Job
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PluginFuncMethods
|
||||
|
||||
|
||||
class DeleteMessage:
|
||||
async def _delete_message(self: "PluginFuncMethods", context: "CallbackContext") -> None:
|
||||
job = context.job
|
||||
message_id = job.data
|
||||
chat_info = f"chat_id[{job.chat_id}]"
|
||||
|
||||
try:
|
||||
chat = await self.get_chat(job.chat_id)
|
||||
full_name = chat.full_name
|
||||
if full_name:
|
||||
chat_info = f"{full_name}[{chat.id}]"
|
||||
else:
|
||||
chat_info = f"{chat.title}[{chat.id}]"
|
||||
except (NetworkError, Forbidden) as exc:
|
||||
logger.warning("获取 chat info 失败 %s", exc.message)
|
||||
except Exception as exc:
|
||||
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
||||
|
||||
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
||||
|
||||
try:
|
||||
# noinspection PyTypeChecker
|
||||
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
||||
except NetworkError as exc:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
except Forbidden as exc:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
except Exception as exc:
|
||||
logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc)
|
||||
|
||||
def add_delete_message_job(
|
||||
self: "PluginFuncMethods",
|
||||
message: Optional[Union[int, Message]] = None,
|
||||
*,
|
||||
delay: int = 60,
|
||||
name: Optional[str] = None,
|
||||
chat: Optional[Union[int, Chat]] = None,
|
||||
context: Optional[CallbackContext] = None,
|
||||
) -> Job:
|
||||
"""延迟删除消息"""
|
||||
|
||||
if isinstance(message, Message):
|
||||
if chat is None:
|
||||
chat = message.chat_id
|
||||
message = message.id
|
||||
|
||||
chat = chat.id if isinstance(chat, Chat) else chat
|
||||
|
||||
job_queue = self.application.job_queue or context.job_queue
|
||||
|
||||
if job_queue is None or chat is None:
|
||||
raise RuntimeError
|
||||
|
||||
return job_queue.run_once(
|
||||
callback=self._delete_message,
|
||||
when=delay,
|
||||
data=message,
|
||||
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
|
||||
chat_id=chat,
|
||||
job_kwargs={
|
||||
"replace_existing": True,
|
||||
"id": f"{chat}|{message}|delete_message",
|
||||
},
|
||||
)
|
47
plugin/methods/download_resource.py
Normal file
47
plugin/methods/download_resource.py
Normal file
@ -0,0 +1,47 @@
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from httpx import UnsupportedProtocol
|
||||
|
||||
from utils.const import CACHE_DIR, REQUEST_HEADERS
|
||||
from utils.error import UrlResourcesNotFoundError
|
||||
from utils.helpers import sha1
|
||||
from utils.log import logger
|
||||
|
||||
|
||||
class DownloadResource:
|
||||
@staticmethod
|
||||
async def download_resource(url: str, return_path: bool = False) -> str:
|
||||
url_sha1 = sha1(url) # url 的 hash 值
|
||||
pathed_url = Path(url)
|
||||
|
||||
file_name = url_sha1 + pathed_url.suffix
|
||||
file_path = CACHE_DIR.joinpath(file_name)
|
||||
|
||||
if not file_path.exists(): # 若文件不存在,则下载
|
||||
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
|
||||
try:
|
||||
response = await client.get(url)
|
||||
except UnsupportedProtocol:
|
||||
logger.error("链接不支持 url[%s]", url)
|
||||
return ""
|
||||
|
||||
if response.is_error:
|
||||
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
|
||||
raise UrlResourcesNotFoundError(url)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
"download_resource 获取url[%s] 错误 status_code[%s]",
|
||||
url,
|
||||
response.status_code,
|
||||
)
|
||||
raise UrlResourcesNotFoundError(url)
|
||||
|
||||
async with aiofiles.open(file_path, mode="wb") as f:
|
||||
await f.write(response.content)
|
||||
|
||||
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
|
||||
|
||||
return file_path if return_path else Path(file_path).as_uri()
|
24
plugin/methods/get_args.py
Normal file
24
plugin/methods/get_args.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram.ext import CallbackContext
|
||||
|
||||
|
||||
class GetArgs:
|
||||
@staticmethod
|
||||
def get_args(context: "CallbackContext") -> List[str]:
|
||||
args = context.args
|
||||
match = context.match
|
||||
|
||||
if args is None:
|
||||
if match is not None and (command := match.groups()[0]):
|
||||
temp = []
|
||||
command_parts = command.split(" ")
|
||||
for command_part in command_parts:
|
||||
if command_part:
|
||||
temp.append(command_part)
|
||||
return temp
|
||||
return []
|
||||
if len(args) >= 1:
|
||||
return args
|
||||
return []
|
38
plugin/methods/get_chat.py
Normal file
38
plugin/methods/get_chat.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Union, Optional, TYPE_CHECKING
|
||||
|
||||
from telegram import Chat
|
||||
|
||||
from gram_core.dependence.redisdb import RedisDB
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PluginFuncMethods
|
||||
|
||||
try:
|
||||
import ujson as jsonlib
|
||||
except ImportError:
|
||||
import json as jsonlib
|
||||
|
||||
|
||||
class GetChat:
|
||||
async def get_chat(
|
||||
self: "PluginFuncMethods",
|
||||
chat_id: Union[str, int],
|
||||
redis_db: Optional[RedisDB] = None,
|
||||
expire: int = 86400,
|
||||
) -> Chat:
|
||||
application = self.application
|
||||
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
|
||||
|
||||
if not redis_db:
|
||||
return await application.bot.get_chat(chat_id)
|
||||
|
||||
qname = f"bot:chat:{chat_id}"
|
||||
|
||||
data = await redis_db.client.get(qname)
|
||||
if data:
|
||||
json_data = jsonlib.loads(data)
|
||||
return Chat.de_json(json_data, application.telegram.bot)
|
||||
|
||||
chat_info = await application.telegram.bot.get_chat(chat_id)
|
||||
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
|
||||
return chat_info
|
58
plugin/methods/migrate_data.py
Normal file
58
plugin/methods/migrate_data.py
Normal file
@ -0,0 +1,58 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, TypeVar, List, Any, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gram_core.services.players.models import PlayersDataBase as Player
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MigrateDataException(Exception):
|
||||
"""迁移数据异常"""
|
||||
|
||||
def __init__(self, msg: str):
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class IMigrateData(ABC):
|
||||
@abstractmethod
|
||||
async def migrate_data_msg(self) -> str:
|
||||
"""返回迁移数据的提示信息"""
|
||||
|
||||
@abstractmethod
|
||||
async def migrate_data(self) -> bool:
|
||||
"""迁移数据"""
|
||||
|
||||
@staticmethod
|
||||
def get_sql_data_by_key(model: T, keys: Tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
"""通过 key 获取数据"""
|
||||
data = []
|
||||
for i in keys:
|
||||
data.append(getattr(model, i.key))
|
||||
return tuple(data)
|
||||
|
||||
@staticmethod
|
||||
async def filter_sql_data(
|
||||
model: Type[T], service_method, old_user_id: int, new_user_id: int, keys: Tuple[Any, ...]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""过滤数据库数据"""
|
||||
data: List[model] = await service_method(old_user_id)
|
||||
if not data:
|
||||
return [], []
|
||||
new_data = await service_method(new_user_id)
|
||||
new_data_index = [IMigrateData.get_sql_data_by_key(p, keys) for p in new_data]
|
||||
need_migrate = []
|
||||
for d in data:
|
||||
if IMigrateData.get_sql_data_by_key(d, keys) not in new_data_index:
|
||||
need_migrate.append(d)
|
||||
return need_migrate, new_data
|
||||
|
||||
|
||||
class MigrateData:
|
||||
async def get_migrate_data(
|
||||
self, old_user_id: int, new_user_id: int, old_players: List["Player"]
|
||||
) -> Optional[IMigrateData]:
|
||||
"""获取迁移数据"""
|
||||
if not (old_user_id and new_user_id and old_players):
|
||||
return None
|
||||
return None
|
5
pyproject.toml
Normal file
5
pyproject.toml
Normal file
@ -0,0 +1,5 @@
|
||||
# 格式配置
|
||||
[tool.black]
|
||||
include = '\.pyi?$'
|
||||
line-length = 120
|
||||
target-version = ['py311']
|
@ -108,3 +108,10 @@ class PlayerInfoRepository(BaseService.Component):
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
await session.refresh(player)
|
||||
|
||||
async def get_all_by_user_id(self, user_id: int) -> List[PlayerInfoSQLModel]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(PlayerInfoSQLModel).where(PlayerInfoSQLModel.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
players = results.all()
|
||||
return players
|
||||
|
@ -48,3 +48,9 @@ class TaskRepository(BaseService.Component):
|
||||
query = select(Task).where(Task.type == task_type)
|
||||
results = await session.exec(query)
|
||||
return results.all()
|
||||
|
||||
async def get_all_by_user_id(self, user_id: int) -> List[Task]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
query = select(Task).where(Task.user_id == user_id)
|
||||
results = await session.exec(query)
|
||||
return results.all()
|
||||
|
@ -37,6 +37,9 @@ class TaskServices:
|
||||
async def get_all(self):
|
||||
return await self._repository.get_all(self.TASK_TYPE)
|
||||
|
||||
async def get_all_by_user_id(self, user_id: int):
|
||||
return await self._repository.get_all_by_user_id(user_id)
|
||||
|
||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
||||
return Task(
|
||||
user_id=user_id,
|
||||
|
Loading…
Reference in New Issue
Block a user