添加 utils 模块

添加 utils 模块
移除 job_queue 模块
修改获取配置文件信息函数
重构部分类以及函数
This commit is contained in:
luoshuijs 2022-06-09 20:04:38 +08:00 committed by 洛水居室
parent db96c70180
commit 63d71262fe
No known key found for this signature in database
GPG Key ID: C9DE87DA724B88FC
12 changed files with 240 additions and 243 deletions

View File

@ -18,6 +18,7 @@ class Config:
self.DEBUG = False self.DEBUG = False
self.ADMINISTRATORS = self.get_config("administrators") self.ADMINISTRATORS = self.get_config("administrators")
self.MYSQL = self.get_config("mysql") self.MYSQL = self.get_config("mysql")
self.REDIS = self.get_config("redis")
self.TELEGRAM = self.get_config("telegram") self.TELEGRAM = self.get_config("telegram")
self.FUNCTION = self.get_config("function") self.FUNCTION = self.get_config("function")

View File

@ -11,7 +11,6 @@ from plugins.errorhandler import error_handler
from plugins.gacha import Gacha from plugins.gacha import Gacha
from plugins.help import Help from plugins.help import Help
from plugins.inline import Inline from plugins.inline import Inline
from plugins.job_queue import JobQueue
from plugins.post import Post from plugins.post import Post
from plugins.quiz import Quiz from plugins.quiz import Quiz
from plugins.sign import Sign from plugins.sign import Sign
@ -77,7 +76,5 @@ def register_handlers(application, service: BaseService = None):
application.add_handler(post_handler) application.add_handler(post_handler)
inline = Inline(service) inline = Inline(service)
application.add_handler(InlineQueryHandler(inline.inline_query, block=False)) application.add_handler(InlineQueryHandler(inline.inline_query, block=False))
job_queue = JobQueue(service)
application.job_queue.run_once(job_queue.start_job, when=3, name="start_job")
application.add_handler(MessageHandler(filters.COMMAND & filters.ChatType.PRIVATE, unknown_command)) application.add_handler(MessageHandler(filters.COMMAND & filters.ChatType.PRIVATE, unknown_command))
application.add_error_handler(error_handler, block=False) application.add_error_handler(error_handler, block=False)

27
main.py
View File

@ -8,8 +8,9 @@ from config import config
from handler import register_handlers from handler import register_handlers
from logger import Log from logger import Log
from service import StartService from service import StartService
from service.cache import RedisCache from utils.aiobrowser import AioBrowser
from service.repository import AsyncRepository from utils.mysql import MySQL
from utils.redisdb import RedisDB
# 无视相关警告 # 无视相关警告
# 该警告说明在官方GITHUB的WIKI中Frequently Asked Questions里的What do the per_* settings in ConversationHandler do? # 该警告说明在官方GITHUB的WIKI中Frequently Asked Questions里的What do the per_* settings in ConversationHandler do?
@ -21,20 +22,20 @@ def main() -> None:
# 初始化数据库 # 初始化数据库
Log.info("初始化数据库") Log.info("初始化数据库")
repository = AsyncRepository(mysql_host=config.MYSQL["host"], mysql = MySQL(host=config.MYSQL["host"], user=config.MYSQL["user"], password=config.MYSQL["password"],
mysql_user=config.MYSQL["user"], port=config.MYSQL["port"], database=config.MYSQL["database"])
mysql_password=config.MYSQL["password"],
mysql_port=config.MYSQL["port"],
mysql_database=config.MYSQL["database"]
)
# 初始化Redis缓存 # 初始化Redis缓存
Log.info("初始化Redis缓存") Log.info("初始化Redis缓存")
cache = RedisCache(db=6) redis = RedisDB(host=config.REDIS["host"], port=config.REDIS["port"], db=config.REDIS["database"])
# 初始化Playwright
Log.info("初始化Playwright")
browser = AioBrowser()
# 传入服务并启动 # 传入服务并启动
Log.info("传入服务并启动") Log.info("传入服务并启动")
service = StartService(repository, cache) service = StartService(mysql, redis, browser)
# 构建BOT # 构建BOT
application = Application.builder().token(config.TELEGRAM["token"]).build() application = Application.builder().token(config.TELEGRAM["token"]).build()
@ -58,13 +59,13 @@ def main() -> None:
try: try:
# 需要关闭数据库连接 # 需要关闭数据库连接
Log.info("正在关闭数据库连接") Log.info("正在关闭数据库连接")
loop.run_until_complete(repository.wait_closed()) loop.run_until_complete(mysql.wait_closed())
# 关闭Redis连接 # 关闭Redis连接
Log.info("正在关闭Redis连接") Log.info("正在关闭Redis连接")
loop.run_until_complete(cache.close()) loop.run_until_complete(redis.close())
# 关闭playwright # 关闭playwright
Log.info("正在关闭Playwright") Log.info("正在关闭Playwright")
loop.run_until_complete(service.template.close()) loop.run_until_complete(browser.close())
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
pass pass
except Exception as exc: except Exception as exc:

