From c26016561a37af6f0480234003c97d883ee6ecdd Mon Sep 17 00:00:00 2001 From: omg-xtao <100690902+omg-xtao@users.noreply.github.com> Date: Sat, 18 Feb 2023 15:41:10 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Post=20support=20gif?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apihelper/client/components/hyperion.py | 7 +++-- modules/apihelper/models/genshin/hyperion.py | 26 +++++++++++++------ plugins/other/post.py | 9 +++---- tests/test_hyperion_bbs.py | 14 ++++++++++ 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/modules/apihelper/client/components/hyperion.py b/modules/apihelper/client/components/hyperion.py index e1e80fa8..ccecc003 100644 --- a/modules/apihelper/client/components/hyperion.py +++ b/modules/apihelper/client/components/hyperion.py @@ -118,8 +118,11 @@ class Hyperion: return art_list async def download_image(self, art_id: int, url: str, page: int = 0) -> ArtworkImage: - response = await self.client.get(url, params=self.get_images_params(resize=2000), timeout=10, de_json=False) - return ArtworkImage(art_id=art_id, page=page, data=response.content) + image = url.endswith(".jpg") or url.endswith(".png") + response = await self.client.get( + url, params=self.get_images_params(resize=2000) if image else None, timeout=10, de_json=False + ) + return ArtworkImage(art_id=art_id, page=page, ext=url.split(".")[-1], data=response.content) async def get_new_list(self, gids: int, type_id: int, page_size: int = 20): """ diff --git a/modules/apihelper/models/genshin/hyperion.py b/modules/apihelper/models/genshin/hyperion.py index a224cd44..a396e229 100644 --- a/modules/apihelper/models/genshin/hyperion.py +++ b/modules/apihelper/models/genshin/hyperion.py @@ -1,23 +1,31 @@ import imghdr -from typing import List, Any +from typing import List, Any, Union from pydantic import BaseModel, PrivateAttr __all__ = ("ArtworkImage", "PostInfo") +from telegram import InputMediaPhoto, InputMediaVideo, InputMediaDocument + class ArtworkImage(BaseModel): art_id: int page: int = 0 data: bytes = b"" + ext: str = "jpg" is_error: bool = False @property def format(self) -> str: - if self.is_error: - return "" - else: - imghdr.what(None, self.data) + return "" if self.is_error else (imghdr.what(None, self.data) or self.ext) + + def input_media(self, *args, **kwargs) -> Union[None, InputMediaDocument, InputMediaPhoto, InputMediaVideo]: + file_type = self.format + if file_type in {"jpg", "jpeg", "png", "webp"}: + return InputMediaPhoto(self.data, *args, **kwargs) + if file_type in {"gif", "mp4", "mov", "avi", "mkv", "webm", "flv"}: + return InputMediaVideo(self.data, *args, **kwargs) + return InputMediaDocument(self.data, *args, **kwargs) class PostInfo(BaseModel): @@ -26,6 +34,7 @@ class PostInfo(BaseModel): user_uid: int subject: str image_urls: List[str] + video_urls: List[str] created_at: int def __init__(self, _data: dict, **data: Any): @@ -34,14 +43,14 @@ class PostInfo(BaseModel): @classmethod def paste_data(cls, data: dict) -> "PostInfo": - image_urls = [] _data_post = data["post"] post = _data_post["post"] post_id = post["post_id"] subject = post["subject"] image_list = _data_post["image_list"] - for image in image_list: - image_urls.append(image["url"]) + image_urls = [image["url"] for image in image_list] + vod_list = _data_post["vod_list"] + video_urls = [vod["resolutions"][-1]["url"] for vod in vod_list] created_at = post["created_at"] user = _data_post["user"] # 用户数据 user_uid = user["uid"] # 用户ID @@ -51,6 +60,7 @@ class PostInfo(BaseModel): user_uid=user_uid, subject=subject, image_urls=image_urls, + video_urls=video_urls, created_at=created_at, ) diff --git a/plugins/other/post.py b/plugins/other/post.py index 24594331..d0d7468e 100644 --- a/plugins/other/post.py +++ b/plugins/other/post.py @@ -5,7 +5,6 @@ from telegram import ( Update, ReplyKeyboardMarkup, ReplyKeyboardRemove, - InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup, Message, @@ -210,8 +209,8 @@ class Post(Plugin.Conversation, BasePlugin.Conversation): await message.reply_text(f"警告!图片字符描述已经超过 {MessageLimit.CAPTION_LENGTH} 个字,已经切割") try: if len(post_images) > 1: - media = [InputMediaPhoto(img_info.data) for img_info in post_images] - media[0] = InputMediaPhoto(post_images[0].data, caption=post_text, parse_mode=ParseMode.MARKDOWN_V2) + media = [img_info.input_media() for img_info in post_images if img_info.format] + media[0] = post_images[0].input_media(caption=post_text, parse_mode=ParseMode.MARKDOWN_V2) if len(media) > 10: media = media[:10] await message.reply_text("获取到的图片已经超过10张,为了保证发送成功,已经删除一部分图片") @@ -388,8 +387,8 @@ class Post(Plugin.Conversation, BasePlugin.Conversation): post_text += f" \\#{tag}" try: if len(post_images) > 1: - media = [InputMediaPhoto(img_info.data) for img_info in post_images] - media[0] = InputMediaPhoto(post_images[0].data, caption=post_text, parse_mode=ParseMode.MARKDOWN_V2) + media = [img_info.input_media() for img_info in post_images if img_info.format] + media[0] = post_images[0].input_media(caption=post_text, parse_mode=ParseMode.MARKDOWN_V2) await context.bot.send_media_group(channel_id, media=media) elif len(post_images) == 1: image = post_images[0] diff --git a/tests/test_hyperion_bbs.py b/tests/test_hyperion_bbs.py index b0e6d825..54eb5863 100644 --- a/tests/test_hyperion_bbs.py +++ b/tests/test_hyperion_bbs.py @@ -37,6 +37,20 @@ async def test_get_post_info(hyperion): assert post_soup.find_all("p") +# noinspection PyShadowingNames +@pytest.mark.asyncio +@flaky(3, 1) +async def test_get_video_post_info(hyperion): + post_info = await hyperion.get_post_info(2, 33846648) + assert post_info + assert isinstance(post_info, PostInfo) + assert post_info["post"]["post"]["post_id"] == "33846648" + assert post_info.post_id == 33846648 + assert post_info["post"]["post"]["subject"] == "当然是原神了" + assert post_info.subject == "当然是原神了" + assert len(post_info.video_urls) == 1 + + # noinspection PyShadowingNames @pytest.mark.asyncio @flaky(3, 1)