feat: sticker export

This commit is contained in:
xtaodada 2024-01-22 23:13:55 +08:00
parent bd62630a7f
commit a89c56f6a9
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
3 changed files with 311 additions and 0 deletions

194
defs/sticker_download.py Normal file
View File

@ -0,0 +1,194 @@
import asyncio
import contextlib
import os
import shutil
import tempfile
import zipfile
from concurrent.futures import ThreadPoolExecutor
from sys import executable
from typing import TYPE_CHECKING, Union
from pathlib import Path
import aiofiles
from pyrogram.enums import MessageEntityType
from pyrogram.filters import create as create_filter
from pyrogram.file_id import FileType, FileId
from pyrogram.raw.functions.messages import GetStickerSet
from pyrogram.raw.types import InputStickerSetShortName
from pyrogram.raw.types.messages import StickerSet
from pyrogram.types import Message, Sticker
from init import logs
if TYPE_CHECKING:
from pyrogram import Client
temp_path = Path("data/cache")
temp_path.mkdir(parents=True, exist_ok=True)
lottie_path = Path(executable).with_name("lottie_convert.py")
async def _custom_emoji_filter(_, __, message: Message):
entities = message.entities or message.caption_entities
if not entities:
return False
for entity in entities:
if entity.type == MessageEntityType.CUSTOM_EMOJI:
return True
return False
custom_emoji_filter = create_filter(_custom_emoji_filter)
def get_target_file_path(src: Path) -> Path:
old_ext = src.suffix
if old_ext in [".jpeg", ".jpg", ".png", ".webp"]:
return src.with_suffix(".png")
elif old_ext in [".mp4", ".webm", ".tgs"]:
return src.with_suffix(".gif")
else:
return src.with_suffix(".mp4")
async def converter(src_file: Union[Path, str]) -> Path:
src_file = Path(src_file)
target_file = get_target_file_path(src_file)
if src_file.suffix == ".tgs":
process = await asyncio.create_subprocess_exec(
executable,
lottie_path,
src_file,
target_file,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
else:
process = await asyncio.create_subprocess_exec(
"ffmpeg",
"-i",
src_file,
target_file,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await process.communicate()
if process.returncode == 0:
src_file.unlink(missing_ok=True)
else:
logs.error("转换 %s -> %s 时出错: %s", src_file.name, target_file.name, stderr.decode("utf-8"))
raise ValueError
return target_file
def get_file_id(doc, set_id, set_hash) -> FileId:
return FileId(
file_type=FileType.STICKER,
dc_id=doc.dc_id,
file_reference=doc.file_reference,
media_id=doc.id,
access_hash=doc.access_hash,
sticker_set_id=set_id,
sticker_set_access_hash=set_hash,
)
def get_ext_from_mime(mime: str) -> str:
if mime == "image/jpeg":
return ".jpg"
elif mime == "image/png":
return ".png"
elif mime == "image/webp":
return ".webp"
elif mime == "video/mp4":
return ".mp4"
elif mime == "video/webm":
return ".webm"
elif mime == "application/x-tgsticker":
return ".tgs"
else:
return ""
def zip_dir(dir_path: str, zip_filepath: Path):
zipf = zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED)
for root, dirs, files in os.walk(dir_path):
for file in files:
file_path = Path(root).joinpath(file)
file_name = file_path.relative_to(dir_path)
zipf.write(file_path, file_name)
zipf.close()
async def run_zip_dir(dir_path: str, zip_filepath: Path):
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
await loop.run_in_executor(
executor, zip_dir, dir_path, zip_filepath,
)
async def edit_message(reply: "Message", text: str) -> "Message":
with contextlib.suppress(Exception):
return await reply.edit(text)
async def get_from_sticker_set(short_name: str, uid: int, client: "Client", reply: "Message") -> Path:
s = InputStickerSetShortName(short_name=short_name)
packs: "StickerSet" = await client.invoke(GetStickerSet(stickerset=s, hash=0))
tempdir = tempfile.mkdtemp()
logs.info("下载贴纸包 %s", short_name)
for doc in packs.documents:
file_id = get_file_id(doc, packs.set.id, packs.set.access_hash)
ext = get_ext_from_mime(doc.mime_type)
file_path = Path(tempdir) / f"{doc.id}{ext}"
async with aiofiles.open(file_path, "wb") as file:
async for chunk in client.get_file(file_id):
await file.write(chunk)
logs.info("转换贴纸包 %s", short_name)
await edit_message(reply, "正在转换贴纸包...请耐心等待")
for f in Path(tempdir).glob("*"):
await converter(f)
logs.info("打包贴纸包 %s", short_name)
await edit_message(reply, "正在打包贴纸包...请耐心等待")
zip_file_path = temp_path / f"{uid}_{short_name}.zip"
await run_zip_dir(tempdir, zip_file_path)
shutil.rmtree(tempdir)
logs.info("发送贴纸包 %s", short_name)
await edit_message(reply, "正在发送贴纸包...请耐心等待")
return zip_file_path
async def get_from_sticker(client: "Client", message: "Message") -> Path:
sticker_path = temp_path / f"{message.sticker.file_unique_id}.webp"
await client.download_media(message, file_name=sticker_path.as_posix())
return await converter(sticker_path)
async def get_from_custom_emoji(client: "Client", sticker: "Sticker") -> Path:
sticker_path = temp_path / f"{sticker.file_unique_id}.webp"
await client.download_media(sticker.file_id, file_name=sticker_path.as_posix())
return await converter(sticker_path)
async def export_add(tempdir: str, sticker: Sticker, client: "Client"):
file_id = sticker.file_id
file_unique_id = sticker.file_unique_id
ext = sticker.file_name.split(".")[-1]
filepath: "Path" = Path(tempdir).joinpath(f"{file_unique_id}.{ext}")
await client.download_media(file_id, file_name=filepath.as_posix())
await converter(filepath)
async def export_end(uid: int, tempdir: str, reply: "Message") -> Path:
if not Path(tempdir).glob("*"):
raise FileNotFoundError
logs.info("打包 %s 的批量导出的贴纸包", uid)
zip_file_path = temp_path / f"{uid}.zip"
await run_zip_dir(tempdir, zip_file_path)
shutil.rmtree(tempdir)
logs.info("发送 %s 的批量导出的贴纸包", uid)
await edit_message(reply, "正在发送贴纸包...请耐心等待")
return zip_file_path

116
modules/sticker_download.py Normal file
View File

@ -0,0 +1,116 @@
import contextlib
import tempfile
from typing import TYPE_CHECKING
from cashews import cache
from pyrogram import filters, ContinuePropagation
from pyrogram.enums import ChatAction, MessageEntityType
from pyrogram.errors import StickersetInvalid
from defs.sticker_download import get_from_sticker_set, get_from_sticker, custom_emoji_filter, \
get_from_custom_emoji, export_end, export_add
from init import bot
if TYPE_CHECKING:
from pyrogram import Client
from pyrogram.types import Message
@bot.on_message(
filters.private & filters.text & filters.incoming & filters.regex(r"^https://t.me/addstickers/.*")
)
async def process_sticker_set(client: "Client", message: "Message"):
cid = message.from_user.id
short_name = message.text.replace("https://t.me/addstickers/", "")
file = None
try:
reply = await message.reply("正在下载贴纸包...请耐心等待", quote=True)
file = await get_from_sticker_set(short_name, cid, client, reply)
with contextlib.suppress(Exception):
await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT)
await message.reply_document(file.as_posix(), quote=True)
with contextlib.suppress(Exception):
await reply.delete()
except StickersetInvalid:
await message.reply("无效的贴纸包", quote=True)
finally:
if file:
file.unlink(missing_ok=True)
raise ContinuePropagation
@bot.on_message(filters.private & filters.sticker & filters.incoming)
async def process_single_sticker(client: "Client", message: "Message"):
await message.reply_chat_action(ChatAction.TYPING)
if temp := await cache.get(f"sticker:export:{message.from_user.id}"):
await export_add(temp, message.sticker, client)
await message.reply_text("成功加入导出列表,结束选择请输入 /sticker_export_end", quote=True)
else:
reply = await message.reply("正在转换贴纸...请耐心等待", quote=True)
target_file = None
try:
target_file = await get_from_sticker(client, message)
await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT)
await message.reply_document(target_file.as_posix(), quote=True)
finally:
if target_file:
target_file.unlink(missing_ok=True)
with contextlib.suppress(Exception):
await reply.delete()
raise ContinuePropagation
@bot.on_message(filters.private & custom_emoji_filter & filters.incoming)
async def process_custom_emoji(client: "Client", message: "Message"):
try:
stickers = await client.get_custom_emoji_stickers(
[i.custom_emoji_id for i in message.entities if i and i.type == MessageEntityType.CUSTOM_EMOJI]
)
except AttributeError:
await message.reply("无法获取贴纸", quote=True)
raise ContinuePropagation
reply = await message.reply(f"正在下载 {len(stickers)} 个 emoji ...请耐心等待", quote=True)
for sticker in stickers:
target_file = None
try:
target_file = await get_from_custom_emoji(client, sticker)
await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT)
await message.reply_document(target_file.as_posix(), quote=True)
finally:
if target_file:
target_file.unlink(missing_ok=True)
with contextlib.suppress(Exception):
await reply.delete()
raise ContinuePropagation
@bot.on_message(
filters.private & filters.incoming & filters.command(
["sticker_export_start", "sticker_export_end"]
)
)
async def batch_start(_: "Client", message: "Message"):
uid = message.from_user.id
if "start" in message.command[0].lower():
if await cache.get(f"sticker:export:{uid}"):
await message.reply("已经开始批量导出贴纸,请直接发送贴纸,完成选择请输入 /sticker_export_end", quote=True)
return
await cache.set(f"sticker:export:{uid}", tempfile.mkdtemp())
await message.reply("开始批量导出贴纸,请直接发送贴纸,完成选择请输入 /sticker_export_end", quote=True)
else:
target_dir = await cache.get(f"sticker:export:{uid}")
if not target_dir:
await message.reply("未开始批量导出贴纸,请先使用命令 /sticker_export_start", quote=True)
return
file = None
try:
reply = await message.reply("正在打包贴纸包...请耐心等待", quote=True)
file = await export_end(uid, target_dir, reply)
await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT)
await message.reply_document(file.as_posix())
except FileNotFoundError:
await message.reply("没有选择贴纸,导出失败", quote=True)
finally:
await cache.delete(f"sticker:export:{uid}")
if file:
file.unlink(missing_ok=True)

View File

@ -18,3 +18,4 @@ sqlmodel
aiosqlite
aiofiles
pydantic
lottie