View File

@ -1,31 +0,0 @@
from typing import List
from telegram.ext import CallbackContext
from logger import Log
from plugins.base import BasePlugins
from service import BaseService
class JobQueue(BasePlugins):
def __init__(self, service: BaseService):
super().__init__(service)
self.new_post_id_list_cache: List[int] = []
async def start_job(self, _: CallbackContext) -> None:
Log.info("正在检查化必要模块是否正常工作")
Log.info("正在检查Playwright")
try:
# 尝试获取 browser 如果获取失败尝试初始化
await self.service.template.get_browser()
except TimeoutError as err:
Log.error("初始化Playwright超时请检查日记查看错误 \n", err)
except AttributeError as err:
Log.error("初始化Playwright时变量为空请检查日记查看错误 \n", err)
else:
Log.info("检查Playwright成功")
Log.info("检查完成")
async def check_cookie(self, _: CallbackContext):
pass

View File

@ -1,6 +1,7 @@
# service 目录说明 # service 目录说明
## 文件说明 ## 文件说明
| FileName | Contribution | | FileName | Contribution |
|:----------:|--------------| |:----------:|--------------|
| init | 服务初始化 | | init | 服务初始化 |

View File

@ -6,20 +6,23 @@ from service.quiz import QuizService
from service.repository import AsyncRepository from service.repository import AsyncRepository
from service.template import TemplateService from service.template import TemplateService
from service.user import UserInfoFormDB from service.user import UserInfoFormDB
from utils.aiobrowser import AioBrowser
from utils.mysql import MySQL
from utils.redisdb import RedisDB
class BaseService: class BaseService:
def __init__(self, async_repository: AsyncRepository, async_cache: RedisCache): def __init__(self, mysql: MySQL, redis: RedisDB, browser: AioBrowser):
self.repository = async_repository self.repository = AsyncRepository(mysql)
self.cache = async_cache self.cache = RedisCache(redis)
self.user_service_db = UserInfoFormDB(self.repository) self.user_service_db = UserInfoFormDB(self.repository)
self.quiz_service = QuizService(self.repository, self.cache) self.quiz_service = QuizService(self.repository, self.cache)
self.get_game_info = GetGameInfo(self.repository, self.cache) self.get_game_info = GetGameInfo(self.repository, self.cache)
self.gacha = GachaService(self.repository, self.cache) self.gacha = GachaService(self.repository, self.cache)
self.admin = AdminService(self.repository, self.cache) self.admin = AdminService(self.repository, self.cache)
self.template = TemplateService() self.template = TemplateService(browser)
class StartService(BaseService): class StartService(BaseService):
def __init__(self, async_repository: AsyncRepository, async_cache: RedisCache): def __init__(self, mysql: MySQL, redis: RedisDB, browser: AioBrowser):
super().__init__(async_repository, async_cache) super().__init__(mysql, redis, browser)

View File

