From a89c56f6a93e0a38d7c2640f3ad9b3c45586d7ed Mon Sep 17 00:00:00 2001 From: xtaodada Date: Mon, 22 Jan 2024 23:13:55 +0800 Subject: [PATCH] feat: sticker export --- defs/sticker_download.py | 194 ++++++++++++++++++++++++++++++++++++ modules/sticker_download.py | 116 +++++++++++++++++++++ requirements.txt | 1 + 3 files changed, 311 insertions(+) create mode 100644 defs/sticker_download.py create mode 100644 modules/sticker_download.py diff --git a/defs/sticker_download.py b/defs/sticker_download.py new file mode 100644 index 0000000..7378967 --- /dev/null +++ b/defs/sticker_download.py @@ -0,0 +1,194 @@ +import asyncio +import contextlib +import os +import shutil +import tempfile +import zipfile +from concurrent.futures import ThreadPoolExecutor +from sys import executable +from typing import TYPE_CHECKING, Union + +from pathlib import Path + +import aiofiles +from pyrogram.enums import MessageEntityType +from pyrogram.filters import create as create_filter +from pyrogram.file_id import FileType, FileId +from pyrogram.raw.functions.messages import GetStickerSet +from pyrogram.raw.types import InputStickerSetShortName +from pyrogram.raw.types.messages import StickerSet +from pyrogram.types import Message, Sticker + +from init import logs + +if TYPE_CHECKING: + from pyrogram import Client + +temp_path = Path("data/cache") +temp_path.mkdir(parents=True, exist_ok=True) +lottie_path = Path(executable).with_name("lottie_convert.py") + + +async def _custom_emoji_filter(_, __, message: Message): + entities = message.entities or message.caption_entities + if not entities: + return False + for entity in entities: + if entity.type == MessageEntityType.CUSTOM_EMOJI: + return True + return False + + +custom_emoji_filter = create_filter(_custom_emoji_filter) + + +def get_target_file_path(src: Path) -> Path: + old_ext = src.suffix + if old_ext in [".jpeg", ".jpg", ".png", ".webp"]: + return src.with_suffix(".png") + elif old_ext in [".mp4", ".webm", ".tgs"]: + return src.with_suffix(".gif") + else: + return src.with_suffix(".mp4") + + +async def converter(src_file: Union[Path, str]) -> Path: + src_file = Path(src_file) + target_file = get_target_file_path(src_file) + if src_file.suffix == ".tgs": + process = await asyncio.create_subprocess_exec( + executable, + lottie_path, + src_file, + target_file, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + else: + process = await asyncio.create_subprocess_exec( + "ffmpeg", + "-i", + src_file, + target_file, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await process.communicate() + + if process.returncode == 0: + src_file.unlink(missing_ok=True) + else: + logs.error("转换 %s -> %s 时出错: %s", src_file.name, target_file.name, stderr.decode("utf-8")) + raise ValueError + return target_file + + +def get_file_id(doc, set_id, set_hash) -> FileId: + return FileId( + file_type=FileType.STICKER, + dc_id=doc.dc_id, + file_reference=doc.file_reference, + media_id=doc.id, + access_hash=doc.access_hash, + sticker_set_id=set_id, + sticker_set_access_hash=set_hash, + ) + + +def get_ext_from_mime(mime: str) -> str: + if mime == "image/jpeg": + return ".jpg" + elif mime == "image/png": + return ".png" + elif mime == "image/webp": + return ".webp" + elif mime == "video/mp4": + return ".mp4" + elif mime == "video/webm": + return ".webm" + elif mime == "application/x-tgsticker": + return ".tgs" + else: + return "" + + +def zip_dir(dir_path: str, zip_filepath: Path): + zipf = zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) + for root, dirs, files in os.walk(dir_path): + for file in files: + file_path = Path(root).joinpath(file) + file_name = file_path.relative_to(dir_path) + zipf.write(file_path, file_name) + zipf.close() + + +async def run_zip_dir(dir_path: str, zip_filepath: Path): + loop = asyncio.get_event_loop() + with ThreadPoolExecutor() as executor: + await loop.run_in_executor( + executor, zip_dir, dir_path, zip_filepath, + ) + + +async def edit_message(reply: "Message", text: str) -> "Message": + with contextlib.suppress(Exception): + return await reply.edit(text) + + +async def get_from_sticker_set(short_name: str, uid: int, client: "Client", reply: "Message") -> Path: + s = InputStickerSetShortName(short_name=short_name) + packs: "StickerSet" = await client.invoke(GetStickerSet(stickerset=s, hash=0)) + tempdir = tempfile.mkdtemp() + logs.info("下载贴纸包 %s", short_name) + for doc in packs.documents: + file_id = get_file_id(doc, packs.set.id, packs.set.access_hash) + ext = get_ext_from_mime(doc.mime_type) + file_path = Path(tempdir) / f"{doc.id}{ext}" + async with aiofiles.open(file_path, "wb") as file: + async for chunk in client.get_file(file_id): + await file.write(chunk) + logs.info("转换贴纸包 %s", short_name) + await edit_message(reply, "正在转换贴纸包...请耐心等待") + for f in Path(tempdir).glob("*"): + await converter(f) + logs.info("打包贴纸包 %s", short_name) + await edit_message(reply, "正在打包贴纸包...请耐心等待") + zip_file_path = temp_path / f"{uid}_{short_name}.zip" + await run_zip_dir(tempdir, zip_file_path) + shutil.rmtree(tempdir) + logs.info("发送贴纸包 %s", short_name) + await edit_message(reply, "正在发送贴纸包...请耐心等待") + return zip_file_path + + +async def get_from_sticker(client: "Client", message: "Message") -> Path: + sticker_path = temp_path / f"{message.sticker.file_unique_id}.webp" + await client.download_media(message, file_name=sticker_path.as_posix()) + return await converter(sticker_path) + + +async def get_from_custom_emoji(client: "Client", sticker: "Sticker") -> Path: + sticker_path = temp_path / f"{sticker.file_unique_id}.webp" + await client.download_media(sticker.file_id, file_name=sticker_path.as_posix()) + return await converter(sticker_path) + + +async def export_add(tempdir: str, sticker: Sticker, client: "Client"): + file_id = sticker.file_id + file_unique_id = sticker.file_unique_id + ext = sticker.file_name.split(".")[-1] + filepath: "Path" = Path(tempdir).joinpath(f"{file_unique_id}.{ext}") + await client.download_media(file_id, file_name=filepath.as_posix()) + await converter(filepath) + + +async def export_end(uid: int, tempdir: str, reply: "Message") -> Path: + if not Path(tempdir).glob("*"): + raise FileNotFoundError + logs.info("打包 %s 的批量导出的贴纸包", uid) + zip_file_path = temp_path / f"{uid}.zip" + await run_zip_dir(tempdir, zip_file_path) + shutil.rmtree(tempdir) + logs.info("发送 %s 的批量导出的贴纸包", uid) + await edit_message(reply, "正在发送贴纸包...请耐心等待") + return zip_file_path diff --git a/modules/sticker_download.py b/modules/sticker_download.py new file mode 100644 index 0000000..4e648a1 --- /dev/null +++ b/modules/sticker_download.py @@ -0,0 +1,116 @@ +import contextlib +import tempfile +from typing import TYPE_CHECKING + +from cashews import cache +from pyrogram import filters, ContinuePropagation +from pyrogram.enums import ChatAction, MessageEntityType +from pyrogram.errors import StickersetInvalid + +from defs.sticker_download import get_from_sticker_set, get_from_sticker, custom_emoji_filter, \ + get_from_custom_emoji, export_end, export_add +from init import bot + +if TYPE_CHECKING: + from pyrogram import Client + from pyrogram.types import Message + + +@bot.on_message( + filters.private & filters.text & filters.incoming & filters.regex(r"^https://t.me/addstickers/.*") +) +async def process_sticker_set(client: "Client", message: "Message"): + cid = message.from_user.id + short_name = message.text.replace("https://t.me/addstickers/", "") + file = None + try: + reply = await message.reply("正在下载贴纸包...请耐心等待", quote=True) + file = await get_from_sticker_set(short_name, cid, client, reply) + with contextlib.suppress(Exception): + await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT) + await message.reply_document(file.as_posix(), quote=True) + with contextlib.suppress(Exception): + await reply.delete() + except StickersetInvalid: + await message.reply("无效的贴纸包", quote=True) + finally: + if file: + file.unlink(missing_ok=True) + raise ContinuePropagation + + +@bot.on_message(filters.private & filters.sticker & filters.incoming) +async def process_single_sticker(client: "Client", message: "Message"): + await message.reply_chat_action(ChatAction.TYPING) + if temp := await cache.get(f"sticker:export:{message.from_user.id}"): + await export_add(temp, message.sticker, client) + await message.reply_text("成功加入导出列表,结束选择请输入 /sticker_export_end", quote=True) + else: + reply = await message.reply("正在转换贴纸...请耐心等待", quote=True) + target_file = None + try: + target_file = await get_from_sticker(client, message) + await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT) + await message.reply_document(target_file.as_posix(), quote=True) + finally: + if target_file: + target_file.unlink(missing_ok=True) + with contextlib.suppress(Exception): + await reply.delete() + raise ContinuePropagation + + +@bot.on_message(filters.private & custom_emoji_filter & filters.incoming) +async def process_custom_emoji(client: "Client", message: "Message"): + try: + stickers = await client.get_custom_emoji_stickers( + [i.custom_emoji_id for i in message.entities if i and i.type == MessageEntityType.CUSTOM_EMOJI] + ) + except AttributeError: + await message.reply("无法获取贴纸", quote=True) + raise ContinuePropagation + reply = await message.reply(f"正在下载 {len(stickers)} 个 emoji ...请耐心等待", quote=True) + for sticker in stickers: + target_file = None + try: + target_file = await get_from_custom_emoji(client, sticker) + await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT) + await message.reply_document(target_file.as_posix(), quote=True) + finally: + if target_file: + target_file.unlink(missing_ok=True) + with contextlib.suppress(Exception): + await reply.delete() + raise ContinuePropagation + + +@bot.on_message( + filters.private & filters.incoming & filters.command( + ["sticker_export_start", "sticker_export_end"] + ) +) +async def batch_start(_: "Client", message: "Message"): + uid = message.from_user.id + if "start" in message.command[0].lower(): + if await cache.get(f"sticker:export:{uid}"): + await message.reply("已经开始批量导出贴纸,请直接发送贴纸,完成选择请输入 /sticker_export_end", quote=True) + return + await cache.set(f"sticker:export:{uid}", tempfile.mkdtemp()) + await message.reply("开始批量导出贴纸,请直接发送贴纸,完成选择请输入 /sticker_export_end", quote=True) + else: + target_dir = await cache.get(f"sticker:export:{uid}") + if not target_dir: + await message.reply("未开始批量导出贴纸,请先使用命令 /sticker_export_start", quote=True) + return + file = None + try: + reply = await message.reply("正在打包贴纸包...请耐心等待", quote=True) + file = await export_end(uid, target_dir, reply) + await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT) + await message.reply_document(file.as_posix()) + except FileNotFoundError: + await message.reply("没有选择贴纸,导出失败", quote=True) + finally: + await cache.delete(f"sticker:export:{uid}") + if file: + file.unlink(missing_ok=True) diff --git a/requirements.txt b/requirements.txt index 8e53876..540139f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ sqlmodel aiosqlite aiofiles pydantic +lottie