♻ 重构 mysql 通信模块

移除 `aiomysql` 依赖
添加 `SQLAlchemy` `sqlmodel` `asyncmy` 依赖
This commit is contained in:
洛水居室 2022-08-04 21:18:23 +08:00
parent 7c90b27934
commit 9e7637203a
No known key found for this signature in database
GPG Key ID: C9DE87DA724B88FC
12 changed files with 210 additions and 239 deletions

8
app/admin/models.py Normal file
View File

@ -0,0 +1,8 @@
from typing import Optional
from sqlmodel import SQLModel, Field
class Admin(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field()

View File

@ -1,6 +1,10 @@
from typing import List
from typing import List, cast
from sqlalchemy import select
from sqlmodel.ext.asyncio.session import AsyncSession
from utils.mysql import MySQL
from .models import Admin
class BotAdminRepository:
@ -8,30 +12,21 @@ class BotAdminRepository:
self.mysql = mysql
async def delete_by_user_id(self, user_id: int):
query = """
DELETE FROM `admin`
WHERE user_id=%s;
"""
query_args = (user_id,)
await self.mysql.execute_and_fetchall(query, query_args)
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
statement = select(Admin).where(Admin.user_id == user_id)
results = await session.exec(statement)
admin = results.one()
await session.delete(admin)
async def add_by_user_id(self, user_id: int):
query = """
INSERT INTO `admin`
(user_id)
VALUES
(%s)
"""
query_args = (user_id,)
await self.mysql.execute_and_fetchall(query, query_args)
async with self.mysql.Session() as session:
admin = Admin(user_id=user_id)
await session.add(admin)
async def get_by_user_id(self) -> List[int]:
query = """
SELECT user_id
FROM `admin`
"""
query_args = ()
data = await self.mysql.execute_and_fetchall(query, query_args)
if len(data) == 0:
return []
return [i[0] for i in data]
async def get_all_user_id(self) -> List[int]:
async with self.mysql.Session() as session:
query = select(Admin)
results = await session.exec(query)
admins = results.all()
return [admin[0].user_id for admin in admins]

View File

@ -17,7 +17,7 @@ class BotAdminService:
async def get_admin_list(self) -> List[int]:
admin_list = await self._cache.get_list()
if len(admin_list) == 0:
admin_list = await self._repository.get_by_user_id()
admin_list = await self._repository.get_all_user_id()
for config_admin in config.ADMINISTRATORS:
admin_list.append(config_admin["user_id"])
await self._cache.set_list(admin_list)
@ -28,7 +28,7 @@ class BotAdminService:
await self._repository.add_by_user_id(user_id)
except IntegrityError as error:
Log.warning(f"{user_id} 已经存在数据库 \n", error)
admin_list = await self._repository.get_by_user_id()
admin_list = await self._repository.get_all_user_id()
for config_admin in config.ADMINISTRATORS:
admin_list.append(config_admin["user_id"])
await self._cache.set_list(admin_list)
@ -39,7 +39,7 @@ class BotAdminService:
await self._repository.delete_by_user_id(user_id)
except ValueError:
return False
admin_list = await self._repository.get_by_user_id()
admin_list = await self._repository.get_all_user_id()
for config_admin in config.ADMINISTRATORS:
admin_list.append(config_admin["user_id"])
await self._cache.set_list(admin_list)

25
app/cookies/models.py Normal file
View File

@ -0,0 +1,25 @@
import enum
from typing import Optional, Dict
from sqlmodel import SQLModel, Field, JSON, Enum, Column
class CookiesStatusEnum(int, enum.Enum):
STATUS_SUCCESS = 0
INVALID_COOKIES = 1
TOO_MANY_REQUESTS = 2
class Cookies(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: Optional[int] = Field()
cookies: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
class HyperionCookie(Cookies, table=True):
__tablename__ = 'mihoyo_cookies'
class HoyolabCookie(Cookies, table=True):
__tablename__ = 'hoyoverse_cookies'

View File

@ -1,79 +1,106 @@
import ujson
from typing import cast, List
from model.base import ServiceEnum
from utils.error import NotFoundError
from sqlalchemy import select
from sqlmodel.ext.asyncio.session import AsyncSession
from model.base import RegionEnum
from utils.error import NotFoundError, RegionNotFoundError
from utils.mysql import MySQL
from .models import HyperionCookie, HoyolabCookie, Cookies
class CookiesRepository:
def __init__(self, mysql: MySQL):
self.mysql = mysql
async def update_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum):
if default_service == ServiceEnum.HYPERION:
query = """
UPDATE `mihoyo_cookie`
SET cookie=%s
WHERE user_id=%s;
"""
elif default_service == ServiceEnum.HOYOLAB:
query = """
UPDATE `hoyoverse_cookie`
SET cookie=%s
WHERE user_id=%s;
"""
else:
raise DefaultServiceNotFoundError(default_service.name)
query_args = (cookies, user_id)
await self.mysql.execute_and_fetchall(query, query_args)
async def add_cookies(self, user_id: int, cookies: str, region: RegionEnum):
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
db_data = HyperionCookie(user_id=user_id, cookie=cookies)
elif region == RegionEnum.HOYOLAB:
db_data = HoyolabCookie(user_id=user_id, cookie=cookies)
else:
raise RegionNotFoundError(region.name)
await session.add(db_data)
async def set_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum):
if default_service == ServiceEnum.HYPERION:
query = """
INSERT INTO `mihoyo_cookie`
(user_id,cookie)
VALUES
(%s,%s)
ON DUPLICATE KEY UPDATE
cookie=VALUES(cookie);
"""
elif default_service == ServiceEnum.HOYOLAB:
query = """
INSERT INTO `hoyoverse_cookie`
(user_id,cookie)
VALUES
(%s,%s)
ON DUPLICATE KEY UPDATE
cookie=VALUES(cookie);
"""
else:
raise DefaultServiceNotFoundError(default_service.name)
query_args = (user_id, cookies)
await self.mysql.execute_and_fetchall(query, query_args)
async def update_cookies(self, user_id: int, cookies: str, region: RegionEnum):
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
results = await session.exec(statement)
db_cookies = results.one()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
db_cookies.cookie = cookies
session.add(db_cookies)
await session.commit()
await session.refresh(db_cookies)
elif region == RegionEnum.HOYOLAB:
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
results = await session.add(statement)
db_cookies = results.one()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
db_cookies.cookie = cookies
session.add(db_cookies)
await session.commit()
await session.refresh(db_cookies)
else:
raise RegionNotFoundError(region.name)
async def read_cookies(self, user_id, default_service: ServiceEnum) -> dict:
if default_service == ServiceEnum.HYPERION:
query = """
SELECT cookie
FROM `mihoyo_cookie`
WHERE user_id=%s;
"""
elif default_service == ServiceEnum.HOYOLAB:
query = """
SELECT cookie
FROM `hoyoverse_cookie`
WHERE user_id=%s;;
"""
else:
raise DefaultServiceNotFoundError(default_service.name)
query_args = (user_id,)
data = await self.mysql.execute_and_fetchall(query, query_args)
if len(data) == 0:
return {}
(cookies,) = data
return ujson.loads(cookies)
async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum):
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
session.add(cookies)
await session.commit()
await session.refresh(cookies)
elif region == RegionEnum.HOYOLAB:
await session.add(cookies)
await session.commit()
await session.refresh(cookies)
else:
raise RegionNotFoundError(region.name)
async def get_cookies(self, user_id, region: RegionEnum) -> Cookies:
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
results = await session.exec(statement)
db_cookies = results.first()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
return db_cookies[0]
elif region == RegionEnum.HOYOLAB:
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
results = await session.exec(statement)
db_cookies = results.first()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
return db_cookies[0]
else:
raise RegionNotFoundError(region.name)
async def get_all_cookies(self, region: RegionEnum) -> List[Cookies]:
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
statement = select(HyperionCookie)
results = await session.exec(statement)
db_cookies = results.all()
return [cookies[0] for cookies in db_cookies]
elif region == RegionEnum.HOYOLAB:
statement = select(HoyolabCookie)
results = await session.exec(statement)
db_cookies = results.all()
return [cookies[0] for cookies in db_cookies]
else:
raise RegionNotFoundError(region.name)
class DefaultServiceNotFoundError(NotFoundError):
entity_name: str = "ServiceEnum"
entity_value_name: str = "default_service"
class CookiesNotFoundError(NotFoundError):
entity_name: str = "CookiesRepository"
entity_value_name: str = "user_id"

