mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-11-16 04:35:49 +00:00
添加 utils
模块
添加 utils 模块 移除 job_queue 模块 修改获取配置文件信息函数 重构部分类以及函数
This commit is contained in:
parent
db96c70180
commit
63d71262fe
@ -18,6 +18,7 @@ class Config:
|
|||||||
self.DEBUG = False
|
self.DEBUG = False
|
||||||
self.ADMINISTRATORS = self.get_config("administrators")
|
self.ADMINISTRATORS = self.get_config("administrators")
|
||||||
self.MYSQL = self.get_config("mysql")
|
self.MYSQL = self.get_config("mysql")
|
||||||
|
self.REDIS = self.get_config("redis")
|
||||||
self.TELEGRAM = self.get_config("telegram")
|
self.TELEGRAM = self.get_config("telegram")
|
||||||
self.FUNCTION = self.get_config("function")
|
self.FUNCTION = self.get_config("function")
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@ from plugins.errorhandler import error_handler
|
|||||||
from plugins.gacha import Gacha
|
from plugins.gacha import Gacha
|
||||||
from plugins.help import Help
|
from plugins.help import Help
|
||||||
from plugins.inline import Inline
|
from plugins.inline import Inline
|
||||||
from plugins.job_queue import JobQueue
|
|
||||||
from plugins.post import Post
|
from plugins.post import Post
|
||||||
from plugins.quiz import Quiz
|
from plugins.quiz import Quiz
|
||||||
from plugins.sign import Sign
|
from plugins.sign import Sign
|
||||||
@ -77,7 +76,5 @@ def register_handlers(application, service: BaseService = None):
|
|||||||
application.add_handler(post_handler)
|
application.add_handler(post_handler)
|
||||||
inline = Inline(service)
|
inline = Inline(service)
|
||||||
application.add_handler(InlineQueryHandler(inline.inline_query, block=False))
|
application.add_handler(InlineQueryHandler(inline.inline_query, block=False))
|
||||||
job_queue = JobQueue(service)
|
|
||||||
application.job_queue.run_once(job_queue.start_job, when=3, name="start_job")
|
|
||||||
application.add_handler(MessageHandler(filters.COMMAND & filters.ChatType.PRIVATE, unknown_command))
|
application.add_handler(MessageHandler(filters.COMMAND & filters.ChatType.PRIVATE, unknown_command))
|
||||||
application.add_error_handler(error_handler, block=False)
|
application.add_error_handler(error_handler, block=False)
|
||||||
|
27
main.py
27
main.py
@ -8,8 +8,9 @@ from config import config
|
|||||||
from handler import register_handlers
|
from handler import register_handlers
|
||||||
from logger import Log
|
from logger import Log
|
||||||
from service import StartService
|
from service import StartService
|
||||||
from service.cache import RedisCache
|
from utils.aiobrowser import AioBrowser
|
||||||
from service.repository import AsyncRepository
|
from utils.mysql import MySQL
|
||||||
|
from utils.redisdb import RedisDB
|
||||||
|
|
||||||
# 无视相关警告
|
# 无视相关警告
|
||||||
# 该警告说明在官方GITHUB的WIKI中Frequently Asked Questions里的What do the per_* settings in ConversationHandler do?
|
# 该警告说明在官方GITHUB的WIKI中Frequently Asked Questions里的What do the per_* settings in ConversationHandler do?
|
||||||
@ -21,20 +22,20 @@ def main() -> None:
|
|||||||
|
|
||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
Log.info("初始化数据库")
|
Log.info("初始化数据库")
|
||||||
repository = AsyncRepository(mysql_host=config.MYSQL["host"],
|
mysql = MySQL(host=config.MYSQL["host"], user=config.MYSQL["user"], password=config.MYSQL["password"],
|
||||||
mysql_user=config.MYSQL["user"],
|
port=config.MYSQL["port"], database=config.MYSQL["database"])
|
||||||
mysql_password=config.MYSQL["password"],
|
|
||||||
mysql_port=config.MYSQL["port"],
|
|
||||||
mysql_database=config.MYSQL["database"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 初始化Redis缓存
|
# 初始化Redis缓存
|
||||||
Log.info("初始化Redis缓存")
|
Log.info("初始化Redis缓存")
|
||||||
cache = RedisCache(db=6)
|
redis = RedisDB(host=config.REDIS["host"], port=config.REDIS["port"], db=config.REDIS["database"])
|
||||||
|
|
||||||
|
# 初始化Playwright
|
||||||
|
Log.info("初始化Playwright")
|
||||||
|
browser = AioBrowser()
|
||||||
|
|
||||||
# 传入服务并启动
|
# 传入服务并启动
|
||||||
Log.info("传入服务并启动")
|
Log.info("传入服务并启动")
|
||||||
service = StartService(repository, cache)
|
service = StartService(mysql, redis, browser)
|
||||||
|
|
||||||
# 构建BOT
|
# 构建BOT
|
||||||
application = Application.builder().token(config.TELEGRAM["token"]).build()
|
application = Application.builder().token(config.TELEGRAM["token"]).build()
|
||||||
@ -58,13 +59,13 @@ def main() -> None:
|
|||||||
try:
|
try:
|
||||||
# 需要关闭数据库连接
|
# 需要关闭数据库连接
|
||||||
Log.info("正在关闭数据库连接")
|
Log.info("正在关闭数据库连接")
|
||||||
loop.run_until_complete(repository.wait_closed())
|
loop.run_until_complete(mysql.wait_closed())
|
||||||
# 关闭Redis连接
|
# 关闭Redis连接
|
||||||
Log.info("正在关闭Redis连接")
|
Log.info("正在关闭Redis连接")
|
||||||
loop.run_until_complete(cache.close())
|
loop.run_until_complete(redis.close())
|
||||||
# 关闭playwright
|
# 关闭playwright
|
||||||
Log.info("正在关闭Playwright")
|
Log.info("正在关闭Playwright")
|
||||||
loop.run_until_complete(service.template.close())
|
loop.run_until_complete(browser.close())
|
||||||
except (KeyboardInterrupt, SystemExit):
|
except (KeyboardInterrupt, SystemExit):
|
||||||
pass
|
pass
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
@ -1,31 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from telegram.ext import CallbackContext
|
|
||||||
|
|
||||||
from logger import Log
|
|
||||||
from plugins.base import BasePlugins
|
|
||||||
from service import BaseService
|
|
||||||
|
|
||||||
|
|
||||||
class JobQueue(BasePlugins):
|
|
||||||
|
|
||||||
def __init__(self, service: BaseService):
|
|
||||||
super().__init__(service)
|
|
||||||
self.new_post_id_list_cache: List[int] = []
|
|
||||||
|
|
||||||
async def start_job(self, _: CallbackContext) -> None:
|
|
||||||
Log.info("正在检查化必要模块是否正常工作")
|
|
||||||
Log.info("正在检查Playwright")
|
|
||||||
try:
|
|
||||||
# 尝试获取 browser 如果获取失败尝试初始化
|
|
||||||
await self.service.template.get_browser()
|
|
||||||
except TimeoutError as err:
|
|
||||||
Log.error("初始化Playwright超时,请检查日记查看错误 \n", err)
|
|
||||||
except AttributeError as err:
|
|
||||||
Log.error("初始化Playwright时变量为空,请检查日记查看错误 \n", err)
|
|
||||||
else:
|
|
||||||
Log.info("检查Playwright成功")
|
|
||||||
Log.info("检查完成")
|
|
||||||
|
|
||||||
async def check_cookie(self, _: CallbackContext):
|
|
||||||
pass
|
|
@ -1,6 +1,7 @@
|
|||||||
# service 目录说明
|
# service 目录说明
|
||||||
|
|
||||||
## 文件说明
|
## 文件说明
|
||||||
|
|
||||||
| FileName | Contribution |
|
| FileName | Contribution |
|
||||||
|:----------:|--------------|
|
|:----------:|--------------|
|
||||||
| init | 服务初始化 |
|
| init | 服务初始化 |
|
||||||
|
@ -6,20 +6,23 @@ from service.quiz import QuizService
|
|||||||
from service.repository import AsyncRepository
|
from service.repository import AsyncRepository
|
||||||
from service.template import TemplateService
|
from service.template import TemplateService
|
||||||
from service.user import UserInfoFormDB
|
from service.user import UserInfoFormDB
|
||||||
|
from utils.aiobrowser import AioBrowser
|
||||||
|
from utils.mysql import MySQL
|
||||||
|
from utils.redisdb import RedisDB
|
||||||
|
|
||||||
|
|
||||||
class BaseService:
|
class BaseService:
|
||||||
def __init__(self, async_repository: AsyncRepository, async_cache: RedisCache):
|
def __init__(self, mysql: MySQL, redis: RedisDB, browser: AioBrowser):
|
||||||
self.repository = async_repository
|
self.repository = AsyncRepository(mysql)
|
||||||
self.cache = async_cache
|
self.cache = RedisCache(redis)
|
||||||
self.user_service_db = UserInfoFormDB(self.repository)
|
self.user_service_db = UserInfoFormDB(self.repository)
|
||||||
self.quiz_service = QuizService(self.repository, self.cache)
|
self.quiz_service = QuizService(self.repository, self.cache)
|
||||||
self.get_game_info = GetGameInfo(self.repository, self.cache)
|
self.get_game_info = GetGameInfo(self.repository, self.cache)
|
||||||
self.gacha = GachaService(self.repository, self.cache)
|
self.gacha = GachaService(self.repository, self.cache)
|
||||||
self.admin = AdminService(self.repository, self.cache)
|
self.admin = AdminService(self.repository, self.cache)
|
||||||
self.template = TemplateService()
|
self.template = TemplateService(browser)
|
||||||
|
|
||||||
|
|
||||||
class StartService(BaseService):
|
class StartService(BaseService):
|
||||||
def __init__(self, async_repository: AsyncRepository, async_cache: RedisCache):
|
def __init__(self, mysql: MySQL, redis: RedisDB, browser: AioBrowser):
|
||||||
super().__init__(async_repository, async_cache)
|
super().__init__(mysql, redis, browser)
|
||||||
|
@ -1,74 +1,45 @@
|
|||||||
import asyncio
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import ujson
|
import ujson
|
||||||
from redis import asyncio as aioredis
|
|
||||||
|
|
||||||
from logger import Log
|
|
||||||
from service.base import QuestionData, AnswerData
|
from service.base import QuestionData, AnswerData
|
||||||
|
from utils.redisdb import RedisDB
|
||||||
|
|
||||||
|
|
||||||
class RedisCache:
|
class RedisCache:
|
||||||
|
|
||||||
def __init__(self, host="127.0.0.1", port=6379, db=0, loop=None):
|
def __init__(self, redis: RedisDB):
|
||||||
self._loop = asyncio.get_event_loop()
|
self.client = redis.client
|
||||||
# Redis 官方文档显示 默认创建POOL连接池
|
|
||||||
Log.debug(f'获取Redis配置 [host]: {host}')
|
|
||||||
Log.debug(f'获取Redis配置 [host]: {port}')
|
|
||||||
Log.debug(f'获取Redis配置 [host]: {db}')
|
|
||||||
self.rdb = aioredis.Redis(host=host, port=port, db=db)
|
|
||||||
self.ttl = 600
|
|
||||||
self.key_prefix = "paimon_bot"
|
|
||||||
self._loop = loop
|
|
||||||
if self._loop is None:
|
|
||||||
self._loop = asyncio.get_event_loop()
|
|
||||||
try:
|
|
||||||
Log.info("正在尝试建立与Redis连接")
|
|
||||||
self._loop.run_until_complete(self.ping())
|
|
||||||
except (KeyboardInterrupt, SystemExit):
|
|
||||||
pass
|
|
||||||
except Exception as exc:
|
|
||||||
Log.error("尝试连接Redis失败 \n")
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
async def ping(self):
|
|
||||||
if await self.rdb.ping():
|
|
||||||
Log.info("连接Redis成功")
|
|
||||||
else:
|
|
||||||
Log.info("连接Redis失败")
|
|
||||||
raise RuntimeError("连接Redis失败")
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
await self.rdb.close()
|
|
||||||
|
|
||||||
async def get_chat_admin(self, char_id: int):
|
async def get_chat_admin(self, char_id: int):
|
||||||
qname = f"group:admin_list:{char_id}"
|
qname = f"group:admin_list:{char_id}"
|
||||||
return [int(str_id) for str_id in await self.rdb.lrange(qname, 0, -1)]
|
return [int(str_id) for str_id in await self.client.lrange(qname, 0, -1)]
|
||||||
|
|
||||||
async def set_chat_admin(self, char_id: int, admin_list: List[int]):
|
async def set_chat_admin(self, char_id: int, admin_list: List[int]):
|
||||||
qname = f"group:admin_list:{char_id}"
|
qname = f"group:admin_list:{char_id}"
|
||||||
await self.rdb.ltrim(qname, 1, 0)
|
await self.client.ltrim(qname, 1, 0)
|
||||||
await self.rdb.lpush(qname, *admin_list)
|
await self.client.lpush(qname, *admin_list)
|
||||||
await self.rdb.expire(qname, 60)
|
await self.client.expire(qname, 60)
|
||||||
count = await self.rdb.llen(qname)
|
count = await self.client.llen(qname)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
async def get_all_question(self) -> List[str]:
|
async def get_all_question(self) -> List[str]:
|
||||||
qname = "quiz:question"
|
qname = "quiz:question"
|
||||||
data_list = [qname + f":{question_id}" for question_id in await self.rdb.lrange(qname + "id_list", 0, -1)]
|
data_list = [qname + f":{question_id}" for question_id in
|
||||||
return await self.rdb.mget(data_list)
|
await self.client.lrange(qname + "id_list", 0, -1)]
|
||||||
|
return await self.client.mget(data_list)
|
||||||
|
|
||||||
async def get_all_question_id_list(self) -> List[str]:
|
async def get_all_question_id_list(self) -> List[str]:
|
||||||
qname = "quiz:question:id_list"
|
qname = "quiz:question:id_list"
|
||||||
return await self.rdb.lrange(qname, 0, -1)
|
return await self.client.lrange(qname, 0, -1)
|
||||||
|
|
||||||
async def get_one_question(self, question_id: int) -> str:
|
async def get_one_question(self, question_id: int) -> str:
|
||||||
qname = f"quiz:question:{question_id}"
|
qname = f"quiz:question:{question_id}"
|
||||||
return await self.rdb.get(qname)
|
return await self.client.get(qname)
|
||||||
|
|
||||||
async def get_one_answer(self, answer_id: int) -> str:
|
async def get_one_answer(self, answer_id: int) -> str:
|
||||||
qname = f"quiz:answer:{answer_id}"
|
qname = f"quiz:answer:{answer_id}"
|
||||||
return await self.rdb.get(qname)
|
return await self.client.get(qname)
|
||||||
|
|
||||||
async def set_question(self, question_list: List[QuestionData] = None):
|
async def set_question(self, question_list: List[QuestionData] = None):
|
||||||
qname = "quiz:question"
|
qname = "quiz:question"
|
||||||
@ -82,25 +53,25 @@ class RedisCache:
|
|||||||
return ujson.dumps(data)
|
return ujson.dumps(data)
|
||||||
|
|
||||||
for question in question_list:
|
for question in question_list:
|
||||||
await self.rdb.set(qname + f":{question.question_id}", json_dumps(question))
|
await self.client.set(qname + f":{question.question_id}", json_dumps(question))
|
||||||
|
|
||||||
question_id_list = [question.question_id for question in question_list]
|
question_id_list = [question.question_id for question in question_list]
|
||||||
await self.rdb.lpush(qname + ":id_list", *question_id_list)
|
await self.client.lpush(qname + ":id_list", *question_id_list)
|
||||||
return await self.rdb.llen(qname + ":id_list")
|
return await self.client.llen(qname + ":id_list")
|
||||||
|
|
||||||
async def del_all_question(self, answer_list: List[AnswerData] = None):
|
async def del_all_question(self, answer_list: List[AnswerData] = None):
|
||||||
qname = "quiz:question"
|
qname = "quiz:question"
|
||||||
keys = await self.rdb.keys(qname + "*")
|
keys = await self.client.keys(qname + "*")
|
||||||
if keys is not None:
|
if keys is not None:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
await self.rdb.delete(key)
|
await self.client.delete(key)
|
||||||
|
|
||||||
async def del_all_answer(self, answer_list: List[AnswerData] = None):
|
async def del_all_answer(self, answer_list: List[AnswerData] = None):
|
||||||
qname = "quiz:answer"
|
qname = "quiz:answer"
|
||||||
keys = await self.rdb.keys(qname + "*")
|
keys = await self.client.keys(qname + "*")
|
||||||
if keys is not None:
|
if keys is not None:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
await self.rdb.delete(key)
|
await self.client.delete(key)
|
||||||
|
|
||||||
async def set_answer(self, answer_list: List[AnswerData] = None):
|
async def set_answer(self, answer_list: List[AnswerData] = None):
|
||||||
qname = "quiz:answer"
|
qname = "quiz:answer"
|
||||||
@ -109,30 +80,30 @@ class RedisCache:
|
|||||||
return ujson.dumps(obj=_answer.__dict__)
|
return ujson.dumps(obj=_answer.__dict__)
|
||||||
|
|
||||||
for answer in answer_list:
|
for answer in answer_list:
|
||||||
await self.rdb.set(qname + f":{answer.answer_id}", json_dumps(answer))
|
await self.client.set(qname + f":{answer.answer_id}", json_dumps(answer))
|
||||||
|
|
||||||
answer_id_list = [answer.answer_id for answer in answer_list]
|
answer_id_list = [answer.answer_id for answer in answer_list]
|
||||||
await self.rdb.lpush(qname + ":id_list", *answer_id_list)
|
await self.client.lpush(qname + ":id_list", *answer_id_list)
|
||||||
return await self.rdb.llen(qname + ":id_list")
|
return await self.client.llen(qname + ":id_list")
|
||||||
|
|
||||||
async def get_str_list(self, qname: str):
|
async def get_str_list(self, qname: str):
|
||||||
return [str(str_data, encoding="utf-8") for str_data in await self.rdb.lrange(qname, 0, -1)]
|
return [str(str_data, encoding="utf-8") for str_data in await self.client.lrange(qname, 0, -1)]
|
||||||
|
|
||||||
async def set_str_list(self, qname: str, str_list: List[str], ttl: int = 60):
|
async def set_str_list(self, qname: str, str_list: List[str], ttl: int = 60):
|
||||||
await self.rdb.ltrim(qname, 1, 0)
|
await self.client.ltrim(qname, 1, 0)
|
||||||
await self.rdb.lpush(qname, *str_list)
|
await self.client.lpush(qname, *str_list)
|
||||||
if ttl != -1:
|
if ttl != -1:
|
||||||
await self.rdb.expire(qname, ttl)
|
await self.client.expire(qname, ttl)
|
||||||
count = await self.rdb.llen(qname)
|
count = await self.client.llen(qname)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
async def get_int_list(self, qname: str):
|
async def get_int_list(self, qname: str):
|
||||||
return [int(str_data) for str_data in await self.rdb.lrange(qname, 0, -1)]
|
return [int(str_data) for str_data in await self.client.lrange(qname, 0, -1)]
|
||||||
|
|
||||||
async def set_int_list(self, qname: str, str_list: List[int], ttl: int = 60):
|
async def set_int_list(self, qname: str, str_list: List[int], ttl: int = 60):
|
||||||
await self.rdb.ltrim(qname, 1, 0)
|
await self.client.ltrim(qname, 1, 0)
|
||||||
await self.rdb.lpush(qname, *str_list)
|
await self.client.lpush(qname, *str_list)
|
||||||
if ttl != -1:
|
if ttl != -1:
|
||||||
await self.rdb.expire(qname, ttl)
|
await self.client.expire(qname, ttl)
|
||||||
count = await self.rdb.llen(qname)
|
count = await self.client.llen(qname)
|
||||||
return count
|
return count
|
||||||
|
@ -1,78 +1,14 @@
|
|||||||
import asyncio
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import aiomysql
|
|
||||||
|
|
||||||
from logger import Log
|
|
||||||
from model.base import ServiceEnum
|
from model.base import ServiceEnum
|
||||||
from service.base import CreateUserInfoDBDataFromSQLData, UserInfoData, CreatCookieDictFromSQLData, \
|
from service.base import CreateUserInfoDBDataFromSQLData, UserInfoData, CreatCookieDictFromSQLData, \
|
||||||
CreatQuestionFromSQLData, QuestionData, AnswerData, CreatAnswerFromSQLData
|
CreatQuestionFromSQLData, QuestionData, AnswerData, CreatAnswerFromSQLData
|
||||||
|
from utils.mysql import MySQL
|
||||||
|
|
||||||
|
|
||||||
class AsyncRepository:
|
class AsyncRepository:
|
||||||
def __init__(self, mysql_host: str = "127.0.0.1", mysql_port: int = 3306, mysql_user: str = "root",
|
def __init__(self, mysql: MySQL):
|
||||||
mysql_password: str = "", mysql_database: str = "", loop=None):
|
self.mysql = mysql
|
||||||
self._mysql_database = mysql_database
|
|
||||||
self._mysql_password = mysql_password
|
|
||||||
self._mysql_user = mysql_user
|
|
||||||
self._mysql_port = mysql_port
|
|
||||||
self._mysql_host = mysql_host
|
|
||||||
self._loop = loop
|
|
||||||
self._sql_pool = None
|
|
||||||
Log.debug(f'获取数据库配置 [host]: {self._mysql_host}')
|
|
||||||
Log.debug(f'获取数据库配置 [port]: {self._mysql_port}')
|
|
||||||
Log.debug(f'获取数据库配置 [user]: {self._mysql_user}')
|
|
||||||
Log.debug(f'获取数据库配置 [password][len]: {len(self._mysql_password)}')
|
|
||||||
Log.debug(f'获取数据库配置 [db]: {self._mysql_database}')
|
|
||||||
if self._loop is None:
|
|
||||||
self._loop = asyncio.get_event_loop()
|
|
||||||
try:
|
|
||||||
Log.info("正在创建数据库LOOP")
|
|
||||||
self._loop.run_until_complete(self.create_pool())
|
|
||||||
Log.info("创建数据库LOOP成功")
|
|
||||||
except (KeyboardInterrupt, SystemExit):
|
|
||||||
pass
|
|
||||||
except Exception as exc:
|
|
||||||
Log.error("创建数据库LOOP发生严重错误")
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
async def create_pool(self):
|
|
||||||
self._sql_pool = await aiomysql.create_pool(
|
|
||||||
host=self._mysql_host, port=self._mysql_port,
|
|
||||||
user=self._mysql_user, password=self._mysql_password,
|
|
||||||
db=self._mysql_database, loop=self._loop)
|
|
||||||
|
|
||||||
async def wait_closed(self):
|
|
||||||
if self._sql_pool is None:
|
|
||||||
return
|
|
||||||
pool = self._sql_pool
|
|
||||||
pool.close()
|
|
||||||
await pool.wait_closed()
|
|
||||||
|
|
||||||
async def _get_pool(self):
|
|
||||||
if self._sql_pool is None:
|
|
||||||
raise RuntimeError("mysql pool is none")
|
|
||||||
return self._sql_pool
|
|
||||||
|
|
||||||
async def _executemany(self, query, query_args):
|
|
||||||
pool = await self._get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
sql_cur = await conn.cursor()
|
|
||||||
await sql_cur.executemany(query, query_args)
|
|
||||||
rowcount = sql_cur.rowcount
|
|
||||||
await sql_cur.close()
|
|
||||||
await conn.commit()
|
|
||||||
return rowcount
|
|
||||||
|
|
||||||
async def _execute_and_fetchall(self, query, query_args):
|
|
||||||
pool = await self._get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
sql_cur = await conn.cursor()
|
|
||||||
await sql_cur.execute(query, query_args)
|
|
||||||
result = await sql_cur.fetchall()
|
|
||||||
await sql_cur.close()
|
|
||||||
await conn.commit()
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def update_cookie(self, user_id: int, cookie: str, service: ServiceEnum):
|
async def update_cookie(self, user_id: int, cookie: str, service: ServiceEnum):
|
||||||
if service == ServiceEnum.MIHOYOBBS:
|
if service == ServiceEnum.MIHOYOBBS:
|
||||||
@ -90,7 +26,7 @@ class AsyncRepository:
|
|||||||
else:
|
else:
|
||||||
query = ""
|
query = ""
|
||||||
query_args = (cookie, user_id)
|
query_args = (cookie, user_id)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def set_cookie(self, user_id: int, cookie: str, service: ServiceEnum):
|
async def set_cookie(self, user_id: int, cookie: str, service: ServiceEnum):
|
||||||
if service == ServiceEnum.MIHOYOBBS:
|
if service == ServiceEnum.MIHOYOBBS:
|
||||||
@ -114,7 +50,7 @@ class AsyncRepository:
|
|||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
query_args = (user_id, cookie)
|
query_args = (user_id, cookie)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def set_user_info(self, user_id: int, mihoyo_game_uid: int, hoyoverse_game_uid: int, service: int):
|
async def set_user_info(self, user_id: int, mihoyo_game_uid: int, hoyoverse_game_uid: int, service: int):
|
||||||
query = """
|
query = """
|
||||||
@ -128,7 +64,7 @@ class AsyncRepository:
|
|||||||
service=VALUES(service);
|
service=VALUES(service);
|
||||||
"""
|
"""
|
||||||
query_args = (user_id, mihoyo_game_uid, hoyoverse_game_uid, service)
|
query_args = (user_id, mihoyo_game_uid, hoyoverse_game_uid, service)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def read_mihoyo_cookie(self, user_id) -> dict:
|
async def read_mihoyo_cookie(self, user_id) -> dict:
|
||||||
query = """
|
query = """
|
||||||
@ -137,7 +73,7 @@ class AsyncRepository:
|
|||||||
WHERE user_id=%s;
|
WHERE user_id=%s;
|
||||||
"""
|
"""
|
||||||
query_args = (user_id,)
|
query_args = (user_id,)
|
||||||
data = await self._execute_and_fetchall(query, query_args)
|
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
return {}
|
return {}
|
||||||
return CreatCookieDictFromSQLData(data[0])
|
return CreatCookieDictFromSQLData(data[0])
|
||||||
@ -149,7 +85,7 @@ class AsyncRepository:
|
|||||||
WHERE user_id=%s;
|
WHERE user_id=%s;
|
||||||
"""
|
"""
|
||||||
query_args = (user_id,)
|
query_args = (user_id,)
|
||||||
data = await self._execute_and_fetchall(query, query_args)
|
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
return {}
|
return {}
|
||||||
return CreatCookieDictFromSQLData(data[0])
|
return CreatCookieDictFromSQLData(data[0])
|
||||||
@ -161,7 +97,7 @@ class AsyncRepository:
|
|||||||
WHERE user_id=%s;
|
WHERE user_id=%s;
|
||||||
"""
|
"""
|
||||||
query_args = (user_id,)
|
query_args = (user_id,)
|
||||||
data = await self._execute_and_fetchall(query, query_args)
|
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
return UserInfoData()
|
return UserInfoData()
|
||||||
return CreateUserInfoDBDataFromSQLData(data[0])
|
return CreateUserInfoDBDataFromSQLData(data[0])
|
||||||
@ -172,7 +108,7 @@ class AsyncRepository:
|
|||||||
FROM `question`
|
FROM `question`
|
||||||
"""
|
"""
|
||||||
query_args = ()
|
query_args = ()
|
||||||
data = await self._execute_and_fetchall(query, query_args)
|
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
return CreatQuestionFromSQLData(data)
|
return CreatQuestionFromSQLData(data)
|
||||||
|
|
||||||
async def read_answer_form_question_id(self, question_id: int) -> List[AnswerData]:
|
async def read_answer_form_question_id(self, question_id: int) -> List[AnswerData]:
|
||||||
@ -182,7 +118,7 @@ class AsyncRepository:
|
|||||||
WHERE question_id=%s;
|
WHERE question_id=%s;
|
||||||
"""
|
"""
|
||||||
query_args = (question_id,)
|
query_args = (question_id,)
|
||||||
data = await self._execute_and_fetchall(query, query_args)
|
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
return CreatAnswerFromSQLData(data)
|
return CreatAnswerFromSQLData(data)
|
||||||
|
|
||||||
async def save_question(self, question: str):
|
async def save_question(self, question: str):
|
||||||
@ -193,7 +129,7 @@ class AsyncRepository:
|
|||||||
(%s)
|
(%s)
|
||||||
"""
|
"""
|
||||||
query_args = (question,)
|
query_args = (question,)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def read_question(self, question: str) -> QuestionData:
|
async def read_question(self, question: str) -> QuestionData:
|
||||||
query = """
|
query = """
|
||||||
@ -202,7 +138,7 @@ class AsyncRepository:
|
|||||||
WHERE question=%s;
|
WHERE question=%s;
|
||||||
"""
|
"""
|
||||||
query_args = (question,)
|
query_args = (question,)
|
||||||
data = await self._execute_and_fetchall(query, query_args)
|
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
return CreatQuestionFromSQLData(data)[0]
|
return CreatQuestionFromSQLData(data)[0]
|
||||||
|
|
||||||
async def save_answer(self, question_id: int, is_correct: int, answer: str):
|
async def save_answer(self, question_id: int, is_correct: int, answer: str):
|
||||||
@ -213,7 +149,7 @@ class AsyncRepository:
|
|||||||
(%s,%s,%s)
|
(%s,%s,%s)
|
||||||
"""
|
"""
|
||||||
query_args = (question_id, is_correct, answer)
|
query_args = (question_id, is_correct, answer)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def delete_question(self, question_id: int):
|
async def delete_question(self, question_id: int):
|
||||||
query = """
|
query = """
|
||||||
@ -221,7 +157,7 @@ class AsyncRepository:
|
|||||||
WHERE id=%s;
|
WHERE id=%s;
|
||||||
"""
|
"""
|
||||||
query_args = (question_id,)
|
query_args = (question_id,)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def delete_answer(self, answer_id: int):
|
async def delete_answer(self, answer_id: int):
|
||||||
query = """
|
query = """
|
||||||
@ -229,7 +165,7 @@ class AsyncRepository:
|
|||||||
WHERE id=%s;
|
WHERE id=%s;
|
||||||
"""
|
"""
|
||||||
query_args = (answer_id,)
|
query_args = (answer_id,)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def delete_admin(self, user_id: int):
|
async def delete_admin(self, user_id: int):
|
||||||
query = """
|
query = """
|
||||||
@ -237,7 +173,7 @@ class AsyncRepository:
|
|||||||
WHERE user_id=%s;
|
WHERE user_id=%s;
|
||||||
"""
|
"""
|
||||||
query_args = (user_id,)
|
query_args = (user_id,)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def add_admin(self, user_id: int):
|
async def add_admin(self, user_id: int):
|
||||||
query = """
|
query = """
|
||||||
@ -247,7 +183,7 @@ class AsyncRepository:
|
|||||||
(%s)
|
(%s)
|
||||||
"""
|
"""
|
||||||
query_args = (user_id,)
|
query_args = (user_id,)
|
||||||
await self._execute_and_fetchall(query, query_args)
|
await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
|
|
||||||
async def get_admin(self) -> List[int]:
|
async def get_admin(self) -> List[int]:
|
||||||
query = """
|
query = """
|
||||||
@ -255,7 +191,7 @@ class AsyncRepository:
|
|||||||
FROM `admin`
|
FROM `admin`
|
||||||
"""
|
"""
|
||||||
query_args = ()
|
query_args = ()
|
||||||
data = await self._execute_and_fetchall(query, query_args)
|
data = await self.mysql.execute_and_fetchall(query, query_args)
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
return []
|
return []
|
||||||
return list(data[0])
|
return list(data[0])
|
||||||
|
@ -1,64 +1,24 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from jinja2 import PackageLoader, Environment, Template
|
from jinja2 import PackageLoader, Environment, Template
|
||||||
from playwright.async_api import async_playwright, Browser, ViewportSize, Playwright
|
from playwright.async_api import ViewportSize
|
||||||
|
|
||||||
from config import config
|
from config import config
|
||||||
from logger import Log
|
from logger import Log
|
||||||
|
from utils.aiobrowser import AioBrowser
|
||||||
|
|
||||||
|
|
||||||
class TemplateService:
|
class TemplateService:
|
||||||
def __init__(self, template_package_name: str = "resources", cache_dir_name: str = "cache", loop=None):
|
def __init__(self, browser: AioBrowser, template_package_name: str = "resources", cache_dir_name: str = "cache"):
|
||||||
|
self.browser = browser
|
||||||
self._template_package_name = template_package_name
|
self._template_package_name = template_package_name
|
||||||
self._browser: Optional[Browser] = None
|
|
||||||
self._playwright: Optional[Playwright] = None
|
|
||||||
self._current_dir = os.getcwd()
|
self._current_dir = os.getcwd()
|
||||||
self._output_dir = os.path.join(self._current_dir, cache_dir_name)
|
self._output_dir = os.path.join(self._current_dir, cache_dir_name)
|
||||||
if not os.path.exists(self._output_dir):
|
if not os.path.exists(self._output_dir):
|
||||||
os.mkdir(self._output_dir)
|
os.mkdir(self._output_dir)
|
||||||
self._jinja2_env = {}
|
self._jinja2_env = {}
|
||||||
self._jinja2_template = {}
|
self._jinja2_template = {}
|
||||||
self._loop = loop
|
|
||||||
if self._loop is None:
|
|
||||||
self._loop = asyncio.get_event_loop()
|
|
||||||
try:
|
|
||||||
Log.info("正在尝试启动Playwright")
|
|
||||||
self._loop.run_until_complete(self._browser_init())
|
|
||||||
Log.info("启动Playwright成功")
|
|
||||||
except (KeyboardInterrupt, SystemExit):
|
|
||||||
pass
|
|
||||||
except Exception as exc:
|
|
||||||
Log.error("启动浏览器失败 \n")
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
async def _browser_init(self) -> Browser:
|
|
||||||
if self._playwright is None:
|
|
||||||
self._playwright = await async_playwright().start()
|
|
||||||
try:
|
|
||||||
self._browser = await self._playwright.chromium.launch(timeout=5000)
|
|
||||||
except TimeoutError as err:
|
|
||||||
raise err
|
|
||||||
else:
|
|
||||||
if self._browser is None:
|
|
||||||
try:
|
|
||||||
self._browser = await self._playwright.chromium.launch(timeout=10000)
|
|
||||||
except TimeoutError as err:
|
|
||||||
raise err
|
|
||||||
return self._browser
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
if self._browser is not None:
|
|
||||||
await self._browser.close()
|
|
||||||
if self._playwright is not None:
|
|
||||||
await self._playwright.stop()
|
|
||||||
|
|
||||||
async def get_browser(self) -> Browser:
|
|
||||||
if self._browser is None:
|
|
||||||
return await self._browser_init()
|
|
||||||
return self._browser
|
|
||||||
|
|
||||||
def get_template(self, package_path: str, template_name: str) -> Template:
|
def get_template(self, package_path: str, template_name: str) -> Template:
|
||||||
if config.DEBUG:
|
if config.DEBUG:
|
||||||
|
51
utils/aiobrowser.py
Normal file
51
utils/aiobrowser.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from playwright.async_api import async_playwright, Browser, Playwright
|
||||||
|
|
||||||
|
from logger import Log
|
||||||
|
|
||||||
|
|
||||||
|
class AioBrowser:
|
||||||
|
def __init__(self, loop=None):
|
||||||
|
self.browser: Optional[Browser] = None
|
||||||
|
self._playwright: Optional[Playwright] = None
|
||||||
|
self._loop = loop
|
||||||
|
if self._loop is None:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
Log.info("正在尝试启动Playwright")
|
||||||
|
self._loop.run_until_complete(self._browser_init())
|
||||||
|
Log.info("启动Playwright成功")
|
||||||
|
except (KeyboardInterrupt, SystemExit):
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
Log.error("启动浏览器失败 \n")
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
async def _browser_init(self) -> Browser:
|
||||||
|
if self._playwright is None:
|
||||||
|
self._playwright = await async_playwright().start()
|
||||||
|
try:
|
||||||
|
self.browser = await self._playwright.chromium.launch(timeout=5000)
|
||||||
|
except TimeoutError as err:
|
||||||
|
raise err
|
||||||
|
else:
|
||||||
|
if self.browser is None:
|
||||||
|
try:
|
||||||
|
self.browser = await self._playwright.chromium.launch(timeout=10000)
|
||||||
|
except TimeoutError as err:
|
||||||
|
raise err
|
||||||
|
return self.browser
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
if self.browser is not None:
|
||||||
|
await self.browser.close()
|
||||||
|
if self._playwright is not None:
|
||||||
|
await self._playwright.stop()
|
||||||
|
|
||||||
|
async def get_browser(self) -> Browser:
|
||||||
|
if self.browser is None:
|
||||||
|
raise RuntimeError("browser is not None")
|
||||||
|
return self.browser
|
71
utils/mysql.py
Normal file
71
utils/mysql.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import aiomysql
|
||||||
|
|
||||||
|
from logger import Log
|
||||||
|
|
||||||
|
|
||||||
|
class MySQL:
|
||||||
|
def __init__(self, host: str = "127.0.0.1", port: int = 3306, user: str = "root",
|
||||||
|
password: str = "", database: str = "", loop=None):
|
||||||
|
self.database = database
|
||||||
|
self.password = password
|
||||||
|
self.user = user
|
||||||
|
self.port = port
|
||||||
|
self.host = host
|
||||||
|
self._loop = loop
|
||||||
|
self._sql_pool = None
|
||||||
|
Log.debug(f'获取数据库配置 [host]: {self.host}')
|
||||||
|
Log.debug(f'获取数据库配置 [port]: {self.port}')
|
||||||
|
Log.debug(f'获取数据库配置 [user]: {self.user}')
|
||||||
|
Log.debug(f'获取数据库配置 [password][len]: {len(self.password)}')
|
||||||
|
Log.debug(f'获取数据库配置 [db]: {self.database}')
|
||||||
|
if self._loop is None:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
Log.info("正在创建数据库LOOP")
|
||||||
|
self._loop.run_until_complete(self.create_pool())
|
||||||
|
Log.info("创建数据库LOOP成功")
|
||||||
|
except (KeyboardInterrupt, SystemExit):
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
Log.error("创建数据库LOOP发生严重错误")
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
async def wait_closed(self):
|
||||||
|
if self._sql_pool is None:
|
||||||
|
return
|
||||||
|
pool = self._sql_pool
|
||||||
|
pool.close()
|
||||||
|
await pool.wait_closed()
|
||||||
|
|
||||||
|
async def create_pool(self):
|
||||||
|
self._sql_pool = await aiomysql.create_pool(
|
||||||
|
host=self.host, port=self.port,
|
||||||
|
user=self.user, password=self.password,
|
||||||
|
db=self.database, loop=self._loop)
|
||||||
|
|
||||||
|
async def _get_pool(self):
|
||||||
|
if self._sql_pool is None:
|
||||||
|
raise RuntimeError("mysql pool is none")
|
||||||
|
return self._sql_pool
|
||||||
|
|
||||||
|
async def executemany(self, query, query_args):
|
||||||
|
pool = await self._get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
sql_cur = await conn.cursor()
|
||||||
|
await sql_cur.executemany(query, query_args)
|
||||||
|
rowcount = sql_cur.rowcount
|
||||||
|
await sql_cur.close()
|
||||||
|
await conn.commit()
|
||||||
|
return rowcount
|
||||||
|
|
||||||
|
async def execute_and_fetchall(self, query, query_args):
|
||||||
|
pool = await self._get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
sql_cur = await conn.cursor()
|
||||||
|
await sql_cur.execute(query, query_args)
|
||||||
|
result = await sql_cur.fetchall()
|
||||||
|
await sql_cur.close()
|
||||||
|
await conn.commit()
|
||||||
|
return result
|
36
utils/redisdb.py
Normal file
36
utils/redisdb.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from logger import Log
|
||||||
|
from redis import asyncio as aioredis
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDB:
|
||||||
|
|
||||||
|
def __init__(self, host="127.0.0.1", port=6379, db=0, loop=None):
|
||||||
|
Log.debug(f'获取Redis配置 [host]: {host}')
|
||||||
|
Log.debug(f'获取Redis配置 [host]: {port}')
|
||||||
|
Log.debug(f'获取Redis配置 [host]: {db}')
|
||||||
|
self.client = aioredis.Redis(host=host, port=port, db=db)
|
||||||
|
self.ttl = 600
|
||||||
|
self.key_prefix = "paimon_bot"
|
||||||
|
self._loop = loop
|
||||||
|
if self._loop is None:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
Log.info("正在尝试建立与Redis连接")
|
||||||
|
self._loop.run_until_complete(self.ping())
|
||||||
|
except (KeyboardInterrupt, SystemExit):
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
Log.error("尝试连接Redis失败 \n")
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
async def ping(self):
|
||||||
|
if await self.client.ping():
|
||||||
|
Log.info("连接Redis成功")
|
||||||
|
else:
|
||||||
|
Log.info("连接Redis失败")
|
||||||
|
raise RuntimeError("连接Redis失败")
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
await self.client.close()
|
Loading…
Reference in New Issue
Block a user