mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-11-17 21:22:56 +00:00
9e7637203a
移除 `aiomysql` 依赖 添加 `SQLAlchemy` `sqlmodel` `asyncmy` 依赖
107 lines
4.6 KiB
Python
107 lines
4.6 KiB
Python
from typing import cast, List
|
|
|
|
from sqlalchemy import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from model.base import RegionEnum
|
|
from utils.error import NotFoundError, RegionNotFoundError
|
|
from utils.mysql import MySQL
|
|
from .models import HyperionCookie, HoyolabCookie, Cookies
|
|
|
|
|
|
class CookiesRepository:
|
|
def __init__(self, mysql: MySQL):
|
|
self.mysql = mysql
|
|
|
|
async def add_cookies(self, user_id: int, cookies: str, region: RegionEnum):
|
|
async with self.mysql.Session() as session:
|
|
session = cast(AsyncSession, session)
|
|
if region == RegionEnum.HYPERION:
|
|
db_data = HyperionCookie(user_id=user_id, cookie=cookies)
|
|
elif region == RegionEnum.HOYOLAB:
|
|
db_data = HoyolabCookie(user_id=user_id, cookie=cookies)
|
|
else:
|
|
raise RegionNotFoundError(region.name)
|
|
await session.add(db_data)
|
|
|
|
async def update_cookies(self, user_id: int, cookies: str, region: RegionEnum):
|
|
async with self.mysql.Session() as session:
|
|
session = cast(AsyncSession, session)
|
|
if region == RegionEnum.HYPERION:
|
|
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
|
results = await session.exec(statement)
|
|
db_cookies = results.one()
|
|
if db_cookies is None:
|
|
raise CookiesNotFoundError(user_id)
|
|
db_cookies.cookie = cookies
|
|
session.add(db_cookies)
|
|
await session.commit()
|
|
await session.refresh(db_cookies)
|
|
elif region == RegionEnum.HOYOLAB:
|
|
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
|
results = await session.add(statement)
|
|
db_cookies = results.one()
|
|
if db_cookies is None:
|
|
raise CookiesNotFoundError(user_id)
|
|
db_cookies.cookie = cookies
|
|
session.add(db_cookies)
|
|
await session.commit()
|
|
await session.refresh(db_cookies)
|
|
else:
|
|
raise RegionNotFoundError(region.name)
|
|
|
|
async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum):
|
|
async with self.mysql.Session() as session:
|
|
session = cast(AsyncSession, session)
|
|
if region == RegionEnum.HYPERION:
|
|
session.add(cookies)
|
|
await session.commit()
|
|
await session.refresh(cookies)
|
|
elif region == RegionEnum.HOYOLAB:
|
|
await session.add(cookies)
|
|
await session.commit()
|
|
await session.refresh(cookies)
|
|
else:
|
|
raise RegionNotFoundError(region.name)
|
|
|
|
async def get_cookies(self, user_id, region: RegionEnum) -> Cookies:
|
|
async with self.mysql.Session() as session:
|
|
session = cast(AsyncSession, session)
|
|
if region == RegionEnum.HYPERION:
|
|
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
|
results = await session.exec(statement)
|
|
db_cookies = results.first()
|
|
if db_cookies is None:
|
|
raise CookiesNotFoundError(user_id)
|
|
return db_cookies[0]
|
|
elif region == RegionEnum.HOYOLAB:
|
|
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
|
|
results = await session.exec(statement)
|
|
db_cookies = results.first()
|
|
if db_cookies is None:
|
|
raise CookiesNotFoundError(user_id)
|
|
return db_cookies[0]
|
|
else:
|
|
raise RegionNotFoundError(region.name)
|
|
|
|
async def get_all_cookies(self, region: RegionEnum) -> List[Cookies]:
|
|
async with self.mysql.Session() as session:
|
|
session = cast(AsyncSession, session)
|
|
if region == RegionEnum.HYPERION:
|
|
statement = select(HyperionCookie)
|
|
results = await session.exec(statement)
|
|
db_cookies = results.all()
|
|
return [cookies[0] for cookies in db_cookies]
|
|
elif region == RegionEnum.HOYOLAB:
|
|
statement = select(HoyolabCookie)
|
|
results = await session.exec(statement)
|
|
db_cookies = results.all()
|
|
return [cookies[0] for cookies in db_cookies]
|
|
else:
|
|
raise RegionNotFoundError(region.name)
|
|
|
|
|
|
class CookiesNotFoundError(NotFoundError):
|
|
entity_name: str = "CookiesRepository"
|
|
entity_value_name: str = "user_id"
|