View File

@ -3,14 +3,17 @@ from model.base import ServiceEnum
class CookiesService:
def __init__(self, user_repository: CookiesRepository) -> None:
self._repository: CookiesRepository = user_repository
def __init__(self, cookies_repository: CookiesRepository) -> None:
self._repository: CookiesRepository = cookies_repository
async def update_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum):
await self._repository.update_cookie(user_id, cookies, default_service)
async def update_cookies(self, user_id: int, cookies: str, region: RegionEnum):
await self._repository.update_cookies(user_id, cookies, region)
async def set_cookie(self, user_id: int, cookies: str, default_service: ServiceEnum):
await self._repository.set_cookie(user_id, cookies, default_service)
async def add_cookies(self, user_id: int, cookies: str, region: RegionEnum):
await self._repository.add_cookies(user_id, cookies, region)
async def get_cookies(self, user_id: int, region: RegionEnum):
return await self._repository.get_cookies(user_id, region)
async def read_cookies(self, user_id: int, default_service: ServiceEnum):
return await self._repository.read_cookies(user_id, default_service)

View File

@ -1,6 +1,5 @@
from typing import List
from app.quiz.base import CreatQuestionFromSQLData, CreatAnswerFromSQLData
from app.quiz.models import Question, Answer
from utils.mysql import MySQL
@ -10,74 +9,25 @@ class QuizRepository:
self.mysql = mysql
async def get_question_list(self) -> List[Question]:
query = """
SELECT id,question
FROM `question`
"""
query_args = ()
data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatQuestionFromSQLData(data)
pass
async def get_answer_form_question_id(self, question_id: int) -> List[Answer]:
query = """
SELECT id,question_id,is_correct,answer
FROM `answer`
WHERE question_id=%s;
"""
query_args = (question_id,)
data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatAnswerFromSQLData(data)
pass
async def add_question(self, question: str):
query = """
INSERT INTO `question`
(question)
VALUES
(%s)
"""
query_args = (question,)
await self.mysql.execute_and_fetchall(query, query_args)
pass
async def get_question(self, question: str) -> Question:
query = """
SELECT id,question
FROM `question`
WHERE question=%s;
"""
query_args = (question,)
data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatQuestionFromSQLData(data)[0]
pass
async def add_answer(self, question_id: int, is_correct: int, answer: str):
query = """
INSERT INTO `answer`
(question_id,is_correct,answer)
VALUES
(%s,%s,%s)
"""
query_args = (question_id, is_correct, answer)
await self.mysql.execute_and_fetchall(query, query_args)
pass
async def delete_question(self, question_id: int):
query = """
DELETE FROM `question`
WHERE id=%s;
"""
query_args = (question_id,)
await self.mysql.execute_and_fetchall(query, query_args)
pass
async def delete_answer(self, answer_id: int):
query = """
DELETE FROM `answer`
WHERE id=%s;
"""
query_args = (answer_id,)
await self.mysql.execute_and_fetchall(query, query_args)
pass
async def delete_admin(self, user_id: int):
query = """
DELETE FROM `admin`
WHERE user_id=%s;
"""
query_args = (user_id,)
await self.mysql.execute_and_fetchall(query, query_args)
pass

