From a02022d3a3f1a5a89d109e0aeea2c74676783de9 Mon Sep 17 00:00:00 2001 From: xtaodada Date: Wed, 9 Aug 2023 15:10:54 +0800 Subject: [PATCH] feat: support spoiler --- alembic/versions/3cbe5fbdb7e3_config.py | 36 ++++++++++++++ defs/misskey.py | 37 ++++++++++++--- defs/web_app.py | 26 ++++++++++ misskey_init.py | 16 +++++-- models/models/user_config.py | 11 +++++ models/services/user_config.py | 36 ++++++++++++++ modules/start.py | 2 + modules/user_config.py | 63 +++++++++++++++++++++++++ 8 files changed, 216 insertions(+), 11 deletions(-) create mode 100644 alembic/versions/3cbe5fbdb7e3_config.py create mode 100644 defs/web_app.py create mode 100644 models/models/user_config.py create mode 100644 models/services/user_config.py create mode 100644 modules/user_config.py diff --git a/alembic/versions/3cbe5fbdb7e3_config.py b/alembic/versions/3cbe5fbdb7e3_config.py new file mode 100644 index 0000000..39e751b --- /dev/null +++ b/alembic/versions/3cbe5fbdb7e3_config.py @@ -0,0 +1,36 @@ +"""config + +Revision ID: 3cbe5fbdb7e3 +Revises: fcdaa7ac5975 +Create Date: 2023-08-09 14:36:48.093192 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "3cbe5fbdb7e3" +down_revision = "fcdaa7ac5975" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "user_config", + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("timeline_spoiler", sa.Boolean(), nullable=False), + sa.Column("push_spoiler", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("user_id"), + mysql_charset="utf8mb4", + mysql_collate="utf8mb4_general_ci", + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_config") + # ### end Alembic commands ### diff --git a/defs/misskey.py b/defs/misskey.py index c5d5157..b20ff56 100644 --- a/defs/misskey.py +++ b/defs/misskey.py @@ -202,6 +202,7 @@ async def send_photo( note: Note, reply_to_message_id: int, show_second: bool, + spoiler: bool, ) -> Message: if not url: return await send_text(host, cid, note, reply_to_message_id, show_second) @@ -213,6 +214,7 @@ async def send_photo( reply_markup=gen_button( host, note, get_user_link(host, note.author), show_second ), + has_spoiler=spoiler, ) @@ -225,6 +227,7 @@ async def send_gif( note: Note, reply_to_message_id: int, show_second: bool, + spoiler: bool, ) -> Message: if not url: return await send_text(host, cid, note, reply_to_message_id, show_second) @@ -236,6 +239,7 @@ async def send_gif( reply_markup=gen_button( host, note, get_user_link(host, note.author), show_second ), + has_spoiler=spoiler, ) @@ -248,6 +252,7 @@ async def send_video( note: Note, reply_to_message_id: int, show_second: bool, + spoiler: bool, ) -> Message: if not url: return await send_text(host, cid, note, reply_to_message_id, show_second) @@ -259,6 +264,7 @@ async def send_video( reply_markup=gen_button( host, note, get_user_link(host, note.author), show_second ), + has_spoiler=spoiler, ) @@ -311,7 +317,7 @@ async def send_document( return msg -async def get_media_group(host: str, files: list[File]) -> list: +async def get_media_group(host: str, files: list[File], spoiler: bool) -> list: media_lists = [] for file_ in files: file_url = await fetch_document(host, file_) @@ -324,6 +330,7 @@ async def get_media_group(host: str, files: list[File]) -> list: InputMediaAnimation( file_url, parse_mode=ParseMode.HTML, + has_spoiler=file_.is_sensitive and spoiler, ) ) else: @@ -331,6 +338,7 @@ async def get_media_group(host: str, files: list[File]) -> list: InputMediaPhoto( file_url, parse_mode=ParseMode.HTML, + has_spoiler=file_.is_sensitive and spoiler, ) ) elif file_type.startswith("video"): @@ -338,6 +346,7 @@ async def get_media_group(host: str, files: list[File]) -> list: InputMediaVideo( file_url, parse_mode=ParseMode.HTML, + has_spoiler=file_.is_sensitive and spoiler, ) ) elif file_type.startswith("audio"): @@ -380,8 +389,9 @@ async def send_group( note: Note, reply_to_message_id: int, show_second: bool, + spoiler: bool, ) -> List[Message]: - groups = await get_media_group(host, files) + groups = await get_media_group(host, files, spoiler) if len(groups) == 0: return [await send_text(host, cid, note, reply_to_message_id, show_second)] photo, video, audio, document, msg_ids = [], [], [], [], [] @@ -405,7 +415,12 @@ async def send_group( async def send_update( - host: str, cid: int, note: Note, topic_id: Optional[int], show_second: bool + host: str, + cid: int, + note: Note, + topic_id: Optional[int], + show_second: bool, + spoiler: bool, ) -> Message | list[Message]: files = list(note.files) if note.reply: @@ -421,13 +436,21 @@ async def send_update( url = await fetch_document(host, file) if file_type.startswith("image"): if "gif" in file_type: - return await send_gif(host, cid, url, note, topic_id, show_second) - return await send_photo(host, cid, url, note, topic_id, show_second) + return await send_gif( + host, cid, url, note, topic_id, show_second, spoiler + ) + return await send_photo( + host, cid, url, note, topic_id, show_second, spoiler + ) elif file_type.startswith("video"): - return await send_video(host, cid, url, note, topic_id, show_second) + return await send_video( + host, cid, url, note, topic_id, show_second, spoiler + ) elif file_type.startswith("audio"): return await send_audio(host, cid, url, note, topic_id, show_second) else: return await send_document(host, cid, url, note, topic_id, show_second) case _: - return await send_group(host, cid, files, note, topic_id, show_second) + return await send_group( + host, cid, files, note, topic_id, show_second, spoiler + ) diff --git a/defs/web_app.py b/defs/web_app.py new file mode 100644 index 0000000..73ca62b --- /dev/null +++ b/defs/web_app.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel +from pyrogram import filters +from pyrogram.types import Message + + +class WebAppUserConfig(BaseModel): + timeline_spoiler: bool + push_spoiler: bool + + +class WebAppData(BaseModel): + path: str + data: dict + code: int + message: str + + @property + def user_config(self) -> WebAppUserConfig: + return WebAppUserConfig(**self.data) + + +async def web_data_filter(_, __, m: Message): + return bool(m.web_app_data) + + +filter_web_data = filters.create(web_data_filter) diff --git a/misskey_init.py b/misskey_init.py index e73dae1..e6b9579 100644 --- a/misskey_init.py +++ b/misskey_init.py @@ -41,23 +41,27 @@ from defs.notice import ( ) from models.models.user import User, TokenStatusEnum +from models.models.user_config import UserConfig from models.services.no_repeat_renote import NoRepeatRenoteAction from models.services.revoke import RevokeAction from models.services.user import UserAction from init import bot, logs, sqlite +from models.services.user_config import UserConfigAction class MisskeyBot(commands.Bot): - def __init__(self, user: User): + def __init__(self, user: User, user_config: UserConfig): super().__init__() self._BotBase__on_error = self.__on_error self.user_id: int = user.user_id self.instance_user_id: str = user.instance_user_id self.tg_user: User = user + self.user_config: UserConfig = user_config self.lock = Lock() async def fetch_offline_notes(self): + return logs.info(f"{self.tg_user.user_id} 开始获取最近十条时间线") data = {"withReplies": False, "limit": 10} data = await self.core.http.request( @@ -110,6 +114,8 @@ class MisskeyBot(commands.Bot): note, self.tg_user.timeline_topic, True, + spoiler=self.user_config + and self.user_config.timeline_spoiler, ) await RevokeAction.push(self.tg_user.user_id, note.id, msgs) if self.check_push(note): @@ -119,6 +125,7 @@ class MisskeyBot(commands.Bot): note, None, False, + spoiler=self.user_config and self.user_config.push_spoiler, ) await RevokeAction.push(self.tg_user.user_id, note.id, msgs) elif notice: @@ -217,14 +224,15 @@ def get_misskey_bot(user_id: int) -> Optional[MisskeyBot]: return None if user_id not in misskey_bot_map else misskey_bot_map[user_id] -async def create_or_get_misskey_bot(user: User) -> MisskeyBot: +async def create_or_get_misskey_bot(user: User, user_config: UserConfig) -> MisskeyBot: if user.user_id not in misskey_bot_map: - misskey_bot_map[user.user_id] = MisskeyBot(user) + misskey_bot_map[user.user_id] = MisskeyBot(user, user_config) return misskey_bot_map[user.user_id] async def run(user: User): - misskey = await create_or_get_misskey_bot(user) + user_config = await UserConfigAction.get_user_config_by_id(user.user_id) + misskey = await create_or_get_misskey_bot(user, user_config) try: logs.info(f"尝试启动 Misskey Bot WS 任务 {user.user_id}") await misskey.start(f"wss://{user.host}/streaming", user.token, log_level=None) diff --git a/models/models/user_config.py b/models/models/user_config.py new file mode 100644 index 0000000..aeeab95 --- /dev/null +++ b/models/models/user_config.py @@ -0,0 +1,11 @@ +import sqlalchemy as sa +from sqlmodel import SQLModel, Field, Column + + +class UserConfig(SQLModel, table=True): + __tablename__ = "user_config" + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + + user_id: int = Field(sa_column=Column(sa.BigInteger, primary_key=True)) + timeline_spoiler: bool = Field(default=False) + push_spoiler: bool = Field(default=False) diff --git a/models/services/user_config.py b/models/services/user_config.py new file mode 100644 index 0000000..6ab96a9 --- /dev/null +++ b/models/services/user_config.py @@ -0,0 +1,36 @@ +from typing import cast, Optional + +from sqlalchemy import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from init import sqlite +from models.models.user_config import UserConfig + + +class UserConfigAction: + @staticmethod + async def add_user_config(user_config: UserConfig): + async with sqlite.session() as session: + session = cast(AsyncSession, session) + session.add(user_config) + await session.commit() + + @staticmethod + async def get_user_config_by_id(user_id: int) -> Optional[UserConfig]: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + statement = select(UserConfig).where(UserConfig.user_id == user_id) + results = await session.exec(statement) + return user[0] if (user := results.first()) else None + + @staticmethod + async def update_user_config(user_config: UserConfig): + async with sqlite.session() as session: + session = cast(AsyncSession, session) + session.add(user_config) + await session.commit() + await session.refresh(user_config) + + @staticmethod + def create_user_config(user_id: int) -> UserConfig: + return UserConfig(user_id=user_id) diff --git a/modules/start.py b/modules/start.py index e11287c..8839b18 100644 --- a/modules/start.py +++ b/modules/start.py @@ -19,6 +19,8 @@ des = f"""欢迎使用 {bot.me.first_name},这是一个用于在 Telegram 上 5. [可选] 在私聊中使用 `/bind_push [对话id]` 绑定本人发帖时推送 /unbind_push 解除绑定 +6. [可选] 在私聊中使用 `/config` 设置敏感媒体是否自动设置 Spoiler + 至此,你便可以在 Telegram 接收 Misskey 消息,同时你可以私聊我使用 /status 查看 Bot 运行状态 此 Bot 仅支持 Misskey V13 实例的账号!""" diff --git a/modules/user_config.py b/modules/user_config.py new file mode 100644 index 0000000..0d9ffc2 --- /dev/null +++ b/modules/user_config.py @@ -0,0 +1,63 @@ +import base64 +import json + +from pydantic import ValidationError +from pyrogram import Client, filters +from pyrogram.types import ( + Message, + ReplyKeyboardMarkup, + KeyboardButton, + WebAppInfo, + ReplyKeyboardRemove, +) + +from defs.web_app import WebAppData, WebAppUserConfig, filter_web_data +from glover import web_domain +from misskey_init import rerun_misskey_bot +from models.services.user_config import UserConfigAction + + +@Client.on_message(filters.incoming & filters.private & filter_web_data) +async def process_user_config(_, message: Message): + try: + data = WebAppData(**json.loads(message.web_app_data.data)).user_config + except (json.JSONDecodeError, ValidationError): + await message.reply("数据解析失败,请重试。", quote=True) + return + if user_config := await UserConfigAction.get_user_config_by_id( + message.from_user.id + ): + user_config.timeline_spoiler = data.timeline_spoiler + user_config.push_spoiler = data.push_spoiler + await UserConfigAction.update_user_config(user_config) + else: + user_config = UserConfigAction.create_user_config(message.from_user.id) + user_config.timeline_spoiler = data.timeline_spoiler + user_config.push_spoiler = data.push_spoiler + await UserConfigAction.add_user_config(user_config) + await message.reply("更新设置成功。", quote=True, reply_markup=ReplyKeyboardRemove()) + await rerun_misskey_bot(message.from_user.id) + + +async def get_user_config(user_id: int) -> str: + if user_config := await UserConfigAction.get_user_config_by_id(user_id): + data = WebAppUserConfig( + timeline_spoiler=user_config.timeline_spoiler, + push_spoiler=user_config.push_spoiler, + ).json() + else: + data = "{}" + return base64.b64encode(data.encode()).decode() + + +@Client.on_message(filters.incoming & filters.private & filters.command(["config"])) +async def notice_user_config(_, message: Message): + data = await get_user_config(message.from_user.id) + url = f"https://{web_domain}/config?bot_data={data}" + await message.reply( + "请点击下方按钮,开始设置。", + quote=True, + reply_markup=ReplyKeyboardMarkup( + [[KeyboardButton(text="web config", web_app=WebAppInfo(url=url))]] + ), + )