diff --git a/config.py b/config.py index c5d04358..a0f45d20 100644 --- a/config.py +++ b/config.py @@ -18,6 +18,7 @@ class Config: self.DEBUG = False self.ADMINISTRATORS = self.get_config("administrators") self.MYSQL = self.get_config("mysql") + self.REDIS = self.get_config("redis") self.TELEGRAM = self.get_config("telegram") self.FUNCTION = self.get_config("function") diff --git a/handler.py b/handler.py index 66a9bd72..86bb9d5b 100644 --- a/handler.py +++ b/handler.py @@ -11,7 +11,6 @@ from plugins.errorhandler import error_handler from plugins.gacha import Gacha from plugins.help import Help from plugins.inline import Inline -from plugins.job_queue import JobQueue from plugins.post import Post from plugins.quiz import Quiz from plugins.sign import Sign @@ -77,7 +76,5 @@ def register_handlers(application, service: BaseService = None): application.add_handler(post_handler) inline = Inline(service) application.add_handler(InlineQueryHandler(inline.inline_query, block=False)) - job_queue = JobQueue(service) - application.job_queue.run_once(job_queue.start_job, when=3, name="start_job") application.add_handler(MessageHandler(filters.COMMAND & filters.ChatType.PRIVATE, unknown_command)) application.add_error_handler(error_handler, block=False) diff --git a/main.py b/main.py index e20c040c..351fff62 100644 --- a/main.py +++ b/main.py @@ -8,8 +8,9 @@ from config import config from handler import register_handlers from logger import Log from service import StartService -from service.cache import RedisCache -from service.repository import AsyncRepository +from utils.aiobrowser import AioBrowser +from utils.mysql import MySQL +from utils.redisdb import RedisDB # 无视相关警告 # 该警告说明在官方GITHUB的WIKI中Frequently Asked Questions里的What do the per_* settings in ConversationHandler do? @@ -21,20 +22,20 @@ def main() -> None: # 初始化数据库 Log.info("初始化数据库") - repository = AsyncRepository(mysql_host=config.MYSQL["host"], - mysql_user=config.MYSQL["user"], - mysql_password=config.MYSQL["password"], - mysql_port=config.MYSQL["port"], - mysql_database=config.MYSQL["database"] - ) + mysql = MySQL(host=config.MYSQL["host"], user=config.MYSQL["user"], password=config.MYSQL["password"], + port=config.MYSQL["port"], database=config.MYSQL["database"]) # 初始化Redis缓存 Log.info("初始化Redis缓存") - cache = RedisCache(db=6) + redis = RedisDB(host=config.REDIS["host"], port=config.REDIS["port"], db=config.REDIS["database"]) + + # 初始化Playwright + Log.info("初始化Playwright") + browser = AioBrowser() # 传入服务并启动 Log.info("传入服务并启动") - service = StartService(repository, cache) + service = StartService(mysql, redis, browser) # 构建BOT application = Application.builder().token(config.TELEGRAM["token"]).build() @@ -58,13 +59,13 @@ def main() -> None: try: # 需要关闭数据库连接 Log.info("正在关闭数据库连接") - loop.run_until_complete(repository.wait_closed()) + loop.run_until_complete(mysql.wait_closed()) # 关闭Redis连接 Log.info("正在关闭Redis连接") - loop.run_until_complete(cache.close()) + loop.run_until_complete(redis.close()) # 关闭playwright Log.info("正在关闭Playwright") - loop.run_until_complete(service.template.close()) + loop.run_until_complete(browser.close()) except (KeyboardInterrupt, SystemExit): pass except Exception as exc: diff --git a/plugins/job_queue.py b/plugins/job_queue.py deleted file mode 100644 index 6e77fd6e..00000000 --- a/plugins/job_queue.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import List - -from telegram.ext import CallbackContext - -from logger import Log -from plugins.base import BasePlugins -from service import BaseService - - -class JobQueue(BasePlugins): - - def __init__(self, service: BaseService): - super().__init__(service) - self.new_post_id_list_cache: List[int] = [] - - async def start_job(self, _: CallbackContext) -> None: - Log.info("正在检查化必要模块是否正常工作") - Log.info("正在检查Playwright") - try: - # 尝试获取 browser 如果获取失败尝试初始化 - await self.service.template.get_browser() - except TimeoutError as err: - Log.error("初始化Playwright超时,请检查日记查看错误 \n", err) - except AttributeError as err: - Log.error("初始化Playwright时变量为空,请检查日记查看错误 \n", err) - else: - Log.info("检查Playwright成功") - Log.info("检查完成") - - async def check_cookie(self, _: CallbackContext): - pass diff --git a/service/README.md b/service/README.md index fd81b19a..7d3989d2 100644 --- a/service/README.md +++ b/service/README.md @@ -1,6 +1,7 @@ # service 目录说明 ## 文件说明 + | FileName | Contribution | |:----------:|--------------| | init | 服务初始化 | diff --git a/service/__init__.py b/service/__init__.py index e70755f2..11061b5a 100644 --- a/service/__init__.py +++ b/service/__init__.py @@ -6,20 +6,23 @@ from service.quiz import QuizService from service.repository import AsyncRepository from service.template import TemplateService from service.user import UserInfoFormDB +from utils.aiobrowser import AioBrowser +from utils.mysql import MySQL +from utils.redisdb import RedisDB class BaseService: - def __init__(self, async_repository: AsyncRepository, async_cache: RedisCache): - self.repository = async_repository - self.cache = async_cache + def __init__(self, mysql: MySQL, redis: RedisDB, browser: AioBrowser): + self.repository = AsyncRepository(mysql) + self.cache = RedisCache(redis) self.user_service_db = UserInfoFormDB(self.repository) self.quiz_service = QuizService(self.repository, self.cache) self.get_game_info = GetGameInfo(self.repository, self.cache) self.gacha = GachaService(self.repository, self.cache) self.admin = AdminService(self.repository, self.cache) - self.template = TemplateService() + self.template = TemplateService(browser) class StartService(BaseService): - def __init__(self, async_repository: AsyncRepository, async_cache: RedisCache): - super().__init__(async_repository, async_cache) + def __init__(self, mysql: MySQL, redis: RedisDB, browser: AioBrowser): + super().__init__(mysql, redis, browser) diff --git a/service/cache.py b/service/cache.py index 995902c2..0b876390 100644 --- a/service/cache.py +++ b/service/cache.py @@ -1,74 +1,45 @@ -import asyncio from typing import List import ujson -from redis import asyncio as aioredis -from logger import Log from service.base import QuestionData, AnswerData +from utils.redisdb import RedisDB class RedisCache: - def __init__(self, host="127.0.0.1", port=6379, db=0, loop=None): - self._loop = asyncio.get_event_loop() - # Redis 官方文档显示 默认创建POOL连接池 - Log.debug(f'获取Redis配置 [host]: {host}') - Log.debug(f'获取Redis配置 [host]: {port}') - Log.debug(f'获取Redis配置 [host]: {db}') - self.rdb = aioredis.Redis(host=host, port=port, db=db) - self.ttl = 600 - self.key_prefix = "paimon_bot" - self._loop = loop - if self._loop is None: - self._loop = asyncio.get_event_loop() - try: - Log.info("正在尝试建立与Redis连接") - self._loop.run_until_complete(self.ping()) - except (KeyboardInterrupt, SystemExit): - pass - except Exception as exc: - Log.error("尝试连接Redis失败 \n") - raise exc - - async def ping(self): - if await self.rdb.ping(): - Log.info("连接Redis成功") - else: - Log.info("连接Redis失败") - raise RuntimeError("连接Redis失败") - - async def close(self): - await self.rdb.close() + def __init__(self, redis: RedisDB): + self.client = redis.client async def get_chat_admin(self, char_id: int): qname = f"group:admin_list:{char_id}" - return [int(str_id) for str_id in await self.rdb.lrange(qname, 0, -1)] + return [int(str_id) for str_id in await self.client.lrange(qname, 0, -1)] async def set_chat_admin(self, char_id: int, admin_list: List[int]): qname = f"group:admin_list:{char_id}" - await self.rdb.ltrim(qname, 1, 0) - await self.rdb.lpush(qname, *admin_list) - await self.rdb.expire(qname, 60) - count = await self.rdb.llen(qname) + await self.client.ltrim(qname, 1, 0) + await self.client.lpush(qname, *admin_list) + await self.client.expire(qname, 60) + count = await self.client.llen(qname) return count async def get_all_question(self) -> List[str]: qname = "quiz:question" - data_list = [qname + f":{question_id}" for question_id in await self.rdb.lrange(qname + "id_list", 0, -1)] - return await self.rdb.mget(data_list) + data_list = [qname + f":{question_id}" for question_id in + await self.client.lrange(qname + "id_list", 0, -1)] + return await self.client.mget(data_list) async def get_all_question_id_list(self) -> List[str]: qname = "quiz:question:id_list" - return await self.rdb.lrange(qname, 0, -1) + return await self.client.lrange(qname, 0, -1) async def get_one_question(self, question_id: int) -> str: qname = f"quiz:question:{question_id}" - return await self.rdb.get(qname) + return await self.client.get(qname) async def get_one_answer(self, answer_id: int) -> str: qname = f"quiz:answer:{answer_id}" - return await self.rdb.get(qname) + return await self.client.get(qname) async def set_question(self, question_list: List[QuestionData] = None): qname = "quiz:question" @@ -82,25 +53,25 @@ class RedisCache: return ujson.dumps(data) for question in question_list: - await self.rdb.set(qname + f":{question.question_id}", json_dumps(question)) + await self.client.set(qname + f":{question.question_id}", json_dumps(question)) question_id_list = [question.question_id for question in question_list] - await self.rdb.lpush(qname + ":id_list", *question_id_list) - return await self.rdb.llen(qname + ":id_list") + await self.client.lpush(qname + ":id_list", *question_id_list) + return await self.client.llen(qname + ":id_list") async def del_all_question(self, answer_list: List[AnswerData] = None): qname = "quiz:question" - keys = await self.rdb.keys(qname + "*") + keys = await self.client.keys(qname + "*") if keys is not None: for key in keys: - await self.rdb.delete(key) + await self.client.delete(key) async def del_all_answer(self, answer_list: List[AnswerData] = None): qname = "quiz:answer" - keys = await self.rdb.keys(qname + "*") + keys = await self.client.keys(qname + "*") if keys is not None: for key in keys: - await self.rdb.delete(key) + await self.client.delete(key) async def set_answer(self, answer_list: List[AnswerData] = None): qname = "quiz:answer" @@ -109,30 +80,30 @@ class RedisCache: return ujson.dumps(obj=_answer.__dict__) for answer in answer_list: - await self.rdb.set(qname + f":{answer.answer_id}", json_dumps(answer)) + await self.client.set(qname + f":{answer.answer_id}", json_dumps(answer)) answer_id_list = [answer.answer_id for answer in answer_list] - await self.rdb.lpush(qname + ":id_list", *answer_id_list) - return await self.rdb.llen(qname + ":id_list") + await self.client.lpush(qname + ":id_list", *answer_id_list) + return await self.client.llen(qname + ":id_list") async def get_str_list(self, qname: str): - return [str(str_data, encoding="utf-8") for str_data in await self.rdb.lrange(qname, 0, -1)] + return [str(str_data, encoding="utf-8") for str_data in await self.client.lrange(qname, 0, -1)] async def set_str_list(self, qname: str, str_list: List[str], ttl: int = 60): - await self.rdb.ltrim(qname, 1, 0) - await self.rdb.lpush(qname, *str_list) + await self.client.ltrim(qname, 1, 0) + await self.client.lpush(qname, *str_list) if ttl != -1: - await self.rdb.expire(qname, ttl) - count = await self.rdb.llen(qname) + await self.client.expire(qname, ttl) + count = await self.client.llen(qname) return count async def get_int_list(self, qname: str): - return [int(str_data) for str_data in await self.rdb.lrange(qname, 0, -1)] + return [int(str_data) for str_data in await self.client.lrange(qname, 0, -1)] async def set_int_list(self, qname: str, str_list: List[int], ttl: int = 60): - await self.rdb.ltrim(qname, 1, 0) - await self.rdb.lpush(qname, *str_list) + await self.client.ltrim(qname, 1, 0) + await self.client.lpush(qname, *str_list) if ttl != -1: - await self.rdb.expire(qname, ttl) - count = await self.rdb.llen(qname) + await self.client.expire(qname, ttl) + count = await self.client.llen(qname) return count diff --git a/service/repository.py b/service/repository.py index 825be4c4..c28aa19d 100644 --- a/service/repository.py +++ b/service/repository.py @@ -1,78 +1,14 @@ -import asyncio from typing import List -import aiomysql - -from logger import Log from model.base import ServiceEnum from service.base import CreateUserInfoDBDataFromSQLData, UserInfoData, CreatCookieDictFromSQLData, \ CreatQuestionFromSQLData, QuestionData, AnswerData, CreatAnswerFromSQLData +from utils.mysql import MySQL class AsyncRepository: - def __init__(self, mysql_host: str = "127.0.0.1", mysql_port: int = 3306, mysql_user: str = "root", - mysql_password: str = "", mysql_database: str = "", loop=None): - self._mysql_database = mysql_database - self._mysql_password = mysql_password - self._mysql_user = mysql_user - self._mysql_port = mysql_port - self._mysql_host = mysql_host - self._loop = loop - self._sql_pool = None - Log.debug(f'获取数据库配置 [host]: {self._mysql_host}') - Log.debug(f'获取数据库配置 [port]: {self._mysql_port}') - Log.debug(f'获取数据库配置 [user]: {self._mysql_user}') - Log.debug(f'获取数据库配置 [password][len]: {len(self._mysql_password)}') - Log.debug(f'获取数据库配置 [db]: {self._mysql_database}') - if self._loop is None: - self._loop = asyncio.get_event_loop() - try: - Log.info("正在创建数据库LOOP") - self._loop.run_until_complete(self.create_pool()) - Log.info("创建数据库LOOP成功") - except (KeyboardInterrupt, SystemExit): - pass - except Exception as exc: - Log.error("创建数据库LOOP发生严重错误") - raise exc - - async def create_pool(self): - self._sql_pool = await aiomysql.create_pool( - host=self._mysql_host, port=self._mysql_port, - user=self._mysql_user, password=self._mysql_password, - db=self._mysql_database, loop=self._loop) - - async def wait_closed(self): - if self._sql_pool is None: - return - pool = self._sql_pool - pool.close() - await pool.wait_closed() - - async def _get_pool(self): - if self._sql_pool is None: - raise RuntimeError("mysql pool is none") - return self._sql_pool - - async def _executemany(self, query, query_args): - pool = await self._get_pool() - async with pool.acquire() as conn: - sql_cur = await conn.cursor() - await sql_cur.executemany(query, query_args) - rowcount = sql_cur.rowcount - await sql_cur.close() - await conn.commit() - return rowcount - - async def _execute_and_fetchall(self, query, query_args): - pool = await self._get_pool() - async with pool.acquire() as conn: - sql_cur = await conn.cursor() - await sql_cur.execute(query, query_args) - result = await sql_cur.fetchall() - await sql_cur.close() - await conn.commit() - return result + def __init__(self, mysql: MySQL): + self.mysql = mysql async def update_cookie(self, user_id: int, cookie: str, service: ServiceEnum): if service == ServiceEnum.MIHOYOBBS: @@ -90,7 +26,7 @@ class AsyncRepository: else: query = "" query_args = (cookie, user_id) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def set_cookie(self, user_id: int, cookie: str, service: ServiceEnum): if service == ServiceEnum.MIHOYOBBS: @@ -114,7 +50,7 @@ class AsyncRepository: else: raise ValueError() query_args = (user_id, cookie) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def set_user_info(self, user_id: int, mihoyo_game_uid: int, hoyoverse_game_uid: int, service: int): query = """ @@ -128,7 +64,7 @@ class AsyncRepository: service=VALUES(service); """ query_args = (user_id, mihoyo_game_uid, hoyoverse_game_uid, service) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def read_mihoyo_cookie(self, user_id) -> dict: query = """ @@ -137,7 +73,7 @@ class AsyncRepository: WHERE user_id=%s; """ query_args = (user_id,) - data = await self._execute_and_fetchall(query, query_args) + data = await self.mysql.execute_and_fetchall(query, query_args) if len(data) == 0: return {} return CreatCookieDictFromSQLData(data[0]) @@ -149,7 +85,7 @@ class AsyncRepository: WHERE user_id=%s; """ query_args = (user_id,) - data = await self._execute_and_fetchall(query, query_args) + data = await self.mysql.execute_and_fetchall(query, query_args) if len(data) == 0: return {} return CreatCookieDictFromSQLData(data[0]) @@ -161,7 +97,7 @@ class AsyncRepository: WHERE user_id=%s; """ query_args = (user_id,) - data = await self._execute_and_fetchall(query, query_args) + data = await self.mysql.execute_and_fetchall(query, query_args) if len(data) == 0: return UserInfoData() return CreateUserInfoDBDataFromSQLData(data[0]) @@ -172,7 +108,7 @@ class AsyncRepository: FROM `question` """ query_args = () - data = await self._execute_and_fetchall(query, query_args) + data = await self.mysql.execute_and_fetchall(query, query_args) return CreatQuestionFromSQLData(data) async def read_answer_form_question_id(self, question_id: int) -> List[AnswerData]: @@ -182,7 +118,7 @@ class AsyncRepository: WHERE question_id=%s; """ query_args = (question_id,) - data = await self._execute_and_fetchall(query, query_args) + data = await self.mysql.execute_and_fetchall(query, query_args) return CreatAnswerFromSQLData(data) async def save_question(self, question: str): @@ -193,7 +129,7 @@ class AsyncRepository: (%s) """ query_args = (question,) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def read_question(self, question: str) -> QuestionData: query = """ @@ -202,7 +138,7 @@ class AsyncRepository: WHERE question=%s; """ query_args = (question,) - data = await self._execute_and_fetchall(query, query_args) + data = await self.mysql.execute_and_fetchall(query, query_args) return CreatQuestionFromSQLData(data)[0] async def save_answer(self, question_id: int, is_correct: int, answer: str): @@ -213,7 +149,7 @@ class AsyncRepository: (%s,%s,%s) """ query_args = (question_id, is_correct, answer) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def delete_question(self, question_id: int): query = """ @@ -221,7 +157,7 @@ class AsyncRepository: WHERE id=%s; """ query_args = (question_id,) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def delete_answer(self, answer_id: int): query = """ @@ -229,7 +165,7 @@ class AsyncRepository: WHERE id=%s; """ query_args = (answer_id,) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def delete_admin(self, user_id: int): query = """ @@ -237,7 +173,7 @@ class AsyncRepository: WHERE user_id=%s; """ query_args = (user_id,) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def add_admin(self, user_id: int): query = """ @@ -247,7 +183,7 @@ class AsyncRepository: (%s) """ query_args = (user_id,) - await self._execute_and_fetchall(query, query_args) + await self.mysql.execute_and_fetchall(query, query_args) async def get_admin(self) -> List[int]: query = """ @@ -255,7 +191,7 @@ class AsyncRepository: FROM `admin` """ query_args = () - data = await self._execute_and_fetchall(query, query_args) + data = await self.mysql.execute_and_fetchall(query, query_args) if len(data) == 0: return [] return list(data[0]) diff --git a/service/template.py b/service/template.py index 7e594fd0..49d6818c 100644 --- a/service/template.py +++ b/service/template.py @@ -1,64 +1,24 @@ -import asyncio import os import time -from typing import Optional from jinja2 import PackageLoader, Environment, Template -from playwright.async_api import async_playwright, Browser, ViewportSize, Playwright +from playwright.async_api import ViewportSize from config import config from logger import Log +from utils.aiobrowser import AioBrowser class TemplateService: - def __init__(self, template_package_name: str = "resources", cache_dir_name: str = "cache", loop=None): + def __init__(self, browser: AioBrowser, template_package_name: str = "resources", cache_dir_name: str = "cache"): + self.browser = browser self._template_package_name = template_package_name - self._browser: Optional[Browser] = None - self._playwright: Optional[Playwright] = None self._current_dir = os.getcwd() self._output_dir = os.path.join(self._current_dir, cache_dir_name) if not os.path.exists(self._output_dir): os.mkdir(self._output_dir) self._jinja2_env = {} self._jinja2_template = {} - self._loop = loop - if self._loop is None: - self._loop = asyncio.get_event_loop() - try: - Log.info("正在尝试启动Playwright") - self._loop.run_until_complete(self._browser_init()) - Log.info("启动Playwright成功") - except (KeyboardInterrupt, SystemExit): - pass - except Exception as exc: - Log.error("启动浏览器失败 \n") - raise exc - - async def _browser_init(self) -> Browser: - if self._playwright is None: - self._playwright = await async_playwright().start() - try: - self._browser = await self._playwright.chromium.launch(timeout=5000) - except TimeoutError as err: - raise err - else: - if self._browser is None: - try: - self._browser = await self._playwright.chromium.launch(timeout=10000) - except TimeoutError as err: - raise err - return self._browser - - async def close(self): - if self._browser is not None: - await self._browser.close() - if self._playwright is not None: - await self._playwright.stop() - - async def get_browser(self) -> Browser: - if self._browser is None: - return await self._browser_init() - return self._browser def get_template(self, package_path: str, template_name: str) -> Template: if config.DEBUG: diff --git a/utils/aiobrowser.py b/utils/aiobrowser.py new file mode 100644 index 00000000..14147e32 --- /dev/null +++ b/utils/aiobrowser.py @@ -0,0 +1,51 @@ +import asyncio +import os +from typing import Optional + +from playwright.async_api import async_playwright, Browser, Playwright + +from logger import Log + + +class AioBrowser: + def __init__(self, loop=None): + self.browser: Optional[Browser] = None + self._playwright: Optional[Playwright] = None + self._loop = loop + if self._loop is None: + self._loop = asyncio.get_event_loop() + try: + Log.info("正在尝试启动Playwright") + self._loop.run_until_complete(self._browser_init()) + Log.info("启动Playwright成功") + except (KeyboardInterrupt, SystemExit): + pass + except Exception as exc: + Log.error("启动浏览器失败 \n") + raise exc + + async def _browser_init(self) -> Browser: + if self._playwright is None: + self._playwright = await async_playwright().start() + try: + self.browser = await self._playwright.chromium.launch(timeout=5000) + except TimeoutError as err: + raise err + else: + if self.browser is None: + try: + self.browser = await self._playwright.chromium.launch(timeout=10000) + except TimeoutError as err: + raise err + return self.browser + + async def close(self): + if self.browser is not None: + await self.browser.close() + if self._playwright is not None: + await self._playwright.stop() + + async def get_browser(self) -> Browser: + if self.browser is None: + raise RuntimeError("browser is not None") + return self.browser diff --git a/utils/mysql.py b/utils/mysql.py new file mode 100644 index 00000000..3d56babb --- /dev/null +++ b/utils/mysql.py @@ -0,0 +1,71 @@ +import asyncio + +import aiomysql + +from logger import Log + + +class MySQL: + def __init__(self, host: str = "127.0.0.1", port: int = 3306, user: str = "root", + password: str = "", database: str = "", loop=None): + self.database = database + self.password = password + self.user = user + self.port = port + self.host = host + self._loop = loop + self._sql_pool = None + Log.debug(f'获取数据库配置 [host]: {self.host}') + Log.debug(f'获取数据库配置 [port]: {self.port}') + Log.debug(f'获取数据库配置 [user]: {self.user}') + Log.debug(f'获取数据库配置 [password][len]: {len(self.password)}') + Log.debug(f'获取数据库配置 [db]: {self.database}') + if self._loop is None: + self._loop = asyncio.get_event_loop() + try: + Log.info("正在创建数据库LOOP") + self._loop.run_until_complete(self.create_pool()) + Log.info("创建数据库LOOP成功") + except (KeyboardInterrupt, SystemExit): + pass + except Exception as exc: + Log.error("创建数据库LOOP发生严重错误") + raise exc + + async def wait_closed(self): + if self._sql_pool is None: + return + pool = self._sql_pool + pool.close() + await pool.wait_closed() + + async def create_pool(self): + self._sql_pool = await aiomysql.create_pool( + host=self.host, port=self.port, + user=self.user, password=self.password, + db=self.database, loop=self._loop) + + async def _get_pool(self): + if self._sql_pool is None: + raise RuntimeError("mysql pool is none") + return self._sql_pool + + async def executemany(self, query, query_args): + pool = await self._get_pool() + async with pool.acquire() as conn: + sql_cur = await conn.cursor() + await sql_cur.executemany(query, query_args) + rowcount = sql_cur.rowcount + await sql_cur.close() + await conn.commit() + return rowcount + + async def execute_and_fetchall(self, query, query_args): + pool = await self._get_pool() + async with pool.acquire() as conn: + sql_cur = await conn.cursor() + await sql_cur.execute(query, query_args) + result = await sql_cur.fetchall() + await sql_cur.close() + await conn.commit() + return result diff --git a/utils/redisdb.py b/utils/redisdb.py new file mode 100644 index 00000000..4bb6315e --- /dev/null +++ b/utils/redisdb.py @@ -0,0 +1,36 @@ +import asyncio + +from logger import Log +from redis import asyncio as aioredis + + +class RedisDB: + + def __init__(self, host="127.0.0.1", port=6379, db=0, loop=None): + Log.debug(f'获取Redis配置 [host]: {host}') + Log.debug(f'获取Redis配置 [host]: {port}') + Log.debug(f'获取Redis配置 [host]: {db}') + self.client = aioredis.Redis(host=host, port=port, db=db) + self.ttl = 600 + self.key_prefix = "paimon_bot" + self._loop = loop + if self._loop is None: + self._loop = asyncio.get_event_loop() + try: + Log.info("正在尝试建立与Redis连接") + self._loop.run_until_complete(self.ping()) + except (KeyboardInterrupt, SystemExit): + pass + except Exception as exc: + Log.error("尝试连接Redis失败 \n") + raise exc + + async def ping(self): + if await self.client.ping(): + Log.info("连接Redis成功") + else: + Log.info("连接Redis失败") + raise RuntimeError("连接Redis失败") + + async def close(self): + await self.client.close()