From a714fa3ebdcbbba94b77359dcf6711048162f701 Mon Sep 17 00:00:00 2001 From: xtaodada Date: Fri, 5 Aug 2022 22:45:54 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20=E6=89=B9=E9=87=8F=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=8F=92=E4=BB=B6=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/cookies/repositories.py | 33 +++++++++++++-------------------- apps/user/repositories.py | 6 ++++-- plugins/abyss.py | 2 +- plugins/adduser.py | 14 +++++++++----- plugins/daily_note.py | 5 +++-- plugins/ledger.py | 1 + utils/helpers.py | 2 +- 7 files changed, 32 insertions(+), 31 deletions(-) diff --git a/apps/cookies/repositories.py b/apps/cookies/repositories.py index c26dbcb..4505892 100644 --- a/apps/cookies/repositories.py +++ b/apps/cookies/repositories.py @@ -17,12 +17,12 @@ class CookiesRepository: async with self.mysql.Session() as session: session = cast(AsyncSession, session) if region == RegionEnum.HYPERION: - db_data = HyperionCookie(user_id=user_id, cookie=cookies) + db_data = HyperionCookie(user_id=user_id, cookies=cookies) elif region == RegionEnum.HOYOLAB: - db_data = HoyolabCookie(user_id=user_id, cookie=cookies) + db_data = HoyolabCookie(user_id=user_id, cookies=cookies) else: raise RegionNotFoundError(region.name) - await session.add(db_data) + session.add(db_data) await session.commit() async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum): @@ -30,26 +30,19 @@ class CookiesRepository: session = cast(AsyncSession, session) if region == RegionEnum.HYPERION: statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id) - results = await session.exec(statement) - db_cookies = results.one()[0] - if db_cookies is None: - raise CookiesNotFoundError(user_id) - db_cookies.cookies = cookies - session.add(db_cookies) - await session.commit() - await session.refresh(db_cookies) elif region == RegionEnum.HOYOLAB: - statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id) - results = await session.add(statement) - db_cookies = results.one()[0] - if db_cookies is None: - raise CookiesNotFoundError(user_id) - db_cookies.cookie = cookies - session.add(db_cookies) - await session.commit() - await session.refresh(db_cookies) + statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id) else: raise RegionNotFoundError(region.name) + results = await session.exec(statement) + db_cookies = results.one() + if db_cookies is None: + raise CookiesNotFoundError(user_id) + db_cookies = db_cookies[0] + db_cookies.cookies = cookies + session.add(db_cookies) + await session.commit() + await session.refresh(db_cookies) async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum): async with self.mysql.Session() as session: diff --git a/apps/user/repositories.py b/apps/user/repositories.py index 3bfc880..728ca55 100644 --- a/apps/user/repositories.py +++ b/apps/user/repositories.py @@ -17,8 +17,10 @@ class UserRepository: session = cast(AsyncSession, session) statement = select(User).where(User.user_id == user_id) results = await session.exec(statement) - user = results.first() - return user[0] + if user := results.first(): + return user[0] + else: + raise UserNotFoundError(user_id) async def update_user(self, user: User): async with self.mysql.Session() as session: diff --git a/plugins/abyss.py b/plugins/abyss.py index 0d9f3c7..73bda50 100644 --- a/plugins/abyss.py +++ b/plugins/abyss.py @@ -90,7 +90,7 @@ class Abyss(BasePlugins): abyss_data["most_played_list"].append(temp) return abyss_data - @restricts + @restricts() @error_callable async def command_start(self, update: Update, context: CallbackContext) -> None: user = update.effective_user diff --git a/plugins/adduser.py b/plugins/adduser.py index 4988b69..a176534 100644 --- a/plugins/adduser.py +++ b/plugins/adduser.py @@ -9,6 +9,7 @@ from telegram.helpers import escape_markdown from apps.cookies.services import CookiesService from apps.user.models import User +from apps.user.repositories import UserNotFoundError from apps.user.services import UserService from logger import Log from models.base import RegionEnum @@ -77,7 +78,10 @@ class AddUser(BasePlugins): async def check_server(self, update: Update, context: CallbackContext) -> int: user = update.effective_user add_user_command_data: AddUserCommandData = context.chat_data.get("add_user_command_data") - user_info = await self.user_service.get_user_by_id(user.id) + try: + user_info = await self.user_service.get_user_by_id(user.id) + except UserNotFoundError: + user_info = None add_user_command_data.user = user_info if update.message.text == "退出": await update.message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) @@ -188,11 +192,11 @@ class AddUser(BasePlugins): elif update.message.text == "确认": if add_user_command_data.user is None: if add_user_command_data.region == RegionEnum.HYPERION: - user_db = User(user_id=user.id, yuanshen_id=add_user_command_data.game_uid, - region=add_user_command_data) + user_db = User(user_id=user.id, yuanshen_uid=add_user_command_data.game_uid, + region=add_user_command_data.region) elif add_user_command_data.region == RegionEnum.HOYOLAB: - user_db = User(user_id=user.id, genshin_id=add_user_command_data.game_uid, - region=add_user_command_data) + user_db = User(user_id=user.id, genshin_uid=add_user_command_data.game_uid, + region=add_user_command_data.region) else: await update.message.reply_text("数据错误") return ConversationHandler.END diff --git a/plugins/daily_note.py b/plugins/daily_note.py index 0f377e7..d781ce3 100644 --- a/plugins/daily_note.py +++ b/plugins/daily_note.py @@ -1,5 +1,6 @@ import datetime import os +from typing import Optional from genshin import DataNotPublic from telegram import Update @@ -88,9 +89,9 @@ class DailyNote(BasePlugins): {"width": 600, "height": 548}, full_page=False) return png_data - @restricts + @restricts() @error_callable - async def command_start(self, update: Update, context: CallbackContext) -> None: + async def command_start(self, update: Update, context: CallbackContext) -> Optional[int]: user = update.effective_user message = update.message Log.info(f"用户 {user.full_name}[{user.id}] 查询游戏状态命令请求") diff --git a/plugins/ledger.py b/plugins/ledger.py index b451368..330da2d 100644 --- a/plugins/ledger.py +++ b/plugins/ledger.py @@ -27,6 +27,7 @@ def check_ledger_month(context: CallbackContext) -> int: args = get_all_args(context) if len(args) >= 1: month = args[0] + elif isinstance(month, int): pass elif re_data := re.findall(r"\d+", month): month = int(re_data[0]) diff --git a/utils/helpers.py b/utils/helpers.py index 2f401e2..77a243f 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -74,7 +74,7 @@ async def get_genshin_client(user_id: int, user_service: UserService, cookies_se client = genshin.Client(cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn", uid=uid) else: - raise TypeError(f"region is not RegionEnum.NULL") + raise TypeError("region is not RegionEnum.NULL") return client