♻ 重构 mysql 通信模块

移除 `aiomysql` 依赖
添加 `SQLAlchemy` `sqlmodel` `asyncmy` 依赖
This commit is contained in:
洛水居室 2022-08-04 21:18:23 +08:00
parent 7c90b27934
commit 9e7637203a
No known key found for this signature in database
GPG Key ID: C9DE87DA724B88FC
12 changed files with 210 additions and 239 deletions

8
app/admin/models.py Normal file
View File

@ -0,0 +1,8 @@
from typing import Optional
from sqlmodel import SQLModel, Field
class Admin(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field()

View File

@ -1,6 +1,10 @@
from typing import List from typing import List, cast
from sqlalchemy import select
from sqlmodel.ext.asyncio.session import AsyncSession
from utils.mysql import MySQL from utils.mysql import MySQL
from .models import Admin
class BotAdminRepository: class BotAdminRepository:
@ -8,30 +12,21 @@ class BotAdminRepository:
self.mysql = mysql self.mysql = mysql
async def delete_by_user_id(self, user_id: int): async def delete_by_user_id(self, user_id: int):
query = """ async with self.mysql.Session() as session:
DELETE FROM `admin` session = cast(AsyncSession, session)
WHERE user_id=%s; statement = select(Admin).where(Admin.user_id == user_id)
""" results = await session.exec(statement)
query_args = (user_id,) admin = results.one()
await self.mysql.execute_and_fetchall(query, query_args) await session.delete(admin)
async def add_by_user_id(self, user_id: int): async def add_by_user_id(self, user_id: int):
query = """ async with self.mysql.Session() as session:
INSERT INTO `admin` admin = Admin(user_id=user_id)
(user_id) await session.add(admin)
VALUES
(%s)
"""
query_args = (user_id,)
await self.mysql.execute_and_fetchall(query, query_args)
async def get_by_user_id(self) -> List[int]: async def get_all_user_id(self) -> List[int]:
query = """ async with self.mysql.Session() as session:
SELECT user_id query = select(Admin)
FROM `admin` results = await session.exec(query)
""" admins = results.all()
query_args = () return [admin[0].user_id for admin in admins]
data = await self.mysql.execute_and_fetchall(query, query_args)
if len(data) == 0:
return []
return [i[0] for i in data]

View File

@ -17,7 +17,7 @@ class BotAdminService:
async def get_admin_list(self) -> List[int]: async def get_admin_list(self) -> List[int]:
admin_list = await self._cache.get_list() admin_list = await self._cache.get_list()
if len(admin_list) == 0: if len(admin_list) == 0:
admin_list = await self._repository.get_by_user_id() admin_list = await self._repository.get_all_user_id()
for config_admin in config.ADMINISTRATORS: for config_admin in config.ADMINISTRATORS:
admin_list.append(config_admin["user_id"]) admin_list.append(config_admin["user_id"])
await self._cache.set_list(admin_list) await self._cache.set_list(admin_list)
@ -28,7 +28,7 @@ class BotAdminService:
await self._repository.add_by_user_id(user_id) await self._repository.add_by_user_id(user_id)
except IntegrityError as error: except IntegrityError as error:
Log.warning(f"{user_id} 已经存在数据库 \n", error) Log.warning(f"{user_id} 已经存在数据库 \n", error)
admin_list = await self._repository.get_by_user_id() admin_list = await self._repository.get_all_user_id()
for config_admin in config.ADMINISTRATORS: for config_admin in config.ADMINISTRATORS:
admin_list.append(config_admin["user_id"]) admin_list.append(config_admin["user_id"])
await self._cache.set_list(admin_list) await self._cache.set_list(admin_list)
@ -39,7 +39,7 @@ class BotAdminService:
await self._repository.delete_by_user_id(user_id) await self._repository.delete_by_user_id(user_id)
except ValueError: except ValueError:
return False return False
admin_list = await self._repository.get_by_user_id() admin_list = await self._repository.get_all_user_id()
for config_admin in config.ADMINISTRATORS: for config_admin in config.ADMINISTRATORS:
admin_list.append(config_admin["user_id"]) admin_list.append(config_admin["user_id"])
await self._cache.set_list(admin_list) await self._cache.set_list(admin_list)

25
app/cookies/models.py Normal file
View File

@ -0,0 +1,25 @@
import enum
from typing import Optional, Dict
from sqlmodel import SQLModel, Field, JSON, Enum, Column
class CookiesStatusEnum(int, enum.Enum):
STATUS_SUCCESS = 0
INVALID_COOKIES = 1
TOO_MANY_REQUESTS = 2
class Cookies(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: Optional[int] = Field()
cookies: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
class HyperionCookie(Cookies, table=True):
__tablename__ = 'mihoyo_cookies'
class HoyolabCookie(Cookies, table=True):
__tablename__ = 'hoyoverse_cookies'

View File

@ -1,79 +1,106 @@
import ujson from typing import cast, List
from model.base import ServiceEnum from sqlalchemy import select
from utils.error import NotFoundError from sqlmodel.ext.asyncio.session import AsyncSession
from model.base import RegionEnum
from utils.error import NotFoundError, RegionNotFoundError
from utils.mysql import MySQL from utils.mysql import MySQL
from .models import HyperionCookie, HoyolabCookie, Cookies
class CookiesRepository: class CookiesRepository:
def __init__(self, mysql: MySQL): def __init__(self, mysql: MySQL):
self.mysql = mysql self.mysql = mysql
async def update_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum): async def add_cookies(self, user_id: int, cookies: str, region: RegionEnum):
if default_service == ServiceEnum.HYPERION: async with self.mysql.Session() as session:
query = """ session = cast(AsyncSession, session)
UPDATE `mihoyo_cookie` if region == RegionEnum.HYPERION:
SET cookie=%s db_data = HyperionCookie(user_id=user_id, cookie=cookies)
WHERE user_id=%s; elif region == RegionEnum.HOYOLAB:
""" db_data = HoyolabCookie(user_id=user_id, cookie=cookies)
elif default_service == ServiceEnum.HOYOLAB: else:
query = """ raise RegionNotFoundError(region.name)
UPDATE `hoyoverse_cookie` await session.add(db_data)
SET cookie=%s
WHERE user_id=%s;
"""
else:
raise DefaultServiceNotFoundError(default_service.name)
query_args = (cookies, user_id)
await self.mysql.execute_and_fetchall(query, query_args)
async def set_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum): async def update_cookies(self, user_id: int, cookies: str, region: RegionEnum):
if default_service == ServiceEnum.HYPERION: async with self.mysql.Session() as session:
query = """ session = cast(AsyncSession, session)
INSERT INTO `mihoyo_cookie` if region == RegionEnum.HYPERION:
(user_id,cookie) statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
VALUES results = await session.exec(statement)
(%s,%s) db_cookies = results.one()
ON DUPLICATE KEY UPDATE if db_cookies is None:
cookie=VALUES(cookie); raise CookiesNotFoundError(user_id)
""" db_cookies.cookie = cookies
elif default_service == ServiceEnum.HOYOLAB: session.add(db_cookies)
query = """ await session.commit()
INSERT INTO `hoyoverse_cookie` await session.refresh(db_cookies)
(user_id,cookie) elif region == RegionEnum.HOYOLAB:
VALUES statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
(%s,%s) results = await session.add(statement)
ON DUPLICATE KEY UPDATE db_cookies = results.one()
cookie=VALUES(cookie); if db_cookies is None:
""" raise CookiesNotFoundError(user_id)
else: db_cookies.cookie = cookies
raise DefaultServiceNotFoundError(default_service.name) session.add(db_cookies)
query_args = (user_id, cookies) await session.commit()
await self.mysql.execute_and_fetchall(query, query_args) await session.refresh(db_cookies)
else:
raise RegionNotFoundError(region.name)
async def read_cookies(self, user_id, default_service: ServiceEnum) -> dict: async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum):
if default_service == ServiceEnum.HYPERION: async with self.mysql.Session() as session:
query = """ session = cast(AsyncSession, session)
SELECT cookie if region == RegionEnum.HYPERION:
FROM `mihoyo_cookie` session.add(cookies)
WHERE user_id=%s; await session.commit()
""" await session.refresh(cookies)
elif default_service == ServiceEnum.HOYOLAB: elif region == RegionEnum.HOYOLAB:
query = """ await session.add(cookies)
SELECT cookie await session.commit()
FROM `hoyoverse_cookie` await session.refresh(cookies)
WHERE user_id=%s;; else:
""" raise RegionNotFoundError(region.name)
else:
raise DefaultServiceNotFoundError(default_service.name) async def get_cookies(self, user_id, region: RegionEnum) -> Cookies:
query_args = (user_id,) async with self.mysql.Session() as session:
data = await self.mysql.execute_and_fetchall(query, query_args) session = cast(AsyncSession, session)
if len(data) == 0: if region == RegionEnum.HYPERION:
return {} statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
(cookies,) = data results = await session.exec(statement)
return ujson.loads(cookies) db_cookies = results.first()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
return db_cookies[0]
elif region == RegionEnum.HOYOLAB:
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
results = await session.exec(statement)
db_cookies = results.first()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
return db_cookies[0]
else:
raise RegionNotFoundError(region.name)
async def get_all_cookies(self, region: RegionEnum) -> List[Cookies]:
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
statement = select(HyperionCookie)
results = await session.exec(statement)
db_cookies = results.all()
return [cookies[0] for cookies in db_cookies]
elif region == RegionEnum.HOYOLAB:
statement = select(HoyolabCookie)
results = await session.exec(statement)
db_cookies = results.all()
return [cookies[0] for cookies in db_cookies]
else:
raise RegionNotFoundError(region.name)
class DefaultServiceNotFoundError(NotFoundError): class CookiesNotFoundError(NotFoundError):
entity_name: str = "ServiceEnum" entity_name: str = "CookiesRepository"
entity_value_name: str = "default_service" entity_value_name: str = "user_id"