View File

@ -1,11 +1,11 @@
from model.base import RegionEnum
from model.baseobject import BaseObject
from typing import Optional
from sqlmodel import SQLModel, Field
class User(BaseObject):
def __init__(self, user_id: int = 0, yuanshen_game_uid: int = 0, genshin_game_uid: int = 0,
region: RegionEnum = RegionEnum.NULL):
self.user_id = user_id
self.yuanshen_game_uid = yuanshen_game_uid
self.genshin_game_uid = genshin_game_uid
self.region = region
class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field()
yuanshen_uid: int = Field()
genshin_uid: int = Field()
region: int = Field()

View File

@ -1,7 +1,11 @@
from app.user.models import User
from model.base import ServiceEnum
from typing import cast
from sqlalchemy import select
from sqlmodel.ext.asyncio.session import AsyncSession
from utils.error import NotFoundError
from utils.mysql import MySQL
from .models import User
class UserRepository:
@ -9,18 +13,14 @@ class UserRepository:
self.mysql = mysql
async def get_by_user_id(self, user_id: int) -> User:
query = """
SELECT user_id,mihoyo_game_uid,hoyoverse_game_uid,service
FROM `user`
WHERE user_id=%s;"""
query_args = (user_id,)
data = await self.mysql.execute_and_fetchall(query, query_args)
if len(data) == 0:
raise UserNotFoundError(user_id)
(user_id, yuanshen_game_uid, genshin_game_uid, default_service) = data
return User(user_id, yuanshen_game_uid, genshin_game_uid, ServiceEnum(default_service))
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
statement = select(User).where(User.user_id == user_id)
results = await session.exec(statement)
user = results.first()
return user
class UserNotFoundError(NotFoundError):
entity_name: str = "User"
entity_value_name: str = "id"
entity_value_name: str = "user_id"