@ -1,74 +1,45 @@
import asyncio
from typing import List from typing import List
import ujson import ujson
from redis import asyncio as aioredis
from logger import Log
from service.base import QuestionData, AnswerData from service.base import QuestionData, AnswerData
from utils.redisdb import RedisDB
class RedisCache: class RedisCache:
def __init__(self, host="127.0.0.1", port=6379, db=0, loop=None): def __init__(self, redis: RedisDB):
self._loop = asyncio.get_event_loop() self.client = redis.client
# Redis 官方文档显示 默认创建POOL连接池
Log.debug(f'获取Redis配置 [host]: {host}')
Log.debug(f'获取Redis配置 [host]: {port}')
Log.debug(f'获取Redis配置 [host]: {db}')
self.rdb = aioredis.Redis(host=host, port=port, db=db)
self.ttl = 600
self.key_prefix = "paimon_bot"
self._loop = loop
if self._loop is None:
self._loop = asyncio.get_event_loop()
try:
Log.info("正在尝试建立与Redis连接")
self._loop.run_until_complete(self.ping())
except (KeyboardInterrupt, SystemExit):
pass
except Exception as exc:
Log.error("尝试连接Redis失败 \n")
raise exc
async def ping(self):
if await self.rdb.ping():
Log.info("连接Redis成功")
else:
Log.info("连接Redis失败")
raise RuntimeError("连接Redis失败")
async def close(self):
await self.rdb.close()
async def get_chat_admin(self, char_id: int): async def get_chat_admin(self, char_id: int):
qname = f"group:admin_list:{char_id}" qname = f"group:admin_list:{char_id}"
return [int(str_id) for str_id in await self.rdb.lrange(qname, 0, -1)] return [int(str_id) for str_id in await self.client.lrange(qname, 0, -1)]
async def set_chat_admin(self, char_id: int, admin_list: List[int]): async def set_chat_admin(self, char_id: int, admin_list: List[int]):
qname = f"group:admin_list:{char_id}" qname = f"group:admin_list:{char_id}"
await self.rdb.ltrim(qname, 1, 0) await self.client.ltrim(qname, 1, 0)
await self.rdb.lpush(qname, *admin_list) await self.client.lpush(qname, *admin_list)
await self.rdb.expire(qname, 60) await self.client.expire(qname, 60)
count = await self.rdb.llen(qname) count = await self.client.llen(qname)
return count return count
async def get_all_question(self) -> List[str]: async def get_all_question(self) -> List[str]:
qname = "quiz:question" qname = "quiz:question"
data_list = [qname + f":{question_id}" for question_id in await self.rdb.lrange(qname + "id_list", 0, -1)] data_list = [qname + f":{question_id}" for question_id in
return await self.rdb.mget(data_list) await self.client.lrange(qname + "id_list", 0, -1)]
return await self.client.mget(data_list)
async def get_all_question_id_list(self) -> List[str]: async def get_all_question_id_list(self) -> List[str]:
qname = "quiz:question:id_list" qname = "quiz:question:id_list"
return await self.rdb.lrange(qname, 0, -1) return await self.client.lrange(qname, 0, -1)
async def get_one_question(self, question_id: int) -> str: async def get_one_question(self, question_id: int) -> str:
qname = f"quiz:question:{question_id}" qname = f"quiz:question:{question_id}"
return await self.rdb.get(qname) return await self.client.get(qname)
async def get_one_answer(self, answer_id: int) -> str: async def get_one_answer(self, answer_id: int) -> str:
qname = f"quiz:answer:{answer_id}" qname = f"quiz:answer:{answer_id}"
return await self.rdb.get(qname) return await self.client.get(qname)
async def set_question(self, question_list: List[QuestionData] = None): async def set_question(self, question_list: List[QuestionData] = None):
qname = "quiz:question" qname = "quiz:question"
@ -82,25 +53,25 @@ class RedisCache:
return ujson.dumps(data) return ujson.dumps(data)
for question in question_list: for question in question_list:
await self.rdb.set(qname + f":{question.question_id}", json_dumps(question)) await self.client.set(qname + f":{question.question_id}", json_dumps(question))
question_id_list = [question.question_id for question in question_list] question_id_list = [question.question_id for question in question_list]
await self.rdb.lpush(qname + ":id_list", *question_id_list) await self.client.lpush(qname + ":id_list", *question_id_list)
return await self.rdb.llen(qname + ":id_list") return await self.client.llen(qname + ":id_list")
async def del_all_question(self, answer_list: List[AnswerData] = None): async def del_all_question(self, answer_list: List[AnswerData] = None):
qname = "quiz:question" qname = "quiz:question"
keys = await self.rdb.keys(qname + "*") keys = await self.client.keys(qname + "*")
if keys is not None: if keys is not None:
for key in keys: for key in keys:
await self.rdb.delete(key) await self.client.delete(key)
async def del_all_answer(self, answer_list: List[AnswerData] = None): async def del_all_answer(self, answer_list: List[AnswerData] = None):
qname = "quiz:answer" qname = "quiz:answer"
keys = await self.rdb.keys(qname + "*") keys = await self.client.keys(qname + "*")
if keys is not None: if keys is not None:
for key in keys: for key in keys:
await self.rdb.delete(key) await self.client.delete(key)
async def set_answer(self, answer_list: List[AnswerData] = None): async def set_answer(self, answer_list: List[AnswerData] = None):
qname = "quiz:answer" qname = "quiz:answer"
@ -109,30 +80,30 @@ class RedisCache:
return ujson.dumps(obj=_answer.__dict__) return ujson.dumps(obj=_answer.__dict__)
for answer in answer_list: for answer in answer_list:
await self.rdb.set(qname + f":{answer.answer_id}", json_dumps(answer)) await self.client.set(qname + f":{answer.answer_id}", json_dumps(answer))
answer_id_list = [answer.answer_id for answer in answer_list] answer_id_list = [answer.answer_id for answer in answer_list]
await self.rdb.lpush(qname + ":id_list", *answer_id_list) await self.client.lpush(qname + ":id_list", *answer_id_list)
return await self.rdb.llen(qname + ":id_list") return await self.client.llen(qname + ":id_list")
async def get_str_list(self, qname: str): async def get_str_list(self, qname: str):
return [str(str_data, encoding="utf-8") for str_data in await self.rdb.lrange(qname, 0, -1)] return [str(str_data, encoding="utf-8") for str_data in await self.client.lrange(qname, 0, -1)]
async def set_str_list(self, qname: str, str_list: List[str], ttl: int = 60): async def set_str_list(self, qname: str, str_list: List[str], ttl: int = 60):
await self.rdb.ltrim(qname, 1, 0) await self.client.ltrim(qname, 1, 0)
await self.rdb.lpush(qname, *str_list) await self.client.lpush(qname, *str_list)
if ttl != -1: if ttl != -1:
await self.rdb.expire(qname, ttl) await self.client.expire(qname, ttl)
count = await self.rdb.llen(qname) count = await self.client.llen(qname)
return count return count
async def get_int_list(self, qname: str): async def get_int_list(self, qname: str):
return [int(str_data) for str_data in await self.rdb.lrange(qname, 0, -1)] return [int(str_data) for str_data in await self.client.lrange(qname, 0, -1)]
async def set_int_list(self, qname: str, str_list: List[int], ttl: int = 60): async def set_int_list(self, qname: str, str_list: List[int], ttl: int = 60):
await self.rdb.ltrim(qname, 1, 0) await self.client.ltrim(qname, 1, 0)
await self.rdb.lpush(qname, *str_list) await self.client.lpush(qname, *str_list)
if ttl != -1: if ttl != -1:
await self.rdb.expire(qname, ttl) await self.client.expire(qname, ttl)
count = await self.rdb.llen(qname) count = await self.client.llen(qname)
return count return count