View File

@ -3,14 +3,17 @@ from model.base import ServiceEnum
class CookiesService: class CookiesService:
def __init__(self, user_repository: CookiesRepository) -> None: def __init__(self, cookies_repository: CookiesRepository) -> None:
self._repository: CookiesRepository = user_repository self._repository: CookiesRepository = cookies_repository
async def update_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum): async def update_cookies(self, user_id: int, cookies: str, region: RegionEnum):
await self._repository.update_cookie(user_id, cookies, default_service) await self._repository.update_cookies(user_id, cookies, region)
async def set_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum): async def add_cookies(self, user_id: int, cookies: str, region: RegionEnum):
await self._repository.set_cookie(user_id, cookies, default_service) await self._repository.add_cookies(user_id, cookies, region)
async def get_cookies(self, user_id: int, region: RegionEnum):
return await self._repository.get_cookies(user_id, region)
async def read_cookies(self, user_id: int, default_service: ServiceEnum): async def read_cookies(self, user_id: int, default_service: ServiceEnum):
return await self._repository.read_cookies(user_id, default_service) return await self._repository.read_cookies(user_id, default_service)

View File

@ -1,6 +1,5 @@
from typing import List from typing import List
from app.quiz.base import CreatQuestionFromSQLData, CreatAnswerFromSQLData
from app.quiz.models import Question, Answer from app.quiz.models import Question, Answer
from utils.mysql import MySQL from utils.mysql import MySQL
@ -10,74 +9,25 @@ class QuizRepository:
self.mysql = mysql self.mysql = mysql
async def get_question_list(self) -> List[Question]: async def get_question_list(self) -> List[Question]:
query = """ pass
SELECT id,question
FROM `question`
"""
query_args = ()
data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatQuestionFromSQLData(data)
async def get_answer_form_question_id(self, question_id: int) -> List[Answer]: async def get_answer_form_question_id(self, question_id: int) -> List[Answer]:
query = """ pass
SELECT id,question_id,is_correct,answer
FROM `answer`
WHERE question_id=%s;
"""
query_args = (question_id,)
data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatAnswerFromSQLData(data)
async def add_question(self, question: str): async def add_question(self, question: str):
query = """ pass
INSERT INTO `question`
(question)
VALUES
(%s)
"""
query_args = (question,)
await self.mysql.execute_and_fetchall(query, query_args)
async def get_question(self, question: str) -> Question: async def get_question(self, question: str) -> Question:
query = """ pass
SELECT id,question
FROM `question`
WHERE question=%s;
"""
query_args = (question,)
data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatQuestionFromSQLData(data)[0]
async def add_answer(self, question_id: int, is_correct: int, answer: str): async def add_answer(self, question_id: int, is_correct: int, answer: str):
query = """ pass
INSERT INTO `answer`
(question_id,is_correct,answer)
VALUES
(%s,%s,%s)
"""
query_args = (question_id, is_correct, answer)
await self.mysql.execute_and_fetchall(query, query_args)
async def delete_question(self, question_id: int): async def delete_question(self, question_id: int):
query = """ pass
DELETE FROM `question`
WHERE id=%s;
"""
query_args = (question_id,)
await self.mysql.execute_and_fetchall(query, query_args)
async def delete_answer(self, answer_id: int): async def delete_answer(self, answer_id: int):
query = """ pass
DELETE FROM `answer`
WHERE id=%s;
"""
query_args = (answer_id,)
await self.mysql.execute_and_fetchall(query, query_args)
async def delete_admin(self, user_id: int): async def delete_admin(self, user_id: int):
query = """ pass
DELETE FROM `admin`
WHERE user_id=%s;
"""
query_args = (user_id,)
await self.mysql.execute_and_fetchall(query, query_args)