View File

@ -1,5 +1,5 @@
from app.user.models import User
from app.user.repositories import UserRepository
from .models import User
from .repositories import UserRepository
class UserService:
@ -10,6 +10,7 @@ class UserService:
async def get_user_by_id(self, user_id: int) -> User:
"""从数据库获取用户信息
:param user_id:用户ID
:return:
:return: User
"""
return await self._repository.get_by_user_id(user_id)
user = await self._repository.get_by_user_id(user_id)
return user

View File

@ -17,4 +17,7 @@ fakeredis>=1.8.1
aiohttp<=3.8.1
python-telegram-bot==20.0a2
pytz>=2021.3
Pillow>=9.0.1
Pillow>=9.0.1
SQLAlchemy>=1.4.39
sqlmodel>=0.0.6
asyncmy>=0.2.5

View File

@ -1,71 +1,30 @@
import asyncio
import aiomysql
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel.ext.asyncio.session import AsyncSession
from logger import Log
class MySQL:
def __init__(self, host: str = "127.0.0.1", port: int = 3306, user: str = "root",
password: str = "", database: str = "", loop=None):
password: str = "", database: str = ""):
self.database = database
self.password = password
self.user = user
self.port = port
self.host = host
self._loop = loop
self._sql_pool = None
Log.debug(f'获取数据库配置 [host]: {self.host}')
Log.debug(f'获取数据库配置 [port]: {self.port}')
Log.debug(f'获取数据库配置 [user]: {self.user}')
Log.debug(f'获取数据库配置 [password][len]: {len(self.password)}')
Log.debug(f'获取数据库配置 [db]: {self.database}')
if self._loop is None:
self._loop = asyncio.get_event_loop()
try:
Log.info("正在创建数据库LOOP")
self._loop.run_until_complete(self.create_pool())
Log.info("创建数据库LOOP成功")
except (KeyboardInterrupt, SystemExit):
pass
except Exception as exc:
Log.error("创建数据库LOOP发生严重错误")
raise exc
self.engine = create_async_engine(f"mysql+asyncmy://{user}:{password}@{host}:{port}/{database}")
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
async def get_session(self):
"""获取会话"""
async with self.Session() as session:
yield session
async def wait_closed(self):
if self._sql_pool is None:
return
pool = self._sql_pool
pool.close()
await pool.wait_closed()
async def create_pool(self):
self._sql_pool = await aiomysql.create_pool(
host=self.host, port=self.port,
user=self.user, password=self.password,
db=self.database, loop=self._loop)
async def _get_pool(self):
if self._sql_pool is None:
raise RuntimeError("mysql pool is none")
return self._sql_pool
async def executemany(self, query, query_args):
pool = await self._get_pool()
async with pool.acquire() as conn:
sql_cur = await conn.cursor()
await sql_cur.executemany(query, query_args)
rowcount = sql_cur.rowcount
await sql_cur.close()
await conn.commit()
return rowcount
async def execute_and_fetchall(self, query, query_args):
pool = await self._get_pool()
async with pool.acquire() as conn:
sql_cur = await conn.cursor()
await sql_cur.execute(query, query_args)
result = await sql_cur.fetchall()
await sql_cur.close()
await conn.commit()
return result
pass