mirror of
https://github.com/Xtao-Labs/iShotaBot.git
synced 2024-11-24 09:15:51 +00:00
feat: sticker export
This commit is contained in:
parent
bd62630a7f
commit
a89c56f6a9
194
defs/sticker_download.py
Normal file
194
defs/sticker_download.py
Normal file
@ -0,0 +1,194 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from sys import executable
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
from pyrogram.enums import MessageEntityType
|
||||
from pyrogram.filters import create as create_filter
|
||||
from pyrogram.file_id import FileType, FileId
|
||||
from pyrogram.raw.functions.messages import GetStickerSet
|
||||
from pyrogram.raw.types import InputStickerSetShortName
|
||||
from pyrogram.raw.types.messages import StickerSet
|
||||
from pyrogram.types import Message, Sticker
|
||||
|
||||
from init import logs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyrogram import Client
|
||||
|
||||
temp_path = Path("data/cache")
|
||||
temp_path.mkdir(parents=True, exist_ok=True)
|
||||
lottie_path = Path(executable).with_name("lottie_convert.py")
|
||||
|
||||
|
||||
async def _custom_emoji_filter(_, __, message: Message):
|
||||
entities = message.entities or message.caption_entities
|
||||
if not entities:
|
||||
return False
|
||||
for entity in entities:
|
||||
if entity.type == MessageEntityType.CUSTOM_EMOJI:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
custom_emoji_filter = create_filter(_custom_emoji_filter)
|
||||
|
||||
|
||||
def get_target_file_path(src: Path) -> Path:
|
||||
old_ext = src.suffix
|
||||
if old_ext in [".jpeg", ".jpg", ".png", ".webp"]:
|
||||
return src.with_suffix(".png")
|
||||
elif old_ext in [".mp4", ".webm", ".tgs"]:
|
||||
return src.with_suffix(".gif")
|
||||
else:
|
||||
return src.with_suffix(".mp4")
|
||||
|
||||
|
||||
async def converter(src_file: Union[Path, str]) -> Path:
|
||||
src_file = Path(src_file)
|
||||
target_file = get_target_file_path(src_file)
|
||||
if src_file.suffix == ".tgs":
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
executable,
|
||||
lottie_path,
|
||||
src_file,
|
||||
target_file,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
else:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
src_file,
|
||||
target_file,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
_, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
src_file.unlink(missing_ok=True)
|
||||
else:
|
||||
logs.error("转换 %s -> %s 时出错: %s", src_file.name, target_file.name, stderr.decode("utf-8"))
|
||||
raise ValueError
|
||||
return target_file
|
||||
|
||||
|
||||
def get_file_id(doc, set_id, set_hash) -> FileId:
|
||||
return FileId(
|
||||
file_type=FileType.STICKER,
|
||||
dc_id=doc.dc_id,
|
||||
file_reference=doc.file_reference,
|
||||
media_id=doc.id,
|
||||
access_hash=doc.access_hash,
|
||||
sticker_set_id=set_id,
|
||||
sticker_set_access_hash=set_hash,
|
||||
)
|
||||
|
||||
|
||||
def get_ext_from_mime(mime: str) -> str:
|
||||
if mime == "image/jpeg":
|
||||
return ".jpg"
|
||||
elif mime == "image/png":
|
||||
return ".png"
|
||||
elif mime == "image/webp":
|
||||
return ".webp"
|
||||
elif mime == "video/mp4":
|
||||
return ".mp4"
|
||||
elif mime == "video/webm":
|
||||
return ".webm"
|
||||
elif mime == "application/x-tgsticker":
|
||||
return ".tgs"
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def zip_dir(dir_path: str, zip_filepath: Path):
|
||||
zipf = zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED)
|
||||
for root, dirs, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
file_path = Path(root).joinpath(file)
|
||||
file_name = file_path.relative_to(dir_path)
|
||||
zipf.write(file_path, file_name)
|
||||
zipf.close()
|
||||
|
||||
|
||||
async def run_zip_dir(dir_path: str, zip_filepath: Path):
|
||||
loop = asyncio.get_event_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
await loop.run_in_executor(
|
||||
executor, zip_dir, dir_path, zip_filepath,
|
||||
)
|
||||
|
||||
|
||||
async def edit_message(reply: "Message", text: str) -> "Message":
|
||||
with contextlib.suppress(Exception):
|
||||
return await reply.edit(text)
|
||||
|
||||
|
||||
async def get_from_sticker_set(short_name: str, uid: int, client: "Client", reply: "Message") -> Path:
|
||||
s = InputStickerSetShortName(short_name=short_name)
|
||||
packs: "StickerSet" = await client.invoke(GetStickerSet(stickerset=s, hash=0))
|
||||
tempdir = tempfile.mkdtemp()
|
||||
logs.info("下载贴纸包 %s", short_name)
|
||||
for doc in packs.documents:
|
||||
file_id = get_file_id(doc, packs.set.id, packs.set.access_hash)
|
||||
ext = get_ext_from_mime(doc.mime_type)
|
||||
file_path = Path(tempdir) / f"{doc.id}{ext}"
|
||||
async with aiofiles.open(file_path, "wb") as file:
|
||||
async for chunk in client.get_file(file_id):
|
||||
await file.write(chunk)
|
||||
logs.info("转换贴纸包 %s", short_name)
|
||||
await edit_message(reply, "正在转换贴纸包...请耐心等待")
|
||||
for f in Path(tempdir).glob("*"):
|
||||
await converter(f)
|
||||
logs.info("打包贴纸包 %s", short_name)
|
||||
await edit_message(reply, "正在打包贴纸包...请耐心等待")
|
||||
zip_file_path = temp_path / f"{uid}_{short_name}.zip"
|
||||
await run_zip_dir(tempdir, zip_file_path)
|
||||
shutil.rmtree(tempdir)
|
||||
logs.info("发送贴纸包 %s", short_name)
|
||||
await edit_message(reply, "正在发送贴纸包...请耐心等待")
|
||||
return zip_file_path
|
||||
|
||||
|
||||
async def get_from_sticker(client: "Client", message: "Message") -> Path:
|
||||
sticker_path = temp_path / f"{message.sticker.file_unique_id}.webp"
|
||||
await client.download_media(message, file_name=sticker_path.as_posix())
|
||||
return await converter(sticker_path)
|
||||
|
||||
|
||||
async def get_from_custom_emoji(client: "Client", sticker: "Sticker") -> Path:
|
||||
sticker_path = temp_path / f"{sticker.file_unique_id}.webp"
|
||||
await client.download_media(sticker.file_id, file_name=sticker_path.as_posix())
|
||||
return await converter(sticker_path)
|
||||
|
||||
|
||||
async def export_add(tempdir: str, sticker: Sticker, client: "Client"):
|
||||
file_id = sticker.file_id
|
||||
file_unique_id = sticker.file_unique_id
|
||||
ext = sticker.file_name.split(".")[-1]
|
||||
filepath: "Path" = Path(tempdir).joinpath(f"{file_unique_id}.{ext}")
|
||||
await client.download_media(file_id, file_name=filepath.as_posix())
|
||||
await converter(filepath)
|
||||
|
||||
|
||||
async def export_end(uid: int, tempdir: str, reply: "Message") -> Path:
|
||||
if not Path(tempdir).glob("*"):
|
||||
raise FileNotFoundError
|
||||
logs.info("打包 %s 的批量导出的贴纸包", uid)
|
||||
zip_file_path = temp_path / f"{uid}.zip"
|
||||
await run_zip_dir(tempdir, zip_file_path)
|
||||
shutil.rmtree(tempdir)
|
||||
logs.info("发送 %s 的批量导出的贴纸包", uid)
|
||||
await edit_message(reply, "正在发送贴纸包...请耐心等待")
|
||||
return zip_file_path
|
116
modules/sticker_download.py
Normal file
116
modules/sticker_download.py
Normal file
@ -0,0 +1,116 @@
|
||||
import contextlib
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cashews import cache
|
||||
from pyrogram import filters, ContinuePropagation
|
||||
from pyrogram.enums import ChatAction, MessageEntityType
|
||||
from pyrogram.errors import StickersetInvalid
|
||||
|
||||
from defs.sticker_download import get_from_sticker_set, get_from_sticker, custom_emoji_filter, \
|
||||
get_from_custom_emoji, export_end, export_add
|
||||
from init import bot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyrogram import Client
|
||||
from pyrogram.types import Message
|
||||
|
||||
|
||||
@bot.on_message(
|
||||
filters.private & filters.text & filters.incoming & filters.regex(r"^https://t.me/addstickers/.*")
|
||||
)
|
||||
async def process_sticker_set(client: "Client", message: "Message"):
|
||||
cid = message.from_user.id
|
||||
short_name = message.text.replace("https://t.me/addstickers/", "")
|
||||
file = None
|
||||
try:
|
||||
reply = await message.reply("正在下载贴纸包...请耐心等待", quote=True)
|
||||
file = await get_from_sticker_set(short_name, cid, client, reply)
|
||||
with contextlib.suppress(Exception):
|
||||
await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT)
|
||||
await message.reply_document(file.as_posix(), quote=True)
|
||||
with contextlib.suppress(Exception):
|
||||
await reply.delete()
|
||||
except StickersetInvalid:
|
||||
await message.reply("无效的贴纸包", quote=True)
|
||||
finally:
|
||||
if file:
|
||||
file.unlink(missing_ok=True)
|
||||
raise ContinuePropagation
|
||||
|
||||
|
||||
@bot.on_message(filters.private & filters.sticker & filters.incoming)
|
||||
async def process_single_sticker(client: "Client", message: "Message"):
|
||||
await message.reply_chat_action(ChatAction.TYPING)
|
||||
if temp := await cache.get(f"sticker:export:{message.from_user.id}"):
|
||||
await export_add(temp, message.sticker, client)
|
||||
await message.reply_text("成功加入导出列表,结束选择请输入 /sticker_export_end", quote=True)
|
||||
else:
|
||||
reply = await message.reply("正在转换贴纸...请耐心等待", quote=True)
|
||||
target_file = None
|
||||
try:
|
||||
target_file = await get_from_sticker(client, message)
|
||||
await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT)
|
||||
await message.reply_document(target_file.as_posix(), quote=True)
|
||||
finally:
|
||||
if target_file:
|
||||
target_file.unlink(missing_ok=True)
|
||||
with contextlib.suppress(Exception):
|
||||
await reply.delete()
|
||||
raise ContinuePropagation
|
||||
|
||||
|
||||
@bot.on_message(filters.private & custom_emoji_filter & filters.incoming)
|
||||
async def process_custom_emoji(client: "Client", message: "Message"):
|
||||
try:
|
||||
stickers = await client.get_custom_emoji_stickers(
|
||||
[i.custom_emoji_id for i in message.entities if i and i.type == MessageEntityType.CUSTOM_EMOJI]
|
||||
)
|
||||
except AttributeError:
|
||||
await message.reply("无法获取贴纸", quote=True)
|
||||
raise ContinuePropagation
|
||||
reply = await message.reply(f"正在下载 {len(stickers)} 个 emoji ...请耐心等待", quote=True)
|
||||
for sticker in stickers:
|
||||
target_file = None
|
||||
try:
|
||||
target_file = await get_from_custom_emoji(client, sticker)
|
||||
await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT)
|
||||
await message.reply_document(target_file.as_posix(), quote=True)
|
||||
finally:
|
||||
if target_file:
|
||||
target_file.unlink(missing_ok=True)
|
||||
with contextlib.suppress(Exception):
|
||||
await reply.delete()
|
||||
raise ContinuePropagation
|
||||
|
||||
|
||||
@bot.on_message(
|
||||
filters.private & filters.incoming & filters.command(
|
||||
["sticker_export_start", "sticker_export_end"]
|
||||
)
|
||||
)
|
||||
async def batch_start(_: "Client", message: "Message"):
|
||||
uid = message.from_user.id
|
||||
if "start" in message.command[0].lower():
|
||||
if await cache.get(f"sticker:export:{uid}"):
|
||||
await message.reply("已经开始批量导出贴纸,请直接发送贴纸,完成选择请输入 /sticker_export_end", quote=True)
|
||||
return
|
||||
await cache.set(f"sticker:export:{uid}", tempfile.mkdtemp())
|
||||
await message.reply("开始批量导出贴纸,请直接发送贴纸,完成选择请输入 /sticker_export_end", quote=True)
|
||||
else:
|
||||
target_dir = await cache.get(f"sticker:export:{uid}")
|
||||
if not target_dir:
|
||||
await message.reply("未开始批量导出贴纸,请先使用命令 /sticker_export_start", quote=True)
|
||||
return
|
||||
file = None
|
||||
try:
|
||||
reply = await message.reply("正在打包贴纸包...请耐心等待", quote=True)
|
||||
file = await export_end(uid, target_dir, reply)
|
||||
await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT)
|
||||
await message.reply_document(file.as_posix())
|
||||
except FileNotFoundError:
|
||||
await message.reply("没有选择贴纸,导出失败", quote=True)
|
||||
finally:
|
||||
await cache.delete(f"sticker:export:{uid}")
|
||||
if file:
|
||||
file.unlink(missing_ok=True)
|
@ -18,3 +18,4 @@ sqlmodel
|
||||
aiosqlite
|
||||
aiofiles
|
||||
pydantic
|
||||
lottie
|
||||
|
Loading…
Reference in New Issue
Block a user