View File

@ -1,11 +1,11 @@
from model.base import RegionEnum from typing import Optional
from model.baseobject import BaseObject
from sqlmodel import SQLModel, Field
class User(BaseObject): class User(SQLModel, table=True):
def __init__(self, user_id: int = 0, yuanshen_game_uid: int = 0, genshin_game_uid: int = 0, id: Optional[int] = Field(default=None, primary_key=True)
region: RegionEnum = RegionEnum.NULL): user_id: int = Field()
self.user_id = user_id yuanshen_uid: int = Field()
self.yuanshen_game_uid = yuanshen_game_uid genshin_uid: int = Field()
self.genshin_game_uid = genshin_game_uid region: int = Field()
self.region = region

View File

@ -1,7 +1,11 @@
from app.user.models import User from typing import cast
from model.base import ServiceEnum
from sqlalchemy import select
from sqlmodel.ext.asyncio.session import AsyncSession
from utils.error import NotFoundError from utils.error import NotFoundError
from utils.mysql import MySQL from utils.mysql import MySQL
from .models import User
class UserRepository: class UserRepository:
@ -9,18 +13,14 @@ class UserRepository:
self.mysql = mysql self.mysql = mysql
async def get_by_user_id(self, user_id: int) -> User: async def get_by_user_id(self, user_id: int) -> User:
query = """ async with self.mysql.Session() as session:
SELECT user_id,mihoyo_game_uid,hoyoverse_game_uid,service session = cast(AsyncSession, session)
FROM `user` statement = select(User).where(User.user_id == user_id)
WHERE user_id=%s;""" results = await session.exec(statement)
query_args = (user_id,) user = results.first()
data = await self.mysql.execute_and_fetchall(query, query_args) return user
if len(data) == 0:
raise UserNotFoundError(user_id)
(user_id, yuanshen_game_uid, genshin_game_uid, default_service) = data
return User(user_id, yuanshen_game_uid, genshin_game_uid, ServiceEnum(default_service))
class UserNotFoundError(NotFoundError): class UserNotFoundError(NotFoundError):
entity_name: str = "User" entity_name: str = "User"
entity_value_name: str = "id" entity_value_name: str = "user_id"

