diff --git a/config.py b/config.py index 73c3a64..5c33d3e 100644 --- a/config.py +++ b/config.py @@ -145,6 +145,8 @@ class ApplicationConfig(Settings): channels: List[int] = [] """文章推送群组""" + channels_helper: Optional[int] = None + """消息帮助频道""" verify_groups: Set[int] = set() """启用群验证功能的群组""" diff --git a/plugin/methods/__init__.py b/plugin/methods/__init__.py index 41cc56a..8cf6f83 100644 --- a/plugin/methods/__init__.py +++ b/plugin/methods/__init__.py @@ -6,6 +6,7 @@ from .get_chat import GetChat from .get_real_uid_or_offset import GetRealUidOrOffset from .get_real_user_id import GetRealUserId from .get_real_user_name import GetRealUserName +from .inline_use_data import InlineUseData from .log_user import LogUser from .migrate_data import MigrateData @@ -19,6 +20,7 @@ class PluginFuncMethods( GetRealUidOrOffset, GetRealUserId, GetRealUserName, + InlineUseData, LogUser, MigrateData, ): diff --git a/plugin/methods/get_real_uid_or_offset.py b/plugin/methods/get_real_uid_or_offset.py index 78083e2..dc76622 100644 --- a/plugin/methods/get_real_uid_or_offset.py +++ b/plugin/methods/get_real_uid_or_offset.py @@ -7,20 +7,30 @@ if TYPE_CHECKING: REGEX = r"@(\d{10})|@(\d{9})|@(\d)" +def get_real_uid_or_offset_by_text(text: str) -> Tuple[Optional[int], Optional[int]]: + if matches := re.findall(REGEX, text): + if numbers := [int(num) for match in matches for num in match if num != ""]: + if 1 < numbers[0] < 10: + return None, numbers[0] - 1 + elif numbers[0] in [0, 1]: + return None, None + return numbers[0], None + return None, None + + class GetRealUidOrOffset: @staticmethod def get_real_uid_or_offset(update: "Update") -> Tuple[Optional[int], Optional[int]]: message = update.effective_message + ilq = update.inline_query + text = None if not message: - return None, None - text = message.text or message.caption + if ilq: + text = ilq.query + else: + return None, None + else: + text = message.text or message.caption if not text: return None, None - if matches := re.findall(REGEX, text): - if numbers := [int(num) for match in matches for num in match if num != ""]: - if 1 < numbers[0] < 10: - return None, numbers[0] - 1 - elif numbers[0] in [0, 1]: - return None, None - return numbers[0], None - return None, None + return get_real_uid_or_offset_by_text(text) diff --git a/plugin/methods/inline_use_data.py b/plugin/methods/inline_use_data.py new file mode 100644 index 0000000..3d62c79 --- /dev/null +++ b/plugin/methods/inline_use_data.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from typing import Optional, List + +from telegram.ext import ContextTypes +from telegram.ext._utils.types import HandlerCallback, UT, CCT, RT + + +@dataclass +class IInlineUseData: + """Inline 使用数据""" + + text: str + hash: str + callback: HandlerCallback[UT, CCT, RT] + cookie: bool = False + player: bool = False + UID_KEY: str = "inline_uid" + + def is_show(self, has_cookie: bool, has_player: bool) -> bool: + """是否显示""" + if self.cookie and not has_cookie: + return False + if self.player and not has_player: + return False + return True + + def get_button_callback_data(self, start: str) -> str: + """获取按钮数据""" + return f"{start}|{self.hash}" + + @staticmethod + def set_uid_to_context(context: "ContextTypes.DEFAULT_TYPE", uid: int) -> None: + """设置 UID 到 Update""" + user_data = context.user_data + user_data[IInlineUseData.UID_KEY] = uid + + @staticmethod + def get_uid_from_context(context: "ContextTypes.DEFAULT_TYPE") -> int: + """从 Update 中获取 UID""" + user_data = context.user_data + if not user_data: + return 0 + return user_data.get(IInlineUseData.UID_KEY, 0) + + +class InlineUseData: + async def get_inline_use_data(self) -> List[Optional[IInlineUseData]]: + """获取 Inline 使用数据""" + return [] diff --git a/services/template/models.py b/services/template/models.py index 2b5c9cd..1d27a48 100644 --- a/services/template/models.py +++ b/services/template/models.py @@ -1,9 +1,10 @@ from enum import Enum from typing import List, Optional, Union -from telegram import InputMediaDocument, InputMediaPhoto, Message +from telegram import InputMediaDocument, InputMediaPhoto, Message, CallbackQuery, Bot from telegram.error import BadRequest +from gram_core.config import config from gram_core.services.template.cache import HtmlToFileIdCache from gram_core.services.template.error import ErrorFileType, FileIdNotFound @@ -106,10 +107,40 @@ class RenderResult: return edit_media - async def cache_file_id(self, reply: Message): - """缓存 telegram 返回的 file_id""" + async def send_photo_to_helper_channel(self, bot: "Bot"): + try: + reply = await bot.send_photo(config.channels_helper, photo=self.photo) + except BadRequest as exc: + if "Wrong file identifier" in exc.message and isinstance(self.photo, str): + await self._cache.delete_data(self.html, self.file_type.name) + raise BadRequest(message="Wrong file identifier specified") + raise exc + + await self.cache_file_id(reply) + + return reply + + async def edit_inline_media(self, callback_query: CallbackQuery, *args, **kwargs) -> bool: + """是 `message.edit_media` 的封装,上传成功后,缓存 telegram 返回的 file_id,方便重复使用""" + if self.file_type != FileType.PHOTO: + raise ErrorFileType + bot = callback_query.get_bot() + + reply = await self.send_photo_to_helper_channel(bot) + file_id = self.get_file_id(reply) + media = InputMediaPhoto(media=file_id, caption=self.caption, parse_mode=self.parse_mode, filename=self.filename) + + try: + return await callback_query.edit_message_media(media, *args, **kwargs) + except BadRequest as exc: + if "Wrong file identifier" in exc.message and isinstance(self.photo, str): + await self._cache.delete_data(self.html, self.file_type.name) + raise BadRequest(message="Wrong file identifier specified") + raise exc + + def get_file_id(self, reply: Message) -> str: if self.is_file_id(): - return + return self.photo if self.file_type == FileType.PHOTO and reply.photo: file_id = reply.photo[0].file_id @@ -117,6 +148,13 @@ class RenderResult: file_id = reply.document.file_id else: raise FileIdNotFound + return file_id + + async def cache_file_id(self, reply: Message): + """缓存 telegram 返回的 file_id""" + if self.is_file_id(): + return + file_id = self.get_file_id(reply) await self._cache.set_data(self.html, self.file_type.name, file_id, self.ttl) def is_file_id(self) -> bool: