Support migrate user data

This commit is contained in:
omg-xtao 2023-12-16 17:36:19 +08:00 committed by GitHub
parent c41bdedfe8
commit 4718860a87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 309 additions and 162 deletions

View File

@ -1,10 +1,10 @@
import asyncio
import sys
from importlib import import_module
from pathlib import Path
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
from arkowrapper import ArkoWrapper
from async_timeout import timeout
from typing_extensions import ParamSpec
from gram_core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
@ -18,6 +18,11 @@ if TYPE_CHECKING:
from gram_core.plugin import PluginType
from gram_core.builtins.executor import Executor
if sys.version_info >= (3, 11):
from asyncio import timeout
else:
from async_timeout import timeout
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
R = TypeVar("R")

View File

@ -1,27 +1,8 @@
from pathlib import Path
from typing import List, Optional, Union, TYPE_CHECKING
from telegram import ReplyKeyboardRemove, Update
from telegram.ext import ConversationHandler
import aiofiles
import httpx
from httpx import UnsupportedProtocol
from telegram import Chat, Message, ReplyKeyboardRemove, Update
from telegram.error import Forbidden, NetworkError
from telegram.ext import CallbackContext, ConversationHandler, Job
from gram_core.dependence.redisdb import RedisDB
from gram_core.plugin._handler import conversation, handler
from utils.const import CACHE_DIR, REQUEST_HEADERS
from utils.error import UrlResourcesNotFoundError
from utils.helpers import sha1
from utils.log import logger
if TYPE_CHECKING:
from gram_core.application import Application
try:
import ujson as jsonlib
except ImportError:
import json as jsonlib
from gram_core.plugin.methods import PluginFuncMethods
__all__ = (
"PluginFuncs",
@ -29,145 +10,8 @@ __all__ = (
)
class PluginFuncs:
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError("No application was set for this PluginManager.")
return self._application
async def _delete_message(self, context: CallbackContext) -> None:
job = context.job
message_id = job.data
chat_info = f"chat_id[{job.chat_id}]"
try:
chat = await self.get_chat(job.chat_id)
full_name = chat.full_name
if full_name:
chat_info = f"{full_name}[{chat.id}]"
else:
chat_info = f"{chat.title}[{chat.id}]"
except (NetworkError, Forbidden) as exc:
logger.warning("获取 chat info 失败 %s", exc.message)
except Exception as exc:
logger.warning("获取 chat info 消息失败 %s", str(exc))
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
try:
# noinspection PyTypeChecker
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
except NetworkError as exc:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
except Forbidden as exc:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
except Exception as exc:
logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc)
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat:
application = self.application
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
if not redis_db:
return await application.bot.get_chat(chat_id)
qname = f"bot:chat:{chat_id}"
data = await redis_db.client.get(qname)
if data:
json_data = jsonlib.loads(data)
return Chat.de_json(json_data, application.telegram.bot)
chat_info = await application.telegram.bot.get_chat(chat_id)
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
return chat_info
def add_delete_message_job(
self,
message: Optional[Union[int, Message]] = None,
*,
delay: int = 60,
name: Optional[str] = None,
chat: Optional[Union[int, Chat]] = None,
context: Optional[CallbackContext] = None,
) -> Job:
"""延迟删除消息"""
if isinstance(message, Message):
if chat is None:
chat = message.chat_id
message = message.id
chat = chat.id if isinstance(chat, Chat) else chat
job_queue = self.application.job_queue or context.job_queue
if job_queue is None or chat is None:
raise RuntimeError
return job_queue.run_once(
callback=self._delete_message,
when=delay,
data=message,
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
chat_id=chat,
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
)
@staticmethod
async def download_resource(url: str, return_path: bool = False) -> str:
url_sha1 = sha1(url) # url 的 hash 值
pathed_url = Path(url)
file_name = url_sha1 + pathed_url.suffix
file_path = CACHE_DIR.joinpath(file_name)
if not file_path.exists(): # 若文件不存在,则下载
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
try:
response = await client.get(url)
except UnsupportedProtocol:
logger.error("链接不支持 url[%s]", url)
return ""
if response.is_error:
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
raise UrlResourcesNotFoundError(url)
if response.status_code != 200:
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
raise UrlResourcesNotFoundError(url)
async with aiofiles.open(file_path, mode="wb") as f:
await f.write(response.content)
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
return file_path if return_path else Path(file_path).as_uri()
@staticmethod
def get_args(context: CallbackContext) -> List[str]:
args = context.args
match = context.match
if args is None:
if match is not None and (command := match.groups()[0]):
temp = []
command_parts = command.split(" ")
for command_part in command_parts:
if command_part:
temp.append(command_part)
return temp
return []
if len(args) >= 1:
return args
return []
class PluginFuncs(PluginFuncMethods):
"""插件方法"""
class ConversationFuncs:

View File

@ -0,0 +1,17 @@
from .application import ApplicationMethod
from .delete_message import DeleteMessage
from .download_resource import DownloadResource
from .get_args import GetArgs
from .get_chat import GetChat
from .migrate_data import MigrateData
class PluginFuncMethods(
ApplicationMethod,
DeleteMessage,
DownloadResource,
GetArgs,
GetChat,
MigrateData,
):
"""插件方法"""

View File

@ -0,0 +1,17 @@
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from gram_core.application import Application
class ApplicationMethod:
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError("No application was set for this PluginManager.")
return self._application

View File

@ -0,0 +1,76 @@
from typing import Optional, Union, TYPE_CHECKING
from telegram import Chat, Message
from telegram.error import Forbidden, NetworkError
from telegram.ext import CallbackContext, Job
from utils.log import logger
if TYPE_CHECKING:
from . import PluginFuncMethods
class DeleteMessage:
async def _delete_message(self: "PluginFuncMethods", context: "CallbackContext") -> None:
job = context.job
message_id = job.data
chat_info = f"chat_id[{job.chat_id}]"
try:
chat = await self.get_chat(job.chat_id)
full_name = chat.full_name
if full_name:
chat_info = f"{full_name}[{chat.id}]"
else:
chat_info = f"{chat.title}[{chat.id}]"
except (NetworkError, Forbidden) as exc:
logger.warning("获取 chat info 失败 %s", exc.message)
except Exception as exc:
logger.warning("获取 chat info 消息失败 %s", str(exc))
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
try:
# noinspection PyTypeChecker
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
except NetworkError as exc:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
except Forbidden as exc:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
except Exception as exc:
logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc)
def add_delete_message_job(
self: "PluginFuncMethods",
message: Optional[Union[int, Message]] = None,
*,
delay: int = 60,
name: Optional[str] = None,
chat: Optional[Union[int, Chat]] = None,
context: Optional[CallbackContext] = None,
) -> Job:
"""延迟删除消息"""
if isinstance(message, Message):
if chat is None:
chat = message.chat_id
message = message.id
chat = chat.id if isinstance(chat, Chat) else chat
job_queue = self.application.job_queue or context.job_queue
if job_queue is None or chat is None:
raise RuntimeError
return job_queue.run_once(
callback=self._delete_message,
when=delay,
data=message,
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
chat_id=chat,
job_kwargs={
"replace_existing": True,
"id": f"{chat}|{message}|delete_message",
},
)

View File

@ -0,0 +1,47 @@
from pathlib import Path
import aiofiles
import httpx
from httpx import UnsupportedProtocol
from utils.const import CACHE_DIR, REQUEST_HEADERS
from utils.error import UrlResourcesNotFoundError
from utils.helpers import sha1
from utils.log import logger
class DownloadResource:
@staticmethod
async def download_resource(url: str, return_path: bool = False) -> str:
url_sha1 = sha1(url) # url 的 hash 值
pathed_url = Path(url)
file_name = url_sha1 + pathed_url.suffix
file_path = CACHE_DIR.joinpath(file_name)
if not file_path.exists(): # 若文件不存在,则下载
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
try:
response = await client.get(url)
except UnsupportedProtocol:
logger.error("链接不支持 url[%s]", url)
return ""
if response.is_error:
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
raise UrlResourcesNotFoundError(url)
if response.status_code != 200:
logger.error(
"download_resource 获取url[%s] 错误 status_code[%s]",
url,
response.status_code,
)
raise UrlResourcesNotFoundError(url)
async with aiofiles.open(file_path, mode="wb") as f:
await f.write(response.content)
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
return file_path if return_path else Path(file_path).as_uri()

View File

@ -0,0 +1,24 @@
from typing import List, TYPE_CHECKING
if TYPE_CHECKING:
from telegram.ext import CallbackContext
class GetArgs:
@staticmethod
def get_args(context: "CallbackContext") -> List[str]:
args = context.args
match = context.match
if args is None:
if match is not None and (command := match.groups()[0]):
temp = []
command_parts = command.split(" ")
for command_part in command_parts:
if command_part:
temp.append(command_part)
return temp
return []
if len(args) >= 1:
return args
return []

View File

@ -0,0 +1,38 @@
from typing import Union, Optional, TYPE_CHECKING
from telegram import Chat
from gram_core.dependence.redisdb import RedisDB
if TYPE_CHECKING:
from . import PluginFuncMethods
try:
import ujson as jsonlib
except ImportError:
import json as jsonlib
class GetChat:
async def get_chat(
self: "PluginFuncMethods",
chat_id: Union[str, int],
redis_db: Optional[RedisDB] = None,
expire: int = 86400,
) -> Chat:
application = self.application
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
if not redis_db:
return await application.bot.get_chat(chat_id)
qname = f"bot:chat:{chat_id}"
data = await redis_db.client.get(qname)
if data:
json_data = jsonlib.loads(data)
return Chat.de_json(json_data, application.telegram.bot)
chat_info = await application.telegram.bot.get_chat(chat_id)
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
return chat_info

View File

@ -0,0 +1,58 @@
from abc import ABC, abstractmethod
from typing import Optional, TypeVar, List, Any, Tuple, Type, TYPE_CHECKING
if TYPE_CHECKING:
from gram_core.services.players.models import PlayersDataBase as Player
T = TypeVar("T")
class MigrateDataException(Exception):
"""迁移数据异常"""
def __init__(self, msg: str):
self.msg = msg
class IMigrateData(ABC):
@abstractmethod
async def migrate_data_msg(self) -> str:
"""返回迁移数据的提示信息"""
@abstractmethod
async def migrate_data(self) -> bool:
"""迁移数据"""
@staticmethod
def get_sql_data_by_key(model: T, keys: Tuple[Any, ...]) -> tuple[Any, ...]:
"""通过 key 获取数据"""
data = []
for i in keys:
data.append(getattr(model, i.key))
return tuple(data)
@staticmethod
async def filter_sql_data(
model: Type[T], service_method, old_user_id: int, new_user_id: int, keys: Tuple[Any, ...]
) -> Tuple[List[T], List[T]]:
"""过滤数据库数据"""
data: List[model] = await service_method(old_user_id)
if not data:
return [], []
new_data = await service_method(new_user_id)
new_data_index = [IMigrateData.get_sql_data_by_key(p, keys) for p in new_data]
need_migrate = []
for d in data:
if IMigrateData.get_sql_data_by_key(d, keys) not in new_data_index:
need_migrate.append(d)
return need_migrate, new_data
class MigrateData:
async def get_migrate_data(
self, old_user_id: int, new_user_id: int, old_players: List["Player"]
) -> Optional[IMigrateData]:
"""获取迁移数据"""
if not (old_user_id and new_user_id and old_players):
return None
return None

5
pyproject.toml Normal file
View File

@ -0,0 +1,5 @@
# 格式配置
[tool.black]
include = '\.pyi?$'
line-length = 120
target-version = ['py311']

View File

@ -108,3 +108,10 @@ class PlayerInfoRepository(BaseService.Component):
session.add(player)
await session.commit()
await session.refresh(player)
async def get_all_by_user_id(self, user_id: int) -> List[PlayerInfoSQLModel]:
async with AsyncSession(self.engine) as session:
statement = select(PlayerInfoSQLModel).where(PlayerInfoSQLModel.user_id == user_id)
results = await session.exec(statement)
players = results.all()
return players

View File

@ -48,3 +48,9 @@ class TaskRepository(BaseService.Component):
query = select(Task).where(Task.type == task_type)
results = await session.exec(query)
return results.all()
async def get_all_by_user_id(self, user_id: int) -> List[Task]:
async with AsyncSession(self.engine) as session:
query = select(Task).where(Task.user_id == user_id)
results = await session.exec(query)
return results.all()

View File

@ -37,6 +37,9 @@ class TaskServices:
async def get_all(self):
return await self._repository.get_all(self.TASK_TYPE)
async def get_all_by_user_id(self, user_id: int):
return await self._repository.get_all_by_user_id(user_id)
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
return Task(
user_id=user_id,