View File

@ -1,5 +1,5 @@
from app.user.models import User from .models import User
from app.user.repositories import UserRepository from .repositories import UserRepository
class UserService: class UserService:
@ -10,6 +10,7 @@ class UserService:
async def get_user_by_id(self, user_id: int) -> User: async def get_user_by_id(self, user_id: int) -> User:
"""从数据库获取用户信息 """从数据库获取用户信息
:param user_id:用户ID :param user_id:用户ID
:return: :return: User
""" """
return await self._repository.get_by_user_id(user_id) user = await self._repository.get_by_user_id(user_id)
return user

View File

@ -17,4 +17,7 @@ fakeredis>=1.8.1
aiohttp<=3.8.1 aiohttp<=3.8.1
python-telegram-bot==20.0a2 python-telegram-bot==20.0a2
pytz>=2021.3 pytz>=2021.3
Pillow>=9.0.1 Pillow>=9.0.1
SQLAlchemy>=1.4.39
sqlmodel>=0.0.6
asyncmy>=0.2.5

View File

@ -1,71 +1,30 @@
import asyncio from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
import aiomysql from sqlmodel.ext.asyncio.session import AsyncSession
from logger import Log from logger import Log
class MySQL: class MySQL:
def __init__(self, host: str = "127.0.0.1", port: int = 3306, user: str = "root", def __init__(self, host: str = "127.0.0.1", port: int = 3306, user: str = "root",
password: str = "", database: str = "", loop=None): password: str = "", database: str = ""):
self.database = database self.database = database
self.password = password self.password = password
self.user = user self.user = user
self.port = port self.port = port
self.host = host self.host = host
self._loop = loop
self._sql_pool = None
Log.debug(f'获取数据库配置 [host]: {self.host}') Log.debug(f'获取数据库配置 [host]: {self.host}')
Log.debug(f'获取数据库配置 [port]: {self.port}') Log.debug(f'获取数据库配置 [port]: {self.port}')
Log.debug(f'获取数据库配置 [user]: {self.user}') Log.debug(f'获取数据库配置 [user]: {self.user}')
Log.debug(f'获取数据库配置 [password][len]: {len(self.password)}') Log.debug(f'获取数据库配置 [password][len]: {len(self.password)}')
Log.debug(f'获取数据库配置 [db]: {self.database}') Log.debug(f'获取数据库配置 [db]: {self.database}')
if self._loop is None: self.engine = create_async_engine(f"mysql+asyncmy://{user}:{password}@{host}:{port}/{database}")
self._loop = asyncio.get_event_loop() self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
try:
Log.info("正在创建数据库LOOP") async def get_session(self):
self._loop.run_until_complete(self.create_pool()) """获取会话"""
Log.info("创建数据库LOOP成功") async with self.Session() as session:
except (KeyboardInterrupt, SystemExit): yield session
pass
except Exception as exc:
Log.error("创建数据库LOOP发生严重错误")
raise exc
async def wait_closed(self): async def wait_closed(self):
if self._sql_pool is None: pass
return
pool = self._sql_pool
pool.close()
await pool.wait_closed()
async def create_pool(self):
self._sql_pool = await aiomysql.create_pool(
host=self.host, port=self.port,
user=self.user, password=self.password,
db=self.database, loop=self._loop)
async def _get_pool(self):
if self._sql_pool is None:
raise RuntimeError("mysql pool is none")
return self._sql_pool
async def executemany(self, query, query_args):
pool = await self._get_pool()
async with pool.acquire() as conn:
sql_cur = await conn.cursor()
await sql_cur.executemany(query, query_args)
rowcount = sql_cur.rowcount
await sql_cur.close()
await conn.commit()
return rowcount
async def execute_and_fetchall(self, query, query_args):
pool = await self._get_pool()
async with pool.acquire() as conn:
sql_cur = await conn.cursor()
await sql_cur.execute(query, query_args)
result = await sql_cur.fetchall()
await sql_cur.close()
await conn.commit()
return result