View File

@ -1,78 +1,14 @@
import asyncio
from typing import List from typing import List
import aiomysql
from logger import Log
from model.base import ServiceEnum from model.base import ServiceEnum
from service.base import CreateUserInfoDBDataFromSQLData, UserInfoData, CreatCookieDictFromSQLData, \ from service.base import CreateUserInfoDBDataFromSQLData, UserInfoData, CreatCookieDictFromSQLData, \
CreatQuestionFromSQLData, QuestionData, AnswerData, CreatAnswerFromSQLData CreatQuestionFromSQLData, QuestionData, AnswerData, CreatAnswerFromSQLData
from utils.mysql import MySQL
class AsyncRepository: class AsyncRepository:
def __init__(self, mysql_host: str = "127.0.0.1", mysql_port: int = 3306, mysql_user: str = "root", def __init__(self, mysql: MySQL):
mysql_password: str = "", mysql_database: str = "", loop=None): self.mysql = mysql
self._mysql_database = mysql_database
self._mysql_password = mysql_password
self._mysql_user = mysql_user
self._mysql_port = mysql_port
self._mysql_host = mysql_host
self._loop = loop
self._sql_pool = None
Log.debug(f'获取数据库配置 [host]: {self._mysql_host}')
Log.debug(f'获取数据库配置 [port]: {self._mysql_port}')
Log.debug(f'获取数据库配置 [user]: {self._mysql_user}')
Log.debug(f'获取数据库配置 [password][len]: {len(self._mysql_password)}')
Log.debug(f'获取数据库配置 [db]: {self._mysql_database}')
if self._loop is None:
self._loop = asyncio.get_event_loop()
try:
Log.info("正在创建数据库LOOP")
self._loop.run_until_complete(self.create_pool())
Log.info("创建数据库LOOP成功")
except (KeyboardInterrupt, SystemExit):
pass
except Exception as exc:
Log.error("创建数据库LOOP发生严重错误")
raise exc
async def create_pool(self):
self._sql_pool = await aiomysql.create_pool(
host=self._mysql_host, port=self._mysql_port,
user=self._mysql_user, password=self._mysql_password,
db=self._mysql_database, loop=self._loop)
async def wait_closed(self):
if self._sql_pool is None:
return
pool = self._sql_pool
pool.close()
await pool.wait_closed()
async def _get_pool(self):
if self._sql_pool is None:
raise RuntimeError("mysql pool is none")
return self._sql_pool
async def _executemany(self, query, query_args):
pool = await self._get_pool()
async with pool.acquire() as conn:
sql_cur = await conn.cursor()
await sql_cur.executemany(query, query_args)
rowcount = sql_cur.rowcount
await sql_cur.close()
await conn.commit()
return rowcount
async def _execute_and_fetchall(self, query, query_args):
pool = await self._get_pool()
async with pool.acquire() as conn:
sql_cur = await conn.cursor()
await sql_cur.execute(query, query_args)
result = await sql_cur.fetchall()
await sql_cur.close()
await conn.commit()
return result
async def update_cookie(self, user_id: int, cookie: str, service: ServiceEnum): async def update_cookie(self, user_id: int, cookie: str, service: ServiceEnum):
if service == ServiceEnum.MIHOYOBBS: if service == ServiceEnum.MIHOYOBBS:
@ -90,7 +26,7 @@ class AsyncRepository:
else: else:
query = "" query = ""
query_args = (cookie, user_id) query_args = (cookie, user_id)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def set_cookie(self, user_id: int, cookie: str, service: ServiceEnum): async def set_cookie(self, user_id: int, cookie: str, service: ServiceEnum):
if service == ServiceEnum.MIHOYOBBS: if service == ServiceEnum.MIHOYOBBS:
@ -114,7 +50,7 @@ class AsyncRepository:
else: else:
raise ValueError() raise ValueError()
query_args = (user_id, cookie) query_args = (user_id, cookie)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def set_user_info(self, user_id: int, mihoyo_game_uid: int, hoyoverse_game_uid: int, service: int): async def set_user_info(self, user_id: int, mihoyo_game_uid: int, hoyoverse_game_uid: int, service: int):
query = """ query = """
@ -128,7 +64,7 @@ class AsyncRepository:
service=VALUES(service); service=VALUES(service);
""" """
query_args = (user_id, mihoyo_game_uid, hoyoverse_game_uid, service) query_args = (user_id, mihoyo_game_uid, hoyoverse_game_uid, service)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def read_mihoyo_cookie(self, user_id) -> dict: async def read_mihoyo_cookie(self, user_id) -> dict:
query = """ query = """
@ -137,7 +73,7 @@ class AsyncRepository:
WHERE user_id=%s; WHERE user_id=%s;
""" """
query_args = (user_id,) query_args = (user_id,)
data = await self._execute_and_fetchall(query, query_args) data = await self.mysql.execute_and_fetchall(query, query_args)
if len(data) == 0: if len(data) == 0:
return {} return {}
return CreatCookieDictFromSQLData(data[0]) return CreatCookieDictFromSQLData(data[0])
@ -149,7 +85,7 @@ class AsyncRepository:
WHERE user_id=%s; WHERE user_id=%s;
""" """
query_args = (user_id,) query_args = (user_id,)
data = await self._execute_and_fetchall(query, query_args) data = await self.mysql.execute_and_fetchall(query, query_args)
if len(data) == 0: if len(data) == 0:
return {} return {}
return CreatCookieDictFromSQLData(data[0]) return CreatCookieDictFromSQLData(data[0])
@ -161,7 +97,7 @@ class AsyncRepository:
WHERE user_id=%s; WHERE user_id=%s;
""" """
query_args = (user_id,) query_args = (user_id,)
data = await self._execute_and_fetchall(query, query_args) data = await self.mysql.execute_and_fetchall(query, query_args)
if len(data) == 0: if len(data) == 0:
return UserInfoData() return UserInfoData()
return CreateUserInfoDBDataFromSQLData(data[0]) return CreateUserInfoDBDataFromSQLData(data[0])
@ -172,7 +108,7 @@ class AsyncRepository:
FROM `question` FROM `question`
""" """
query_args = () query_args = ()
data = await self._execute_and_fetchall(query, query_args) data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatQuestionFromSQLData(data) return CreatQuestionFromSQLData(data)
async def read_answer_form_question_id(self, question_id: int) -> List[AnswerData]: async def read_answer_form_question_id(self, question_id: int) -> List[AnswerData]:
@ -182,7 +118,7 @@ class AsyncRepository:
WHERE question_id=%s; WHERE question_id=%s;
""" """
query_args = (question_id,) query_args = (question_id,)
data = await self._execute_and_fetchall(query, query_args) data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatAnswerFromSQLData(data) return CreatAnswerFromSQLData(data)
async def save_question(self, question: str): async def save_question(self, question: str):
@ -193,7 +129,7 @@ class AsyncRepository:
(%s) (%s)
""" """
query_args = (question,) query_args = (question,)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def read_question(self, question: str) -> QuestionData: async def read_question(self, question: str) -> QuestionData:
query = """ query = """
@ -202,7 +138,7 @@ class AsyncRepository:
WHERE question=%s; WHERE question=%s;
""" """
query_args = (question,) query_args = (question,)
data = await self._execute_and_fetchall(query, query_args) data = await self.mysql.execute_and_fetchall(query, query_args)
return CreatQuestionFromSQLData(data)[0] return CreatQuestionFromSQLData(data)[0]
async def save_answer(self, question_id: int, is_correct: int, answer: str): async def save_answer(self, question_id: int, is_correct: int, answer: str):
@ -213,7 +149,7 @@ class AsyncRepository:
(%s,%s,%s) (%s,%s,%s)
""" """
query_args = (question_id, is_correct, answer) query_args = (question_id, is_correct, answer)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def delete_question(self, question_id: int): async def delete_question(self, question_id: int):
query = """ query = """
@ -221,7 +157,7 @@ class AsyncRepository:
WHERE id=%s; WHERE id=%s;
""" """
query_args = (question_id,) query_args = (question_id,)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def delete_answer(self, answer_id: int): async def delete_answer(self, answer_id: int):
query = """ query = """
@ -229,7 +165,7 @@ class AsyncRepository:
WHERE id=%s; WHERE id=%s;
""" """
query_args = (answer_id,) query_args = (answer_id,)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def delete_admin(self, user_id: int): async def delete_admin(self, user_id: int):
query = """ query = """
@ -237,7 +173,7 @@ class AsyncRepository:
WHERE user_id=%s; WHERE user_id=%s;
""" """
query_args = (user_id,) query_args = (user_id,)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def add_admin(self, user_id: int): async def add_admin(self, user_id: int):
query = """ query = """
@ -247,7 +183,7 @@ class AsyncRepository:
(%s) (%s)
""" """
query_args = (user_id,) query_args = (user_id,)
await self._execute_and_fetchall(query, query_args) await self.mysql.execute_and_fetchall(query, query_args)
async def get_admin(self) -> List[int]: async def get_admin(self) -> List[int]:
query = """ query = """
@ -255,7 +191,7 @@ class AsyncRepository:
FROM `admin` FROM `admin`
""" """
query_args = () query_args = ()
data = await self._execute_and_fetchall(query, query_args) data = await self.mysql.execute_and_fetchall(query, query_args)
if len(data) == 0: if len(data) == 0:
return [] return []
return list(data[0]) return list(data[0])

