mirror of
https://github.com/PaiGramTeam/GramCore.git
synced 2024-12-04 18:51:26 +00:00
✨ Support migrate user data
This commit is contained in:
parent
c41bdedfe8
commit
4718860a87
@ -1,10 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
|
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
from arkowrapper import ArkoWrapper
|
from arkowrapper import ArkoWrapper
|
||||||
from async_timeout import timeout
|
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from gram_core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
|
from gram_core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
|
||||||
@ -18,6 +18,11 @@ if TYPE_CHECKING:
|
|||||||
from gram_core.plugin import PluginType
|
from gram_core.plugin import PluginType
|
||||||
from gram_core.builtins.executor import Executor
|
from gram_core.builtins.executor import Executor
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from asyncio import timeout
|
||||||
|
else:
|
||||||
|
from async_timeout import timeout
|
||||||
|
|
||||||
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
|
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
|
||||||
|
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
166
plugin/_funcs.py
166
plugin/_funcs.py
@ -1,27 +1,8 @@
|
|||||||
from pathlib import Path
|
from telegram import ReplyKeyboardRemove, Update
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from telegram.ext import ConversationHandler
|
||||||
|
|
||||||
import aiofiles
|
|
||||||
import httpx
|
|
||||||
from httpx import UnsupportedProtocol
|
|
||||||
from telegram import Chat, Message, ReplyKeyboardRemove, Update
|
|
||||||
from telegram.error import Forbidden, NetworkError
|
|
||||||
from telegram.ext import CallbackContext, ConversationHandler, Job
|
|
||||||
|
|
||||||
from gram_core.dependence.redisdb import RedisDB
|
|
||||||
from gram_core.plugin._handler import conversation, handler
|
from gram_core.plugin._handler import conversation, handler
|
||||||
from utils.const import CACHE_DIR, REQUEST_HEADERS
|
from gram_core.plugin.methods import PluginFuncMethods
|
||||||
from utils.error import UrlResourcesNotFoundError
|
|
||||||
from utils.helpers import sha1
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from gram_core.application import Application
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ujson as jsonlib
|
|
||||||
except ImportError:
|
|
||||||
import json as jsonlib
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"PluginFuncs",
|
"PluginFuncs",
|
||||||
@ -29,145 +10,8 @@ __all__ = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PluginFuncs:
|
class PluginFuncs(PluginFuncMethods):
|
||||||
_application: "Optional[Application]" = None
|
"""插件方法"""
|
||||||
|
|
||||||
def set_application(self, application: "Application") -> None:
|
|
||||||
self._application = application
|
|
||||||
|
|
||||||
@property
|
|
||||||
def application(self) -> "Application":
|
|
||||||
if self._application is None:
|
|
||||||
raise RuntimeError("No application was set for this PluginManager.")
|
|
||||||
return self._application
|
|
||||||
|
|
||||||
async def _delete_message(self, context: CallbackContext) -> None:
|
|
||||||
job = context.job
|
|
||||||
message_id = job.data
|
|
||||||
chat_info = f"chat_id[{job.chat_id}]"
|
|
||||||
|
|
||||||
try:
|
|
||||||
chat = await self.get_chat(job.chat_id)
|
|
||||||
full_name = chat.full_name
|
|
||||||
if full_name:
|
|
||||||
chat_info = f"{full_name}[{chat.id}]"
|
|
||||||
else:
|
|
||||||
chat_info = f"{chat.title}[{chat.id}]"
|
|
||||||
except (NetworkError, Forbidden) as exc:
|
|
||||||
logger.warning("获取 chat info 失败 %s", exc.message)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
|
||||||
|
|
||||||
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
|
||||||
except NetworkError as exc:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
|
||||||
except Forbidden as exc:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc)
|
|
||||||
|
|
||||||
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, expire: int = 86400) -> Chat:
|
|
||||||
application = self.application
|
|
||||||
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
|
|
||||||
|
|
||||||
if not redis_db:
|
|
||||||
return await application.bot.get_chat(chat_id)
|
|
||||||
|
|
||||||
qname = f"bot:chat:{chat_id}"
|
|
||||||
|
|
||||||
data = await redis_db.client.get(qname)
|
|
||||||
if data:
|
|
||||||
json_data = jsonlib.loads(data)
|
|
||||||
return Chat.de_json(json_data, application.telegram.bot)
|
|
||||||
|
|
||||||
chat_info = await application.telegram.bot.get_chat(chat_id)
|
|
||||||
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
|
|
||||||
return chat_info
|
|
||||||
|
|
||||||
def add_delete_message_job(
|
|
||||||
self,
|
|
||||||
message: Optional[Union[int, Message]] = None,
|
|
||||||
*,
|
|
||||||
delay: int = 60,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
chat: Optional[Union[int, Chat]] = None,
|
|
||||||
context: Optional[CallbackContext] = None,
|
|
||||||
) -> Job:
|
|
||||||
"""延迟删除消息"""
|
|
||||||
|
|
||||||
if isinstance(message, Message):
|
|
||||||
if chat is None:
|
|
||||||
chat = message.chat_id
|
|
||||||
message = message.id
|
|
||||||
|
|
||||||
chat = chat.id if isinstance(chat, Chat) else chat
|
|
||||||
|
|
||||||
job_queue = self.application.job_queue or context.job_queue
|
|
||||||
|
|
||||||
if job_queue is None or chat is None:
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
return job_queue.run_once(
|
|
||||||
callback=self._delete_message,
|
|
||||||
when=delay,
|
|
||||||
data=message,
|
|
||||||
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
|
|
||||||
chat_id=chat,
|
|
||||||
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def download_resource(url: str, return_path: bool = False) -> str:
|
|
||||||
url_sha1 = sha1(url) # url 的 hash 值
|
|
||||||
pathed_url = Path(url)
|
|
||||||
|
|
||||||
file_name = url_sha1 + pathed_url.suffix
|
|
||||||
file_path = CACHE_DIR.joinpath(file_name)
|
|
||||||
|
|
||||||
if not file_path.exists(): # 若文件不存在,则下载
|
|
||||||
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
|
|
||||||
try:
|
|
||||||
response = await client.get(url)
|
|
||||||
except UnsupportedProtocol:
|
|
||||||
logger.error("链接不支持 url[%s]", url)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if response.is_error:
|
|
||||||
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
|
|
||||||
raise UrlResourcesNotFoundError(url)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
|
|
||||||
raise UrlResourcesNotFoundError(url)
|
|
||||||
|
|
||||||
async with aiofiles.open(file_path, mode="wb") as f:
|
|
||||||
await f.write(response.content)
|
|
||||||
|
|
||||||
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
|
|
||||||
|
|
||||||
return file_path if return_path else Path(file_path).as_uri()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_args(context: CallbackContext) -> List[str]:
|
|
||||||
args = context.args
|
|
||||||
match = context.match
|
|
||||||
|
|
||||||
if args is None:
|
|
||||||
if match is not None and (command := match.groups()[0]):
|
|
||||||
temp = []
|
|
||||||
command_parts = command.split(" ")
|
|
||||||
for command_part in command_parts:
|
|
||||||
if command_part:
|
|
||||||
temp.append(command_part)
|
|
||||||
return temp
|
|
||||||
return []
|
|
||||||
if len(args) >= 1:
|
|
||||||
return args
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationFuncs:
|
class ConversationFuncs:
|
||||||
|
17
plugin/methods/__init__.py
Normal file
17
plugin/methods/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from .application import ApplicationMethod
|
||||||
|
from .delete_message import DeleteMessage
|
||||||
|
from .download_resource import DownloadResource
|
||||||
|
from .get_args import GetArgs
|
||||||
|
from .get_chat import GetChat
|
||||||
|
from .migrate_data import MigrateData
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFuncMethods(
|
||||||
|
ApplicationMethod,
|
||||||
|
DeleteMessage,
|
||||||
|
DownloadResource,
|
||||||
|
GetArgs,
|
||||||
|
GetChat,
|
||||||
|
MigrateData,
|
||||||
|
):
|
||||||
|
"""插件方法"""
|
17
plugin/methods/application.py
Normal file
17
plugin/methods/application.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gram_core.application import Application
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationMethod:
|
||||||
|
_application: "Optional[Application]" = None
|
||||||
|
|
||||||
|
def set_application(self, application: "Application") -> None:
|
||||||
|
self._application = application
|
||||||
|
|
||||||
|
@property
|
||||||
|
def application(self) -> "Application":
|
||||||
|
if self._application is None:
|
||||||
|
raise RuntimeError("No application was set for this PluginManager.")
|
||||||
|
return self._application
|
76
plugin/methods/delete_message.py
Normal file
76
plugin/methods/delete_message.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
from telegram import Chat, Message
|
||||||
|
from telegram.error import Forbidden, NetworkError
|
||||||
|
from telegram.ext import CallbackContext, Job
|
||||||
|
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import PluginFuncMethods
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteMessage:
|
||||||
|
async def _delete_message(self: "PluginFuncMethods", context: "CallbackContext") -> None:
|
||||||
|
job = context.job
|
||||||
|
message_id = job.data
|
||||||
|
chat_info = f"chat_id[{job.chat_id}]"
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat = await self.get_chat(job.chat_id)
|
||||||
|
full_name = chat.full_name
|
||||||
|
if full_name:
|
||||||
|
chat_info = f"{full_name}[{chat.id}]"
|
||||||
|
else:
|
||||||
|
chat_info = f"{chat.title}[{chat.id}]"
|
||||||
|
except (NetworkError, Forbidden) as exc:
|
||||||
|
logger.warning("获取 chat info 失败 %s", exc.message)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
||||||
|
|
||||||
|
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
||||||
|
except NetworkError as exc:
|
||||||
|
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||||
|
except Forbidden as exc:
|
||||||
|
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc_info=exc)
|
||||||
|
|
||||||
|
def add_delete_message_job(
|
||||||
|
self: "PluginFuncMethods",
|
||||||
|
message: Optional[Union[int, Message]] = None,
|
||||||
|
*,
|
||||||
|
delay: int = 60,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
chat: Optional[Union[int, Chat]] = None,
|
||||||
|
context: Optional[CallbackContext] = None,
|
||||||
|
) -> Job:
|
||||||
|
"""延迟删除消息"""
|
||||||
|
|
||||||
|
if isinstance(message, Message):
|
||||||
|
if chat is None:
|
||||||
|
chat = message.chat_id
|
||||||
|
message = message.id
|
||||||
|
|
||||||
|
chat = chat.id if isinstance(chat, Chat) else chat
|
||||||
|
|
||||||
|
job_queue = self.application.job_queue or context.job_queue
|
||||||
|
|
||||||
|
if job_queue is None or chat is None:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
return job_queue.run_once(
|
||||||
|
callback=self._delete_message,
|
||||||
|
when=delay,
|
||||||
|
data=message,
|
||||||
|
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
|
||||||
|
chat_id=chat,
|
||||||
|
job_kwargs={
|
||||||
|
"replace_existing": True,
|
||||||
|
"id": f"{chat}|{message}|delete_message",
|
||||||
|
},
|
||||||
|
)
|
47
plugin/methods/download_resource.py
Normal file
47
plugin/methods/download_resource.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import httpx
|
||||||
|
from httpx import UnsupportedProtocol
|
||||||
|
|
||||||
|
from utils.const import CACHE_DIR, REQUEST_HEADERS
|
||||||
|
from utils.error import UrlResourcesNotFoundError
|
||||||
|
from utils.helpers import sha1
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadResource:
|
||||||
|
@staticmethod
|
||||||
|
async def download_resource(url: str, return_path: bool = False) -> str:
|
||||||
|
url_sha1 = sha1(url) # url 的 hash 值
|
||||||
|
pathed_url = Path(url)
|
||||||
|
|
||||||
|
file_name = url_sha1 + pathed_url.suffix
|
||||||
|
file_path = CACHE_DIR.joinpath(file_name)
|
||||||
|
|
||||||
|
if not file_path.exists(): # 若文件不存在,则下载
|
||||||
|
async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=10) as client:
|
||||||
|
try:
|
||||||
|
response = await client.get(url)
|
||||||
|
except UnsupportedProtocol:
|
||||||
|
logger.error("链接不支持 url[%s]", url)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if response.is_error:
|
||||||
|
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
|
||||||
|
raise UrlResourcesNotFoundError(url)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
"download_resource 获取url[%s] 错误 status_code[%s]",
|
||||||
|
url,
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
raise UrlResourcesNotFoundError(url)
|
||||||
|
|
||||||
|
async with aiofiles.open(file_path, mode="wb") as f:
|
||||||
|
await f.write(response.content)
|
||||||
|
|
||||||
|
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
|
||||||
|
|
||||||
|
return file_path if return_path else Path(file_path).as_uri()
|
24
plugin/methods/get_args.py
Normal file
24
plugin/methods/get_args.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import List, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from telegram.ext import CallbackContext
|
||||||
|
|
||||||
|
|
||||||
|
class GetArgs:
|
||||||
|
@staticmethod
|
||||||
|
def get_args(context: "CallbackContext") -> List[str]:
|
||||||
|
args = context.args
|
||||||
|
match = context.match
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
if match is not None and (command := match.groups()[0]):
|
||||||
|
temp = []
|
||||||
|
command_parts = command.split(" ")
|
||||||
|
for command_part in command_parts:
|
||||||
|
if command_part:
|
||||||
|
temp.append(command_part)
|
||||||
|
return temp
|
||||||
|
return []
|
||||||
|
if len(args) >= 1:
|
||||||
|
return args
|
||||||
|
return []
|
38
plugin/methods/get_chat.py
Normal file
38
plugin/methods/get_chat.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from typing import Union, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from telegram import Chat
|
||||||
|
|
||||||
|
from gram_core.dependence.redisdb import RedisDB
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import PluginFuncMethods
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ujson as jsonlib
|
||||||
|
except ImportError:
|
||||||
|
import json as jsonlib
|
||||||
|
|
||||||
|
|
||||||
|
class GetChat:
|
||||||
|
async def get_chat(
|
||||||
|
self: "PluginFuncMethods",
|
||||||
|
chat_id: Union[str, int],
|
||||||
|
redis_db: Optional[RedisDB] = None,
|
||||||
|
expire: int = 86400,
|
||||||
|
) -> Chat:
|
||||||
|
application = self.application
|
||||||
|
redis_db: RedisDB = redis_db or self.application.managers.dependency_map.get(RedisDB, None)
|
||||||
|
|
||||||
|
if not redis_db:
|
||||||
|
return await application.bot.get_chat(chat_id)
|
||||||
|
|
||||||
|
qname = f"bot:chat:{chat_id}"
|
||||||
|
|
||||||
|
data = await redis_db.client.get(qname)
|
||||||
|
if data:
|
||||||
|
json_data = jsonlib.loads(data)
|
||||||
|
return Chat.de_json(json_data, application.telegram.bot)
|
||||||
|
|
||||||
|
chat_info = await application.telegram.bot.get_chat(chat_id)
|
||||||
|
await redis_db.client.set(qname, chat_info.to_json(), ex=expire)
|
||||||
|
return chat_info
|
58
plugin/methods/migrate_data.py
Normal file
58
plugin/methods/migrate_data.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, TypeVar, List, Any, Tuple, Type, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gram_core.services.players.models import PlayersDataBase as Player
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class MigrateDataException(Exception):
|
||||||
|
"""迁移数据异常"""
|
||||||
|
|
||||||
|
def __init__(self, msg: str):
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
|
class IMigrateData(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def migrate_data_msg(self) -> str:
|
||||||
|
"""返回迁移数据的提示信息"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def migrate_data(self) -> bool:
|
||||||
|
"""迁移数据"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_sql_data_by_key(model: T, keys: Tuple[Any, ...]) -> tuple[Any, ...]:
|
||||||
|
"""通过 key 获取数据"""
|
||||||
|
data = []
|
||||||
|
for i in keys:
|
||||||
|
data.append(getattr(model, i.key))
|
||||||
|
return tuple(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def filter_sql_data(
|
||||||
|
model: Type[T], service_method, old_user_id: int, new_user_id: int, keys: Tuple[Any, ...]
|
||||||
|
) -> Tuple[List[T], List[T]]:
|
||||||
|
"""过滤数据库数据"""
|
||||||
|
data: List[model] = await service_method(old_user_id)
|
||||||
|
if not data:
|
||||||
|
return [], []
|
||||||
|
new_data = await service_method(new_user_id)
|
||||||
|
new_data_index = [IMigrateData.get_sql_data_by_key(p, keys) for p in new_data]
|
||||||
|
need_migrate = []
|
||||||
|
for d in data:
|
||||||
|
if IMigrateData.get_sql_data_by_key(d, keys) not in new_data_index:
|
||||||
|
need_migrate.append(d)
|
||||||
|
return need_migrate, new_data
|
||||||
|
|
||||||
|
|
||||||
|
class MigrateData:
|
||||||
|
async def get_migrate_data(
|
||||||
|
self, old_user_id: int, new_user_id: int, old_players: List["Player"]
|
||||||
|
) -> Optional[IMigrateData]:
|
||||||
|
"""获取迁移数据"""
|
||||||
|
if not (old_user_id and new_user_id and old_players):
|
||||||
|
return None
|
||||||
|
return None
|
5
pyproject.toml
Normal file
5
pyproject.toml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# 格式配置
|
||||||
|
[tool.black]
|
||||||
|
include = '\.pyi?$'
|
||||||
|
line-length = 120
|
||||||
|
target-version = ['py311']
|
@ -108,3 +108,10 @@ class PlayerInfoRepository(BaseService.Component):
|
|||||||
session.add(player)
|
session.add(player)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(player)
|
await session.refresh(player)
|
||||||
|
|
||||||
|
async def get_all_by_user_id(self, user_id: int) -> List[PlayerInfoSQLModel]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(PlayerInfoSQLModel).where(PlayerInfoSQLModel.user_id == user_id)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
players = results.all()
|
||||||
|
return players
|
||||||
|
@ -48,3 +48,9 @@ class TaskRepository(BaseService.Component):
|
|||||||
query = select(Task).where(Task.type == task_type)
|
query = select(Task).where(Task.type == task_type)
|
||||||
results = await session.exec(query)
|
results = await session.exec(query)
|
||||||
return results.all()
|
return results.all()
|
||||||
|
|
||||||
|
async def get_all_by_user_id(self, user_id: int) -> List[Task]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
query = select(Task).where(Task.user_id == user_id)
|
||||||
|
results = await session.exec(query)
|
||||||
|
return results.all()
|
||||||
|
@ -37,6 +37,9 @@ class TaskServices:
|
|||||||
async def get_all(self):
|
async def get_all(self):
|
||||||
return await self._repository.get_all(self.TASK_TYPE)
|
return await self._repository.get_all(self.TASK_TYPE)
|
||||||
|
|
||||||
|
async def get_all_by_user_id(self, user_id: int):
|
||||||
|
return await self._repository.get_all_by_user_id(user_id)
|
||||||
|
|
||||||
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
|
||||||
return Task(
|
return Task(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user