From 9e7637203ac915f3ff5a6dd598688d5be4e9a016 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B4=9B=E6=B0=B4=E5=B1=85=E5=AE=A4?= Date: Thu, 4 Aug 2022 21:18:23 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=20=E9=87=8D=E6=9E=84=20`mysql`=20?= =?UTF-8?q?=E9=80=9A=E4=BF=A1=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 移除 `aiomysql` 依赖 添加 `SQLAlchemy` `sqlmodel` `asyncmy` 依赖 --- app/admin/models.py | 8 ++ app/admin/repositories.py | 45 +++++----- app/admin/service.py | 6 +- app/cookies/models.py | 25 ++++++ app/cookies/repositories.py | 161 +++++++++++++++++++++--------------- app/cookies/service.py | 15 ++-- app/quiz/repositories.py | 66 ++------------- app/user/models.py | 18 ++-- app/user/repositories.py | 26 +++--- app/user/services.py | 9 +- requirements.txt | 5 +- utils/mysql.py | 65 +++------------ 12 files changed, 210 insertions(+), 239 deletions(-) create mode 100644 app/admin/models.py create mode 100644 app/cookies/models.py diff --git a/app/admin/models.py b/app/admin/models.py new file mode 100644 index 00000000..61bd364c --- /dev/null +++ b/app/admin/models.py @@ -0,0 +1,8 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + + +class Admin(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + user_id: int = Field() \ No newline at end of file diff --git a/app/admin/repositories.py b/app/admin/repositories.py index ee968846..03de648f 100644 --- a/app/admin/repositories.py +++ b/app/admin/repositories.py @@ -1,6 +1,10 @@ -from typing import List +from typing import List, cast + +from sqlalchemy import select +from sqlmodel.ext.asyncio.session import AsyncSession from utils.mysql import MySQL +from .models import Admin class BotAdminRepository: @@ -8,30 +12,21 @@ class BotAdminRepository: self.mysql = mysql async def delete_by_user_id(self, user_id: int): - query = """ - DELETE FROM `admin` - WHERE user_id=%s; - """ - query_args = (user_id,) - await self.mysql.execute_and_fetchall(query, query_args) + async with self.mysql.Session() as session: + session = cast(AsyncSession, session) + statement = select(Admin).where(Admin.user_id == user_id) + results = await session.exec(statement) + admin = results.one() + await session.delete(admin) async def add_by_user_id(self, user_id: int): - query = """ - INSERT INTO `admin` - (user_id) - VALUES - (%s) - """ - query_args = (user_id,) - await self.mysql.execute_and_fetchall(query, query_args) + async with self.mysql.Session() as session: + admin = Admin(user_id=user_id) + await session.add(admin) - async def get_by_user_id(self) -> List[int]: - query = """ - SELECT user_id - FROM `admin` - """ - query_args = () - data = await self.mysql.execute_and_fetchall(query, query_args) - if len(data) == 0: - return [] - return [i[0] for i in data] \ No newline at end of file + async def get_all_user_id(self) -> List[int]: + async with self.mysql.Session() as session: + query = select(Admin) + results = await session.exec(query) + admins = results.all() + return [admin[0].user_id for admin in admins] diff --git a/app/admin/service.py b/app/admin/service.py index 02236d03..438a5a37 100644 --- a/app/admin/service.py +++ b/app/admin/service.py @@ -17,7 +17,7 @@ class BotAdminService: async def get_admin_list(self) -> List[int]: admin_list = await self._cache.get_list() if len(admin_list) == 0: - admin_list = await self._repository.get_by_user_id() + admin_list = await self._repository.get_all_user_id() for config_admin in config.ADMINISTRATORS: admin_list.append(config_admin["user_id"]) await self._cache.set_list(admin_list) @@ -28,7 +28,7 @@ class BotAdminService: await self._repository.add_by_user_id(user_id) except IntegrityError as error: Log.warning(f"{user_id} 已经存在数据库 \n", error) - admin_list = await self._repository.get_by_user_id() + admin_list = await self._repository.get_all_user_id() for config_admin in config.ADMINISTRATORS: admin_list.append(config_admin["user_id"]) await self._cache.set_list(admin_list) @@ -39,7 +39,7 @@ class BotAdminService: await self._repository.delete_by_user_id(user_id) except ValueError: return False - admin_list = await self._repository.get_by_user_id() + admin_list = await self._repository.get_all_user_id() for config_admin in config.ADMINISTRATORS: admin_list.append(config_admin["user_id"]) await self._cache.set_list(admin_list) diff --git a/app/cookies/models.py b/app/cookies/models.py new file mode 100644 index 00000000..0b9f637a --- /dev/null +++ b/app/cookies/models.py @@ -0,0 +1,25 @@ +import enum +from typing import Optional, Dict + +from sqlmodel import SQLModel, Field, JSON, Enum, Column + + +class CookiesStatusEnum(int, enum.Enum): + STATUS_SUCCESS = 0 + INVALID_COOKIES = 1 + TOO_MANY_REQUESTS = 2 + + +class Cookies(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + user_id: Optional[int] = Field() + cookies: Optional[Dict[str, str]] = Field(sa_column=Column(JSON)) + status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum))) + + +class HyperionCookie(Cookies, table=True): + __tablename__ = 'mihoyo_cookies' + + +class HoyolabCookie(Cookies, table=True): + __tablename__ = 'hoyoverse_cookies' diff --git a/app/cookies/repositories.py b/app/cookies/repositories.py index dabe117f..3a474f8c 100644 --- a/app/cookies/repositories.py +++ b/app/cookies/repositories.py @@ -1,79 +1,106 @@ -import ujson +from typing import cast, List -from model.base import ServiceEnum -from utils.error import NotFoundError +from sqlalchemy import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from model.base import RegionEnum +from utils.error import NotFoundError, RegionNotFoundError from utils.mysql import MySQL +from .models import HyperionCookie, HoyolabCookie, Cookies class CookiesRepository: def __init__(self, mysql: MySQL): self.mysql = mysql - async def update_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum): - if default_service == ServiceEnum.HYPERION: - query = """ - UPDATE `mihoyo_cookie` - SET cookie=%s - WHERE user_id=%s; - """ - elif default_service == ServiceEnum.HOYOLAB: - query = """ - UPDATE `hoyoverse_cookie` - SET cookie=%s - WHERE user_id=%s; - """ - else: - raise DefaultServiceNotFoundError(default_service.name) - query_args = (cookies, user_id) - await self.mysql.execute_and_fetchall(query, query_args) + async def add_cookies(self, user_id: int, cookies: str, region: RegionEnum): + async with self.mysql.Session() as session: + session = cast(AsyncSession, session) + if region == RegionEnum.HYPERION: + db_data = HyperionCookie(user_id=user_id, cookie=cookies) + elif region == RegionEnum.HOYOLAB: + db_data = HoyolabCookie(user_id=user_id, cookie=cookies) + else: + raise RegionNotFoundError(region.name) + await session.add(db_data) - async def set_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum): - if default_service == ServiceEnum.HYPERION: - query = """ - INSERT INTO `mihoyo_cookie` - (user_id,cookie) - VALUES - (%s,%s) - ON DUPLICATE KEY UPDATE - cookie=VALUES(cookie); - """ - elif default_service == ServiceEnum.HOYOLAB: - query = """ - INSERT INTO `hoyoverse_cookie` - (user_id,cookie) - VALUES - (%s,%s) - ON DUPLICATE KEY UPDATE - cookie=VALUES(cookie); - """ - else: - raise DefaultServiceNotFoundError(default_service.name) - query_args = (user_id, cookies) - await self.mysql.execute_and_fetchall(query, query_args) + async def update_cookies(self, user_id: int, cookies: str, region: RegionEnum): + async with self.mysql.Session() as session: + session = cast(AsyncSession, session) + if region == RegionEnum.HYPERION: + statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id) + results = await session.exec(statement) + db_cookies = results.one() + if db_cookies is None: + raise CookiesNotFoundError(user_id) + db_cookies.cookie = cookies + session.add(db_cookies) + await session.commit() + await session.refresh(db_cookies) + elif region == RegionEnum.HOYOLAB: + statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id) + results = await session.add(statement) + db_cookies = results.one() + if db_cookies is None: + raise CookiesNotFoundError(user_id) + db_cookies.cookie = cookies + session.add(db_cookies) + await session.commit() + await session.refresh(db_cookies) + else: + raise RegionNotFoundError(region.name) - async def read_cookies(self, user_id, default_service: ServiceEnum) -> dict: - if default_service == ServiceEnum.HYPERION: - query = """ - SELECT cookie - FROM `mihoyo_cookie` - WHERE user_id=%s; - """ - elif default_service == ServiceEnum.HOYOLAB: - query = """ - SELECT cookie - FROM `hoyoverse_cookie` - WHERE user_id=%s;; - """ - else: - raise DefaultServiceNotFoundError(default_service.name) - query_args = (user_id,) - data = await self.mysql.execute_and_fetchall(query, query_args) - if len(data) == 0: - return {} - (cookies,) = data - return ujson.loads(cookies) + async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum): + async with self.mysql.Session() as session: + session = cast(AsyncSession, session) + if region == RegionEnum.HYPERION: + session.add(cookies) + await session.commit() + await session.refresh(cookies) + elif region == RegionEnum.HOYOLAB: + await session.add(cookies) + await session.commit() + await session.refresh(cookies) + else: + raise RegionNotFoundError(region.name) + + async def get_cookies(self, user_id, region: RegionEnum) -> Cookies: + async with self.mysql.Session() as session: + session = cast(AsyncSession, session) + if region == RegionEnum.HYPERION: + statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id) + results = await session.exec(statement) + db_cookies = results.first() + if db_cookies is None: + raise CookiesNotFoundError(user_id) + return db_cookies[0] + elif region == RegionEnum.HOYOLAB: + statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id) + results = await session.exec(statement) + db_cookies = results.first() + if db_cookies is None: + raise CookiesNotFoundError(user_id) + return db_cookies[0] + else: + raise RegionNotFoundError(region.name) + + async def get_all_cookies(self, region: RegionEnum) -> List[Cookies]: + async with self.mysql.Session() as session: + session = cast(AsyncSession, session) + if region == RegionEnum.HYPERION: + statement = select(HyperionCookie) + results = await session.exec(statement) + db_cookies = results.all() + return [cookies[0] for cookies in db_cookies] + elif region == RegionEnum.HOYOLAB: + statement = select(HoyolabCookie) + results = await session.exec(statement) + db_cookies = results.all() + return [cookies[0] for cookies in db_cookies] + else: + raise RegionNotFoundError(region.name) -class DefaultServiceNotFoundError(NotFoundError): - entity_name: str = "ServiceEnum" - entity_value_name: str = "default_service" +class CookiesNotFoundError(NotFoundError): + entity_name: str = "CookiesRepository" + entity_value_name: str = "user_id" diff --git a/app/cookies/service.py b/app/cookies/service.py index 80e48917..6783d736 100644 --- a/app/cookies/service.py +++ b/app/cookies/service.py @@ -3,14 +3,17 @@ from model.base import ServiceEnum class CookiesService: - def __init__(self, user_repository: CookiesRepository) -> None: - self._repository: CookiesRepository = user_repository + def __init__(self, cookies_repository: CookiesRepository) -> None: + self._repository: CookiesRepository = cookies_repository - async def update_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum): - await self._repository.update_cookie(user_id, cookies, default_service) + async def update_cookies(self, user_id: int, cookies: str, region: RegionEnum): + await self._repository.update_cookies(user_id, cookies, region) - async def set_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum): - await self._repository.set_cookie(user_id, cookies, default_service) + async def add_cookies(self, user_id: int, cookies: str, region: RegionEnum): + await self._repository.add_cookies(user_id, cookies, region) + + async def get_cookies(self, user_id: int, region: RegionEnum): + return await self._repository.get_cookies(user_id, region) async def read_cookies(self, user_id: int, default_service: ServiceEnum): return await self._repository.read_cookies(user_id, default_service) diff --git a/app/quiz/repositories.py b/app/quiz/repositories.py index 870e708e..5b793294 100644 --- a/app/quiz/repositories.py +++ b/app/quiz/repositories.py @@ -1,6 +1,5 @@ from typing import List -from app.quiz.base import CreatQuestionFromSQLData, CreatAnswerFromSQLData from app.quiz.models import Question, Answer from utils.mysql import MySQL @@ -10,74 +9,25 @@ class QuizRepository: self.mysql = mysql async def get_question_list(self) -> List[Question]: - query = """ - SELECT id,question - FROM `question` - """ - query_args = () - data = await self.mysql.execute_and_fetchall(query, query_args) - return CreatQuestionFromSQLData(data) + pass async def get_answer_form_question_id(self, question_id: int) -> List[Answer]: - query = """ - SELECT id,question_id,is_correct,answer - FROM `answer` - WHERE question_id=%s; - """ - query_args = (question_id,) - data = await self.mysql.execute_and_fetchall(query, query_args) - return CreatAnswerFromSQLData(data) + pass async def add_question(self, question: str): - query = """ - INSERT INTO `question` - (question) - VALUES - (%s) - """ - query_args = (question,) - await self.mysql.execute_and_fetchall(query, query_args) + pass async def get_question(self, question: str) -> Question: - query = """ - SELECT id,question - FROM `question` - WHERE question=%s; - """ - query_args = (question,) - data = await self.mysql.execute_and_fetchall(query, query_args) - return CreatQuestionFromSQLData(data)[0] + pass async def add_answer(self, question_id: int, is_correct: int, answer: str): - query = """ - INSERT INTO `answer` - (question_id,is_correct,answer) - VALUES - (%s,%s,%s) - """ - query_args = (question_id, is_correct, answer) - await self.mysql.execute_and_fetchall(query, query_args) + pass async def delete_question(self, question_id: int): - query = """ - DELETE FROM `question` - WHERE id=%s; - """ - query_args = (question_id,) - await self.mysql.execute_and_fetchall(query, query_args) + pass async def delete_answer(self, answer_id: int): - query = """ - DELETE FROM `answer` - WHERE id=%s; - """ - query_args = (answer_id,) - await self.mysql.execute_and_fetchall(query, query_args) + pass async def delete_admin(self, user_id: int): - query = """ - DELETE FROM `admin` - WHERE user_id=%s; - """ - query_args = (user_id,) - await self.mysql.execute_and_fetchall(query, query_args) \ No newline at end of file + pass \ No newline at end of file diff --git a/app/user/models.py b/app/user/models.py index 2214d763..d43bffda 100644 --- a/app/user/models.py +++ b/app/user/models.py @@ -1,11 +1,11 @@ -from model.base import RegionEnum -from model.baseobject import BaseObject +from typing import Optional + +from sqlmodel import SQLModel, Field -class User(BaseObject): - def __init__(self, user_id: int = 0, yuanshen_game_uid: int = 0, genshin_game_uid: int = 0, - region: RegionEnum = RegionEnum.NULL): - self.user_id = user_id - self.yuanshen_game_uid = yuanshen_game_uid - self.genshin_game_uid = genshin_game_uid - self.region = region +class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + user_id: int = Field() + yuanshen_uid: int = Field() + genshin_uid: int = Field() + region: int = Field() diff --git a/app/user/repositories.py b/app/user/repositories.py index 95812b91..dacd79ff 100644 --- a/app/user/repositories.py +++ b/app/user/repositories.py @@ -1,7 +1,11 @@ -from app.user.models import User -from model.base import ServiceEnum +from typing import cast + +from sqlalchemy import select +from sqlmodel.ext.asyncio.session import AsyncSession + from utils.error import NotFoundError from utils.mysql import MySQL +from .models import User class UserRepository: @@ -9,18 +13,14 @@ class UserRepository: self.mysql = mysql async def get_by_user_id(self, user_id: int) -> User: - query = """ - SELECT user_id,mihoyo_game_uid,hoyoverse_game_uid,service - FROM `user` - WHERE user_id=%s;""" - query_args = (user_id,) - data = await self.mysql.execute_and_fetchall(query, query_args) - if len(data) == 0: - raise UserNotFoundError(user_id) - (user_id, yuanshen_game_uid, genshin_game_uid, default_service) = data - return User(user_id, yuanshen_game_uid, genshin_game_uid, ServiceEnum(default_service)) + async with self.mysql.Session() as session: + session = cast(AsyncSession, session) + statement = select(User).where(User.user_id == user_id) + results = await session.exec(statement) + user = results.first() + return user class UserNotFoundError(NotFoundError): entity_name: str = "User" - entity_value_name: str = "id" + entity_value_name: str = "user_id" diff --git a/app/user/services.py b/app/user/services.py index 215b9db5..f8bb200d 100644 --- a/app/user/services.py +++ b/app/user/services.py @@ -1,5 +1,5 @@ -from app.user.models import User -from app.user.repositories import UserRepository +from .models import User +from .repositories import UserRepository class UserService: @@ -10,6 +10,7 @@ class UserService: async def get_user_by_id(self, user_id: int) -> User: """从数据库获取用户信息 :param user_id:用户ID - :return: + :return: User """ - return await self._repository.get_by_user_id(user_id) \ No newline at end of file + user = await self._repository.get_by_user_id(user_id) + return user diff --git a/requirements.txt b/requirements.txt index c508318f..fd27969d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,7 @@ fakeredis>=1.8.1 aiohttp<=3.8.1 python-telegram-bot==20.0a2 pytz>=2021.3 -Pillow>=9.0.1 \ No newline at end of file +Pillow>=9.0.1 +SQLAlchemy>=1.4.39 +sqlmodel>=0.0.6 +asyncmy>=0.2.5 \ No newline at end of file diff --git a/utils/mysql.py b/utils/mysql.py index 3d56babb..ddfef2bc 100644 --- a/utils/mysql.py +++ b/utils/mysql.py @@ -1,71 +1,30 @@ -import asyncio - -import aiomysql +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel.ext.asyncio.session import AsyncSession from logger import Log class MySQL: def __init__(self, host: str = "127.0.0.1", port: int = 3306, user: str = "root", - password: str = "", database: str = "", loop=None): + password: str = "", database: str = ""): self.database = database self.password = password self.user = user self.port = port self.host = host - self._loop = loop - self._sql_pool = None Log.debug(f'获取数据库配置 [host]: {self.host}') Log.debug(f'获取数据库配置 [port]: {self.port}') Log.debug(f'获取数据库配置 [user]: {self.user}') Log.debug(f'获取数据库配置 [password][len]: {len(self.password)}') Log.debug(f'获取数据库配置 [db]: {self.database}') - if self._loop is None: - self._loop = asyncio.get_event_loop() - try: - Log.info("正在创建数据库LOOP") - self._loop.run_until_complete(self.create_pool()) - Log.info("创建数据库LOOP成功") - except (KeyboardInterrupt, SystemExit): - pass - except Exception as exc: - Log.error("创建数据库LOOP发生严重错误") - raise exc + self.engine = create_async_engine(f"mysql+asyncmy://{user}:{password}@{host}:{port}/{database}") + self.Session = sessionmaker(bind=self.engine, class_=AsyncSession) + + async def get_session(self): + """获取会话""" + async with self.Session() as session: + yield session async def wait_closed(self): - if self._sql_pool is None: - return - pool = self._sql_pool - pool.close() - await pool.wait_closed() - - async def create_pool(self): - self._sql_pool = await aiomysql.create_pool( - host=self.host, port=self.port, - user=self.user, password=self.password, - db=self.database, loop=self._loop) - - async def _get_pool(self): - if self._sql_pool is None: - raise RuntimeError("mysql pool is none") - return self._sql_pool - - async def executemany(self, query, query_args): - pool = await self._get_pool() - async with pool.acquire() as conn: - sql_cur = await conn.cursor() - await sql_cur.executemany(query, query_args) - rowcount = sql_cur.rowcount - await sql_cur.close() - await conn.commit() - return rowcount - - async def execute_and_fetchall(self, query, query_args): - pool = await self._get_pool() - async with pool.acquire() as conn: - sql_cur = await conn.cursor() - await sql_cur.execute(query, query_args) - result = await sql_cur.fetchall() - await sql_cur.close() - await conn.commit() - return result + pass