View File

@ -1,64 +1,24 @@
import asyncio
import os import os
import time import time
from typing import Optional
from jinja2 import PackageLoader, Environment, Template from jinja2 import PackageLoader, Environment, Template
from playwright.async_api import async_playwright, Browser, ViewportSize, Playwright from playwright.async_api import ViewportSize
from config import config from config import config
from logger import Log from logger import Log
from utils.aiobrowser import AioBrowser
class TemplateService: class TemplateService:
def __init__(self, template_package_name: str = "resources", cache_dir_name: str = "cache", loop=None): def __init__(self, browser: AioBrowser, template_package_name: str = "resources", cache_dir_name: str = "cache"):
self.browser = browser
self._template_package_name = template_package_name self._template_package_name = template_package_name
self._browser: Optional[Browser] = None
self._playwright: Optional[Playwright] = None
self._current_dir = os.getcwd() self._current_dir = os.getcwd()
self._output_dir = os.path.join(self._current_dir, cache_dir_name) self._output_dir = os.path.join(self._current_dir, cache_dir_name)
if not os.path.exists(self._output_dir): if not os.path.exists(self._output_dir):
os.mkdir(self._output_dir) os.mkdir(self._output_dir)
self._jinja2_env = {} self._jinja2_env = {}
self._jinja2_template = {} self._jinja2_template = {}
self._loop = loop
if self._loop is None:
self._loop = asyncio.get_event_loop()
try:
Log.info("正在尝试启动Playwright")
self._loop.run_until_complete(self._browser_init())
Log.info("启动Playwright成功")
except (KeyboardInterrupt, SystemExit):
pass
except Exception as exc:
Log.error("启动浏览器失败 \n")
raise exc
async def _browser_init(self) -> Browser:
if self._playwright is None:
self._playwright = await async_playwright().start()
try:
self._browser = await self._playwright.chromium.launch(timeout=5000)
except TimeoutError as err:
raise err
else:
if self._browser is None:
try:
self._browser = await self._playwright.chromium.launch(timeout=10000)
except TimeoutError as err:
raise err
return self._browser
async def close(self):
if self._browser is not None:
await self._browser.close()
if self._playwright is not None:
await self._playwright.stop()
async def get_browser(self) -> Browser:
if self._browser is None:
return await self._browser_init()
return self._browser
def get_template(self, package_path: str, template_name: str) -> Template: def get_template(self, package_path: str, template_name: str) -> Template:
if config.DEBUG: if config.DEBUG:

