mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-11-16 04:35:49 +00:00
添加 utils
模块
添加 utils 模块 移除 job_queue 模块 修改获取配置文件信息函数 重构部分类以及函数
This commit is contained in:
parent
db96c70180
commit
63d71262fe
@ -18,6 +18,7 @@ class Config:
|
||||
self.DEBUG = False
|
||||
self.ADMINISTRATORS = self.get_config("administrators")
|
||||
self.MYSQL = self.get_config("mysql")
|
||||
self.REDIS = self.get_config("redis")
|
||||
self.TELEGRAM = self.get_config("telegram")
|
||||
self.FUNCTION = self.get_config("function")
|
||||
|
||||
|
@ -11,7 +11,6 @@ from plugins.errorhandler import error_handler
|
||||
from plugins.gacha import Gacha
|
||||
from plugins.help import Help
|
||||
from plugins.inline import Inline
|
||||
from plugins.job_queue import JobQueue
|
||||
from plugins.post import Post
|
||||
from plugins.quiz import Quiz
|
||||
from plugins.sign import Sign
|
||||
@ -77,7 +76,5 @@ def register_handlers(application, service: BaseService = None):
|
||||
application.add_handler(post_handler)
|
||||
inline = Inline(service)
|
||||
application.add_handler(InlineQueryHandler(inline.inline_query, block=False))
|
||||
job_queue = JobQueue(service)
|
||||
application.job_queue.run_once(job_queue.start_job, when=3, name="start_job")
|
||||
application.add_handler(MessageHandler(filters.COMMAND & filters.ChatType.PRIVATE, unknown_command))
|
||||
application.add_error_handler(error_handler, block=False)
|
||||
|
27
main.py
27
main.py
@ -8,8 +8,9 @@ from config import config
|
||||
from handler import register_handlers
|
||||
from logger import Log
|
||||
from service import StartService
|
||||
from service.cache import RedisCache
|
||||
from service.repository import AsyncRepository
|
||||
from utils.aiobrowser import AioBrowser
|
||||
from utils.mysql import MySQL
|
||||
from utils.redisdb import RedisDB
|
||||
|
||||
# 无视相关警告
|
||||
# 该警告说明在官方GITHUB的WIKI中Frequently Asked Questions里的What do the per_* settings in ConversationHandler do?
|
||||
@ -21,20 +22,20 @@ def main() -> None:
|
||||
|
||||
# 初始化数据库
|
||||
Log.info("初始化数据库")
|
||||
repository = AsyncRepository(mysql_host=config.MYSQL["host"],
|
||||
mysql_user=config.MYSQL["user"],
|
||||
mysql_password=config.MYSQL["password"],
|
||||
mysql_port=config.MYSQL["port"],
|
||||
mysql_database=config.MYSQL["database"]
|
||||
)
|
||||
mysql = MySQL(host=config.MYSQL["host"], user=config.MYSQL["user"], password=config.MYSQL["password"],
|
||||
port=config.MYSQL["port"], database=config.MYSQL["database"])
|
||||
|
||||
# 初始化Redis缓存
|
||||
Log.info("初始化Redis缓存")
|
||||
cache = RedisCache(db=6)
|
||||
redis = RedisDB(host=config.REDIS["host"], port=config.REDIS["port"], db=config.REDIS["database"])
|
||||
|
||||
# 初始化Playwright
|
||||
Log.info("初始化Playwright")
|
||||
browser = AioBrowser()
|
||||
|
||||
# 传入服务并启动
|
||||
Log.info("传入服务并启动")
|
||||
service = StartService(repository, cache)
|
||||
service = StartService(mysql, redis, browser)
|
||||
|
||||
# 构建BOT
|
||||
application = Application.builder().token(config.TELEGRAM["token"]).build()
|
||||
@ -58,13 +59,13 @@ def main() -> None:
|
||||
try:
|
||||
# 需要关闭数据库连接
|
||||
Log.info("正在关闭数据库连接")
|
||||
loop.run_until_complete(repository.wait_closed())
|
||||
loop.run_until_complete(mysql.wait_closed())
|
||||
# 关闭Redis连接
|
||||
Log.info("正在关闭Redis连接")
|
||||
loop.run_until_complete(cache.close())
|
||||
loop.run_until_complete(redis.close())
|
||||
# 关闭playwright
|
||||
Log.info("正在关闭Playwright")
|
||||
loop.run_until_complete(service.template.close())
|
||||
loop.run_until_complete(browser.close())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
except Exception as exc:
|
||||
|
@ -1,31 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from telegram.ext import CallbackContext
|
||||
|
||||
from logger import Log
|
||||
from plugins.base import BasePlugins
|
||||
from service import BaseService
|
||||
|
||||
|
||||
class JobQueue(BasePlugins):
|
||||
|
||||
def __init__(self, service: BaseService):
|
||||
super().__init__(service)
|
||||
self.new_post_id_list_cache: List[int] = []
|
||||
|
||||
async def start_job(self, _: CallbackContext) -> None:
|
||||
Log.info("正在检查化必要模块是否正常工作")
|
||||
Log.info("正在检查Playwright")
|
||||
try:
|
||||
# 尝试获取 browser 如果获取失败尝试初始化
|
||||
await self.service.template.get_browser()
|
||||
except TimeoutError as err:
|
||||
Log.error("初始化Playwright超时,请检查日记查看错误 \n", err)
|
||||
except AttributeError as err:
|
||||
Log.error("初始化Playwright时变量为空,请检查日记查看错误 \n", err)
|
||||
else:
|
||||
Log.info("检查Playwright成功")
|
||||
Log.info("检查完成")
|
||||
|
||||
async def check_cookie(self, _: CallbackContext):
|
||||
pass
|
@ -1,6 +1,7 @@
|
||||
# service 目录说明
|
||||
|
||||
## 文件说明
|
||||
|
||||
| FileName | Contribution |
|
||||
|:----------:|--------------|
|
||||
| init | 服务初始化 |
|
||||
|
@ -6,20 +6,23 @@ from service.quiz import QuizService
|
||||
from service.repository import AsyncRepository
|
||||
from service.template import TemplateService
|
||||
from service.user import UserInfoFormDB
|
||||
from utils.aiobrowser import AioBrowser
|
||||
from utils.mysql import MySQL
|
||||
from utils.redisdb import RedisDB
|
||||
|
||||
|
||||
class BaseService:
|
||||
def __init__(self, async_repository: AsyncRepository, async_cache: RedisCache):
|
||||
self.repository = async_repository
|
||||
self.cache = async_cache
|
||||
def __init__(self, mysql: MySQL, redis: RedisDB, browser: AioBrowser):
|
||||
self.repository = AsyncRepository(mysql)
|
||||
self.cache = RedisCache(redis)
|
||||
self.user_service_db = UserInfoFormDB(self.repository)
|
||||
self.quiz_service = QuizService(self.repository, self.cache)
|
||||
self.get_game_info = GetGameInfo(self.repository, self.cache)
|
||||
self.gacha = GachaService(self.repository, self.cache)
|
||||
self.admin = AdminService(self.repository, self.cache)
|
||||
self.template = TemplateService()
|
||||
self.template = TemplateService(browser)
|
||||
|
||||
|
||||
class StartService(BaseService):
|
||||
def __init__(self, async_repository: AsyncRepository, async_cache: RedisCache):
|
||||
super().__init__(async_repository, async_cache)
|
||||
def __init__(self, mysql: MySQL, redis: RedisDB, browser: AioBrowser):
|
||||
super().__init__(mysql, redis, browser)
|
||||
|
@ -1,74 +1,45 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
import ujson
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
from logger import Log
|
||||
from service.base import QuestionData, AnswerData
|
||||
from utils.redisdb import RedisDB
|
||||
|
||||
|
||||
class RedisCache:
|
||||
|
||||
def __init__(self, host="127.0.0.1", port=6379, db=0, loop=None):
|
||||
self._loop = asyncio.get_event_loop()
|
||||
# Redis 官方文档显示 默认创建POOL连接池
|
||||
Log.debug(f'获取Redis配置 [host]: {host}')
|
||||
Log.debug(f'获取Redis配置 [host]: {port}')
|
||||
Log.debug(f'获取Redis配置 [host]: {db}')
|
||||
self.rdb = aioredis.Redis(host=host, port=port, db=db)
|
||||
self.ttl = 600
|
||||
self.key_prefix = "paimon_bot"
|
||||
self._loop = loop
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
try:
|
||||
Log.info("正在尝试建立与Redis连接")
|
||||
self._loop.run_until_complete(self.ping())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
except Exception as exc:
|
||||
Log.error("尝试连接Redis失败 \n")
|
||||
raise exc
|
||||
|
||||
async def ping(self):
|
||||
if await self.rdb.ping():
|
||||
Log.info("连接Redis成功")
|
||||
else:
|
||||
Log.info("连接Redis失败")
|
||||
raise RuntimeError("连接Redis失败")
|
||||
|
||||
async def close(self):
|
||||
await self.rdb.close()
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
|
||||
async def get_chat_admin(self, char_id: int):
|
||||
qname = f"group:admin_list:{char_id}"
|
||||
return [int(str_id) for str_id in await self.rdb.lrange(qname, 0, -1)]
|
||||
return [int(str_id) for str_id in await self.client.lrange(qname, 0, -1)]
|
||||
|
||||
async def set_chat_admin(self, char_id: int, admin_list: List[int]):
|
||||
qname = f"group:admin_list:{char_id}"
|
||||
await self.rdb.ltrim(qname, 1, 0)
|
||||
await self.rdb.lpush(qname, *admin_list)
|
||||
await self.rdb.expire(qname, 60)
|
||||
count = await self.rdb.llen(qname)
|
||||
await self.client.ltrim(qname, 1, 0)
|
||||
await self.client.lpush(qname, *admin_list)
|
||||
await self.client.expire(qname, 60)
|
||||
count = await self.client.llen(qname)
|
||||
return count
|
||||
|
||||
async def get_all_question(self) -> List[str]:
|
||||
qname = "quiz:question"
|
||||
data_list = [qname + f":{question_id}" for question_id in await self.rdb.lrange(qname + "id_list", 0, -1)]
|
||||
return await self.rdb.mget(data_list)
|
||||
data_list = [qname + f":{question_id}" for question_id in
|
||||
await self.client.lrange(qname + "id_list", 0, -1)]
|
||||
return await self.client.mget(data_list)
|
||||
|
||||
async def get_all_question_id_list(self) -> List[str]:
|
||||
qname = "quiz:question:id_list"
|
||||
return await self.rdb.lrange(qname, 0, -1)
|
||||
return await self.client.lrange(qname, 0, -1)
|
||||
|
||||
async def get_one_question(self, question_id: int) -> str:
|
||||
qname = f"quiz:question:{question_id}"
|
||||
return await self.rdb.get(qname)
|
||||
return await self.client.get(qname)
|
||||
|
||||
async def get_one_answer(self, answer_id: int) -> str:
|
||||
qname = f"quiz:answer:{answer_id}"
|
||||
return await self.rdb.get(qname)
|
||||
return await self.client.get(qname)
|
||||
|
||||
async def set_question(self, question_list: List[QuestionData] = None):
|
||||
qname = "quiz:question"
|
||||
@ -82,25 +53,25 @@ class RedisCache:
|
||||
return ujson.dumps(data)
|
||||
|
||||
for question in question_list:
|
||||
await self.rdb.set(qname + f":{question.question_id}", json_dumps(question))
|
||||
await self.client.set(qname + f":{question.question_id}", json_dumps(question))
|
||||
|
||||
question_id_list = [question.question_id for question in question_list]
|
||||
await self.rdb.lpush(qname + ":id_list", *question_id_list)
|
||||
return await self.rdb.llen(qname + ":id_list")
|
||||
await self.client.lpush(qname + ":id_list", *question_id_list)
|
||||
return await self.client.llen(qname + ":id_list")
|
||||
|
||||
async def del_all_question(self, answer_list: List[AnswerData] = None):
|
||||
qname = "quiz:question"
|
||||
keys = await self.rdb.keys(qname + "*")
|
||||
keys = await self.client.keys(qname + "*")
|
||||
if keys is not None:
|
||||
for key in keys:
|
||||
await self.rdb.delete(key)
|
||||
await self.client.delete(key)
|
||||
|
||||
async def del_all_answer(self, answer_list: List[AnswerData] = None):
|
||||
qname = "quiz:answer"
|
||||
keys = await self.rdb.keys(qname + "*")
|
||||
keys = await self.client.keys(qname + "*")
|
||||
if keys is not None:
|
||||
for key in keys:
|
||||
await self.rdb.delete(key)
|
||||
await self.client.delete(key)
|
||||
|
||||
async def set_answer(self, answer_list: List[AnswerData] = None):
|
||||
qname = "quiz:answer"
|
||||
@ -109,30 +80,30 @@ class RedisCache:
|
||||
return ujson.dumps(obj=_answer.__dict__)
|
||||
|
||||
for answer in answer_list:
|
||||
await self.rdb.set(qname + f":{answer.answer_id}", json_dumps(answer))
|
||||
await self.client.set(qname + f":{answer.answer_id}", json_dumps(answer))
|
||||
|
||||
answer_id_list = [answer.answer_id for answer in answer_list]
|
||||
await self.rdb.lpush(qname + ":id_list", *answer_id_list)
|
||||
return await self.rdb.llen(qname + ":id_list")
|
||||
await self.client.lpush(qname + ":id_list", *answer_id_list)
|
||||
return await self.client.llen(qname + ":id_list")
|
||||
|
||||
async def get_str_list(self, qname: str):
|
||||
return [str(str_data, encoding="utf-8") for str_data in await self.rdb.lrange(qname, 0, -1)]
|
||||
return [str(str_data, encoding="utf-8") for str_data in await self.client.lrange(qname, 0, -1)]
|
||||
|
||||
async def set_str_list(self, qname: str, str_list: List[str], ttl: int = 60):
|
||||
await self.rdb.ltrim(qname, 1, 0)
|
||||
await self.rdb.lpush(qname, *str_list)
|
||||
await self.client.ltrim(qname, 1, 0)
|
||||
await self.client.lpush(qname, *str_list)
|
||||
if ttl != -1:
|
||||
await self.rdb.expire(qname, ttl)
|
||||
count = await self.rdb.llen(qname)
|
||||
await self.client.expire(qname, ttl)
|
||||
count = await self.client.llen(qname)
|
||||
return count
|
||||
|
||||
async def get_int_list(self, qname: str):
|
||||
return [int(str_data) for str_data in await self.rdb.lrange(qname, 0, -1)]
|
||||
return [int(str_data) for str_data in await self.client.lrange(qname, 0, -1)]
|
||||
|
||||
async def set_int_list(self, qname: str, str_list: List[int], ttl: int = 60):
|
||||
await self.rdb.ltrim(qname, 1, 0)
|
||||
await self.rdb.lpush(qname, *str_list)
|
||||
await self.client.ltrim(qname, 1, 0)
|
||||
await self.client.lpush(qname, *str_list)
|
||||
if ttl != -1:
|
||||
await self.rdb.expire(qname, ttl)
|
||||
count = await self.rdb.llen(qname)
|
||||
await self.client.expire(qname, ttl)
|
||||
count = await self.client.llen(qname)
|
||||
return count
|
||||
|
@ -1,78 +1,14 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
import aiomysql
|
||||
|
||||
from logger import Log
|
||||
from model.base import ServiceEnum
|
||||
from service.base import CreateUserInfoDBDataFromSQLData, UserInfoData, CreatCookieDictFromSQLData, \
|
||||
CreatQuestionFromSQLData, QuestionData, AnswerData, CreatAnswerFromSQLData
|
||||
from utils.mysql import MySQL
|
||||
|
||||
|
||||
class AsyncRepository:
|
||||
def __init__(self, mysql_host: str = "127.0.0.1", mysql_port: int = 3306, mysql_user: str = "root",
|
||||
mysql_password: str = "", mysql_database: str = "", loop=None):
|
||||
self._mysql_database = mysql_database
|
||||
self._mysql_password = mysql_password
|
||||
self._mysql_user = mysql_user
|
||||
self._mysql_port = mysql_port
|
||||
self._mysql_host = mysql_host
|
||||
self._loop = loop
|
||||
self._sql_pool = None
|
||||
Log.debug(f'获取数据库配置 [host]: {self._mysql_host}')
|
||||
Log.debug(f'获取数据库配置 [port]: {self._mysql_port}')
|
||||
Log.debug(f'获取数据库配置 [user]: {self._mysql_user}')
|
||||
Log.debug(f'获取数据库配置 [password][len]: {len(self._mysql_password)}')
|
||||
Log.debug(f'获取数据库配置 [db]: {self._mysql_database}')
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
try:
|
||||
Log.info("正在创建数据库LOOP")
|
||||
self._loop.run_until_complete(self.create_pool())
|
||||
Log.info("创建数据库LOOP成功")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
except Exception as exc:
|
||||
Log.error("创建数据库LOOP发生严重错误")
|
||||
raise exc
|
||||
|
||||
async def create_pool(self):
|
||||
self._sql_pool = await aiomysql.create_pool(
|
||||
host=self._mysql_host, port=self._mysql_port,
|
||||
user=self._mysql_user, password=self._mysql_password,
|
||||
db=self._mysql_database, loop=self._loop)
|
||||
|
||||
async def wait_closed(self):
|
||||
if self._sql_pool is None:
|
||||
return
|
||||
pool = self._sql_pool
|
||||
pool.close()
|
||||
await pool.wait_closed()
|
||||
|
||||
async def _get_pool(self):
|
||||
if self._sql_pool is None:
|
||||
raise RuntimeError("mysql pool is none")
|
||||
return self._sql_pool
|
||||
|
||||
async def _executemany(self, query, query_args):
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
sql_cur = await conn.cursor()
|
||||
await sql_cur.executemany(query, query_args)
|
||||
rowcount = sql_cur.rowcount
|
||||
await sql_cur.close()
|
||||
await conn.commit()
|
||||
return rowcount
|
||||
|
||||
async def _execute_and_fetchall(self, query, query_args):
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
sql_cur = await conn.cursor()
|
||||
await sql_cur.execute(query, query_args)
|
||||
result = await sql_cur.fetchall()
|
||||
await sql_cur.close()
|
||||
await conn.commit()
|
||||
return result
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.mysql = mysql
|
||||
|
||||
async def update_cookie(self, user_id: int, cookie: str, service: ServiceEnum):
|
||||
if service == ServiceEnum.MIHOYOBBS:
|
||||
@ -90,7 +26,7 @@ class AsyncRepository:
|
||||
else:
|
||||
query = ""
|
||||
query_args = (cookie, user_id)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def set_cookie(self, user_id: int, cookie: str, service: ServiceEnum):
|
||||
if service == ServiceEnum.MIHOYOBBS:
|
||||
@ -114,7 +50,7 @@ class AsyncRepository:
|
||||
else:
|
||||
raise ValueError()
|
||||
query_args = (user_id, cookie)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def set_user_info(self, user_id: int, mihoyo_game_uid: int, hoyoverse_game_uid: int, service: int):
|
||||
query = """
|
||||
@ -128,7 +64,7 @@ class AsyncRepository:
|
||||
service=VALUES(service);
|
||||
"""
|
||||
query_args = (user_id, mihoyo_game_uid, hoyoverse_game_uid, service)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def read_mihoyo_cookie(self, user_id) -> dict:
|
||||
query = """
|
||||
@ -137,7 +73,7 @@ class AsyncRepository:
|
||||
WHERE user_id=%s;
|
||||
"""
|
||||
query_args = (user_id,)
|
||||
data = await self._execute_and_fetchall(query, query_args)
|
||||
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||
if len(data) == 0:
|
||||
return {}
|
||||
return CreatCookieDictFromSQLData(data[0])
|
||||
@ -149,7 +85,7 @@ class AsyncRepository:
|
||||
WHERE user_id=%s;
|
||||
"""
|
||||
query_args = (user_id,)
|
||||
data = await self._execute_and_fetchall(query, query_args)
|
||||
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||
if len(data) == 0:
|
||||
return {}
|
||||
return CreatCookieDictFromSQLData(data[0])
|
||||
@ -161,7 +97,7 @@ class AsyncRepository:
|
||||
WHERE user_id=%s;
|
||||
"""
|
||||
query_args = (user_id,)
|
||||
data = await self._execute_and_fetchall(query, query_args)
|
||||
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||
if len(data) == 0:
|
||||
return UserInfoData()
|
||||
return CreateUserInfoDBDataFromSQLData(data[0])
|
||||
@ -172,7 +108,7 @@ class AsyncRepository:
|
||||
FROM `question`
|
||||
"""
|
||||
query_args = ()
|
||||
data = await self._execute_and_fetchall(query, query_args)
|
||||
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||
return CreatQuestionFromSQLData(data)
|
||||
|
||||
async def read_answer_form_question_id(self, question_id: int) -> List[AnswerData]:
|
||||
@ -182,7 +118,7 @@ class AsyncRepository:
|
||||
WHERE question_id=%s;
|
||||
"""
|
||||
query_args = (question_id,)
|
||||
data = await self._execute_and_fetchall(query, query_args)
|
||||
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||
return CreatAnswerFromSQLData(data)
|
||||
|
||||
async def save_question(self, question: str):
|
||||
@ -193,7 +129,7 @@ class AsyncRepository:
|
||||
(%s)
|
||||
"""
|
||||
query_args = (question,)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def read_question(self, question: str) -> QuestionData:
|
||||
query = """
|
||||
@ -202,7 +138,7 @@ class AsyncRepository:
|
||||
WHERE question=%s;
|
||||
"""
|
||||
query_args = (question,)
|
||||
data = await self._execute_and_fetchall(query, query_args)
|
||||
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||
return CreatQuestionFromSQLData(data)[0]
|
||||
|
||||
async def save_answer(self, question_id: int, is_correct: int, answer: str):
|
||||
@ -213,7 +149,7 @@ class AsyncRepository:
|
||||
(%s,%s,%s)
|
||||
"""
|
||||
query_args = (question_id, is_correct, answer)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def delete_question(self, question_id: int):
|
||||
query = """
|
||||
@ -221,7 +157,7 @@ class AsyncRepository:
|
||||
WHERE id=%s;
|
||||
"""
|
||||
query_args = (question_id,)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def delete_answer(self, answer_id: int):
|
||||
query = """
|
||||
@ -229,7 +165,7 @@ class AsyncRepository:
|
||||
WHERE id=%s;
|
||||
"""
|
||||
query_args = (answer_id,)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def delete_admin(self, user_id: int):
|
||||
query = """
|
||||
@ -237,7 +173,7 @@ class AsyncRepository:
|
||||
WHERE user_id=%s;
|
||||
"""
|
||||
query_args = (user_id,)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def add_admin(self, user_id: int):
|
||||
query = """
|
||||
@ -247,7 +183,7 @@ class AsyncRepository:
|
||||
(%s)
|
||||
"""
|
||||
query_args = (user_id,)
|
||||
await self._execute_and_fetchall(query, query_args)
|
||||
await self.mysql.execute_and_fetchall(query, query_args)
|
||||
|
||||
async def get_admin(self) -> List[int]:
|
||||
query = """
|
||||
@ -255,7 +191,7 @@ class AsyncRepository:
|
||||
FROM `admin`
|
||||
"""
|
||||
query_args = ()
|
||||
data = await self._execute_and_fetchall(query, query_args)
|
||||
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||
if len(data) == 0:
|
||||
return []
|
||||
return list(data[0])
|
||||
|
@ -1,64 +1,24 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from jinja2 import PackageLoader, Environment, Template
|
||||
from playwright.async_api import async_playwright, Browser, ViewportSize, Playwright
|
||||
from playwright.async_api import ViewportSize
|
||||
|
||||
from config import config
|
||||
from logger import Log
|
||||
from utils.aiobrowser import AioBrowser
|
||||
|
||||
|
||||
class TemplateService:
|
||||
def __init__(self, template_package_name: str = "resources", cache_dir_name: str = "cache", loop=None):
|
||||
def __init__(self, browser: AioBrowser, template_package_name: str = "resources", cache_dir_name: str = "cache"):
|
||||
self.browser = browser
|
||||
self._template_package_name = template_package_name
|
||||
self._browser: Optional[Browser] = None
|
||||
self._playwright: Optional[Playwright] = None
|
||||
self._current_dir = os.getcwd()
|
||||
self._output_dir = os.path.join(self._current_dir, cache_dir_name)
|
||||
if not os.path.exists(self._output_dir):
|
||||
os.mkdir(self._output_dir)
|
||||
self._jinja2_env = {}
|
||||
self._jinja2_template = {}
|
||||
self._loop = loop
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
try:
|
||||
Log.info("正在尝试启动Playwright")
|
||||
self._loop.run_until_complete(self._browser_init())
|
||||
Log.info("启动Playwright成功")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
except Exception as exc:
|
||||
Log.error("启动浏览器失败 \n")
|
||||
raise exc
|
||||
|
||||
async def _browser_init(self) -> Browser:
|
||||
if self._playwright is None:
|
||||
self._playwright = await async_playwright().start()
|
||||
try:
|
||||
self._browser = await self._playwright.chromium.launch(timeout=5000)
|
||||
except TimeoutError as err:
|
||||
raise err
|
||||
else:
|
||||
if self._browser is None:
|
||||
try:
|
||||
self._browser = await self._playwright.chromium.launch(timeout=10000)
|
||||
except TimeoutError as err:
|
||||
raise err
|
||||
return self._browser
|
||||
|
||||
async def close(self):
|
||||
if self._browser is not None:
|
||||
await self._browser.close()
|
||||
if self._playwright is not None:
|
||||
await self._playwright.stop()
|
||||
|
||||
async def get_browser(self) -> Browser:
|
||||
if self._browser is None:
|
||||
return await self._browser_init()
|
||||
return self._browser
|
||||
|
||||
def get_template(self, package_path: str, template_name: str) -> Template:
|
||||
if config.DEBUG:
|
||||
|
51
utils/aiobrowser.py
Normal file
51
utils/aiobrowser.py
Normal file
@ -0,0 +1,51 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from playwright.async_api import async_playwright, Browser, Playwright
|
||||
|
||||
from logger import Log
|
||||
|
||||
|
||||
class AioBrowser:
|
||||
def __init__(self, loop=None):
|
||||
self.browser: Optional[Browser] = None
|
||||
self._playwright: Optional[Playwright] = None
|
||||
self._loop = loop
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
try:
|
||||
Log.info("正在尝试启动Playwright")
|
||||
self._loop.run_until_complete(self._browser_init())
|
||||
Log.info("启动Playwright成功")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
except Exception as exc:
|
||||
Log.error("启动浏览器失败 \n")
|
||||
raise exc
|
||||
|
||||
async def _browser_init(self) -> Browser:
|
||||
if self._playwright is None:
|
||||
self._playwright = await async_playwright().start()
|
||||
try:
|
||||
self.browser = await self._playwright.chromium.launch(timeout=5000)
|
||||
except TimeoutError as err:
|
||||
raise err
|
||||
else:
|
||||
if self.browser is None:
|
||||
try:
|
||||
self.browser = await self._playwright.chromium.launch(timeout=10000)
|
||||
except TimeoutError as err:
|
||||
raise err
|
||||
return self.browser
|
||||
|
||||
async def close(self):
|
||||
if self.browser is not None:
|
||||
await self.browser.close()
|
||||
if self._playwright is not None:
|
||||
await self._playwright.stop()
|
||||
|
||||
async def get_browser(self) -> Browser:
|
||||
if self.browser is None:
|
||||
raise RuntimeError("browser is not None")
|
||||
return self.browser
|
71
utils/mysql.py
Normal file
71
utils/mysql.py
Normal file
@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
|
||||
import aiomysql
|
||||
|
||||
from logger import Log
|
||||
|
||||
|
||||
class MySQL:
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 3306, user: str = "root",
|
||||
password: str = "", database: str = "", loop=None):
|
||||
self.database = database
|
||||
self.password = password
|
||||
self.user = user
|
||||
self.port = port
|
||||
self.host = host
|
||||
self._loop = loop
|
||||
self._sql_pool = None
|
||||
Log.debug(f'获取数据库配置 [host]: {self.host}')
|
||||
Log.debug(f'获取数据库配置 [port]: {self.port}')
|
||||
Log.debug(f'获取数据库配置 [user]: {self.user}')
|
||||
Log.debug(f'获取数据库配置 [password][len]: {len(self.password)}')
|
||||
Log.debug(f'获取数据库配置 [db]: {self.database}')
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
try:
|
||||
Log.info("正在创建数据库LOOP")
|
||||
self._loop.run_until_complete(self.create_pool())
|
||||
Log.info("创建数据库LOOP成功")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
except Exception as exc:
|
||||
Log.error("创建数据库LOOP发生严重错误")
|
||||
raise exc
|
||||
|
||||
async def wait_closed(self):
|
||||
if self._sql_pool is None:
|
||||
return
|
||||
pool = self._sql_pool
|
||||
pool.close()
|
||||
await pool.wait_closed()
|
||||
|
||||
async def create_pool(self):
|
||||
self._sql_pool = await aiomysql.create_pool(
|
||||
host=self.host, port=self.port,
|
||||
user=self.user, password=self.password,
|
||||
db=self.database, loop=self._loop)
|
||||
|
||||
async def _get_pool(self):
|
||||
if self._sql_pool is None:
|
||||
raise RuntimeError("mysql pool is none")
|
||||
return self._sql_pool
|
||||
|
||||
async def executemany(self, query, query_args):
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
sql_cur = await conn.cursor()
|
||||
await sql_cur.executemany(query, query_args)
|
||||
rowcount = sql_cur.rowcount
|
||||
await sql_cur.close()
|
||||
await conn.commit()
|
||||
return rowcount
|
||||
|
||||
async def execute_and_fetchall(self, query, query_args):
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
sql_cur = await conn.cursor()
|
||||
await sql_cur.execute(query, query_args)
|
||||
result = await sql_cur.fetchall()
|
||||
await sql_cur.close()
|
||||
await conn.commit()
|
||||
return result
|
36
utils/redisdb.py
Normal file
36
utils/redisdb.py
Normal file
@ -0,0 +1,36 @@
|
||||
import asyncio
|
||||
|
||||
from logger import Log
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
|
||||
class RedisDB:
|
||||
|
||||
def __init__(self, host="127.0.0.1", port=6379, db=0, loop=None):
|
||||
Log.debug(f'获取Redis配置 [host]: {host}')
|
||||
Log.debug(f'获取Redis配置 [host]: {port}')
|
||||
Log.debug(f'获取Redis配置 [host]: {db}')
|
||||
self.client = aioredis.Redis(host=host, port=port, db=db)
|
||||
self.ttl = 600
|
||||
self.key_prefix = "paimon_bot"
|
||||
self._loop = loop
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
try:
|
||||
Log.info("正在尝试建立与Redis连接")
|
||||
self._loop.run_until_complete(self.ping())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
except Exception as exc:
|
||||
Log.error("尝试连接Redis失败 \n")
|
||||
raise exc
|
||||
|
||||
async def ping(self):
|
||||
if await self.client.ping():
|
||||
Log.info("连接Redis成功")
|
||||
else:
|
||||
Log.info("连接Redis失败")
|
||||
raise RuntimeError("连接Redis失败")
|
||||
|
||||
async def close(self):
|
||||
await self.client.close()
|
Loading…
Reference in New Issue
Block a user