🐛 Fix reposted post by using redis cache

This commit is contained in:
xtaodada 2024-10-31 19:04:38 +08:00
parent 48bdfde9c0
commit 7b602d503f
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659

View File

@ -3,7 +3,7 @@ import os
import re import re
from asyncio import create_subprocess_shell, subprocess from asyncio import create_subprocess_shell, subprocess
from functools import partial from functools import partial
from typing import List, Optional, Tuple, TYPE_CHECKING, Union, Dict from typing import List, Optional, Tuple, TYPE_CHECKING, Union
import aiofiles import aiofiles
from arkowrapper import ArkoWrapper from arkowrapper import ArkoWrapper
@ -26,6 +26,7 @@ from telegram.helpers import escape_markdown
from core.config import config from core.config import config
from core.plugin import Plugin, conversation, handler from core.plugin import Plugin, conversation, handler
from gram_core.basemodel import Settings from gram_core.basemodel import Settings
from gram_core.dependence.redisdb import RedisDB
from modules.apihelper.client.components.hoyolab import Hoyolab from modules.apihelper.client.components.hoyolab import Hoyolab
from modules.apihelper.client.components.hyperion import Hyperion, HyperionBase from modules.apihelper.client.components.hyperion import Hyperion, HyperionBase
from modules.apihelper.error import APIHelperException from modules.apihelper.error import APIHelperException
@ -69,12 +70,28 @@ class Post(Plugin.Conversation):
[["推送频道", "添加TAG"], ["编辑文字", "删除图片"], ["添加视频", "退出"]], True, True [["推送频道", "添加TAG"], ["编辑文字", "删除图片"], ["添加视频", "退出"]], True, True
) )
def __init__(self): def __init__(self, redis: RedisDB):
self.gids = 6 self.gids = 6
self.short_name = "sr" self.short_name = "sr"
self.last_post_id_list: Dict[PostTypeEnum, List[int]] = {PostTypeEnum.CN: [], PostTypeEnum.OS: []}
self.ffmpeg_enable = False self.ffmpeg_enable = False
self.cache_dir = os.path.join(os.getcwd(), "cache") self.cache_dir = os.path.join(os.getcwd(), "cache")
self.cache = redis.client
self.cache_key = "plugin:post:pushed"
def get_cache_key(self, bbs_type: "PostTypeEnum") -> str:
return f"{self.cache_key}:{bbs_type.value}"
async def is_posted(self, bbs_type: "PostTypeEnum", post_id: int) -> bool:
key = self.get_cache_key(bbs_type)
return await self.cache.sismember(key, post_id)
async def set_posted(self, bbs_type: "PostTypeEnum", post_id: int) -> bool:
key = self.get_cache_key(bbs_type)
return await self.cache.sadd(key, post_id)
async def is_posted_empty(self, bbs_type: "PostTypeEnum") -> bool:
key = self.get_cache_key(bbs_type)
return await self.cache.scard(key) == 0
@staticmethod @staticmethod
def get_bbs_client(bbs_type: "PostTypeEnum") -> "HyperionBase": def get_bbs_client(bbs_type: "PostTypeEnum") -> "HyperionBase":
@ -106,7 +123,6 @@ class Post(Plugin.Conversation):
async def task(self, context: "ContextTypes.DEFAULT_TYPE", post_type: "PostTypeEnum"): async def task(self, context: "ContextTypes.DEFAULT_TYPE", post_type: "PostTypeEnum"):
bbs = self.get_bbs_client(post_type) bbs = self.get_bbs_client(post_type)
temp_post_id_list: List[int] = []
# 请求推荐POST列表并处理 # 请求推荐POST列表并处理
try: try:
@ -115,24 +131,19 @@ class Post(Plugin.Conversation):
logger.error("获取首页推荐信息失败 %s", str(exc)) logger.error("获取首页推荐信息失败 %s", str(exc))
return return
for data_list in official_recommended_posts: temp_post_id_list = [post.post_id for post in official_recommended_posts]
temp_post_id_list.append(data_list.post_id)
last_post_id_list = self.last_post_id_list[post_type]
# 判断是否为空 # 判断是否为空
if len(last_post_id_list) == 0: if self.is_posted_empty(post_type):
for temp_list in temp_post_id_list: for temp_list in temp_post_id_list:
last_post_id_list.append(temp_list) await self.set_posted(post_type, temp_list)
return return
# 筛选出新推送的文章 # 筛选出新推送的文章
last_post_id_list = self.last_post_id_list[post_type] new_post_id_list = [post_id for post_id in temp_post_id_list if not await self.is_posted(post_type, post_id)]
new_post_id_list = set(temp_post_id_list).difference(set(last_post_id_list))
if not new_post_id_list: if not new_post_id_list:
return return
self.last_post_id_list[post_type] = temp_post_id_list
chat_id = post_config.chat_id or config.owner chat_id = post_config.chat_id or config.owner
for post_id in new_post_id_list: for post_id in new_post_id_list:
@ -163,6 +174,7 @@ class Post(Plugin.Conversation):
parse_mode=ParseMode.HTML, parse_mode=ParseMode.HTML,
reply_markup=InlineKeyboardMarkup(buttons), reply_markup=InlineKeyboardMarkup(buttons),
) )
await self.set_posted(post_type, post_id)
except BadRequest as exc: except BadRequest as exc:
logger.error("发送消息失败 %s", exc.message) logger.error("发送消息失败 %s", exc.message)
await bbs.close() await bbs.close()