51
utils/aiobrowser.py Normal file
View File

@ -0,0 +1,51 @@
import asyncio
import os
from typing import Optional
from playwright.async_api import async_playwright, Browser, Playwright
from logger import Log
class AioBrowser:
def __init__(self, loop=None):
self.browser: Optional[Browser] = None
self._playwright: Optional[Playwright] = None
self._loop = loop
if self._loop is None:
self._loop = asyncio.get_event_loop()
try:
Log.info("正在尝试启动Playwright")
self._loop.run_until_complete(self._browser_init())
Log.info("启动Playwright成功")
except (KeyboardInterrupt, SystemExit):
pass
except Exception as exc:
Log.error("启动浏览器失败 \n")
raise exc
async def _browser_init(self) -> Browser:
if self._playwright is None:
self._playwright = await async_playwright().start()
try:
self.browser = await self._playwright.chromium.launch(timeout=5000)
except TimeoutError as err:
raise err
else:
if self.browser is None:
try:
self.browser = await self._playwright.chromium.launch(timeout=10000)
except TimeoutError as err:
raise err
return self.browser
async def close(self):
if self.browser is not None:
await self.browser.close()
if self._playwright is not None:
await self._playwright.stop()
async def get_browser(self) -> Browser:
if self.browser is None:
raise RuntimeError("browser is not None")
return self.browser

71
utils/mysql.py Normal file
View File

@ -0,0 +1,71 @@
import asyncio
import aiomysql
from logger import Log
class MySQL:
def __init__(self, host: str = "127.0.0.1", port: int = 3306, user: str = "root",
password: str = "", database: str = "", loop=None):
self.database = database
self.password = password
self.user = user
self.port = port
self.host = host
self._loop = loop
self._sql_pool = None
Log.debug(f'获取数据库配置 [host]: {self.host}')
Log.debug(f'获取数据库配置 [port]: {self.port}')
Log.debug(f'获取数据库配置 [user]: {self.user}')
Log.debug(f'获取数据库配置 [password][len]: {len(self.password)}')
Log.debug(f'获取数据库配置 [db]: {self.database}')
if self._loop is None:
self._loop = asyncio.get_event_loop()
try:
Log.info("正在创建数据库LOOP")
self._loop.run_until_complete(self.create_pool())
Log.info("创建数据库LOOP成功")
except (KeyboardInterrupt, SystemExit):
pass
except Exception as exc:
Log.error("创建数据库LOOP发生严重错误")
raise exc
async def wait_closed(self):
if self._sql_pool is None:
return
pool = self._sql_pool
pool.close()
await pool.wait_closed()
async def create_pool(self):
self._sql_pool = await aiomysql.create_pool(
host=self.host, port=self.port,
user=self.user, password=self.password,
db=self.database, loop=self._loop)
async def _get_pool(self):
if self._sql_pool is None:
raise RuntimeError("mysql pool is none")
return self._sql_pool
async def executemany(self, query, query_args):
pool = await self._get_pool()
async with pool.acquire() as conn:
sql_cur = await conn.cursor()
await sql_cur.executemany(query, query_args)
rowcount = sql_cur.rowcount
await sql_cur.close()
await conn.commit()
return rowcount
async def execute_and_fetchall(self, query, query_args):
pool = await self._get_pool()
async with pool.acquire() as conn:
sql_cur = await conn.cursor()
await sql_cur.execute(query, query_args)
result = await sql_cur.fetchall()
await sql_cur.close()
await conn.commit()
return result

36
utils/redisdb.py Normal file
View File

@ -0,0 +1,36 @@
import asyncio
from logger import Log
from redis import asyncio as aioredis
class RedisDB:
def __init__(self, host="127.0.0.1", port=6379, db=0, loop=None):
Log.debug(f'获取Redis配置 [host]: {host}')
Log.debug(f'获取Redis配置 [host]: {port}')
Log.debug(f'获取Redis配置 [host]: {db}')
self.client = aioredis.Redis(host=host, port=port, db=db)
self.ttl = 600
self.key_prefix = "paimon_bot"
self._loop = loop
if self._loop is None:
self._loop = asyncio.get_event_loop()
try:
Log.info("正在尝试建立与Redis连接")
self._loop.run_until_complete(self.ping())
except (KeyboardInterrupt, SystemExit):
pass
except Exception as exc:
Log.error("尝试连接Redis失败 \n")
raise exc
async def ping(self):
if await self.client.ping():
Log.info("连接Redis成功")
else:
Log.info("连接Redis失败")
raise RuntimeError("连接Redis失败")
async def close(self):
await self.client.close()