From 523ed3e7cb7a0e90e64bedd641973460fe7b4dfb Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Tue, 1 Jun 2021 13:57:31 +0200 Subject: [PATCH] Add support for in-memory uploads in send_media_group (#519) * Add support for in-memory uploads for send_media_group * update input_media_photo docs * update type hints Co-authored-by: Dan <14043624+delivrance@users.noreply.github.com> --- pyrogram/methods/advanced/save_file.py | 86 +++--- pyrogram/methods/messages/send_media_group.py | 251 ++++++++++++------ .../input_media/input_media_animation.py | 5 +- .../types/input_media/input_media_audio.py | 5 +- .../types/input_media/input_media_document.py | 5 +- .../types/input_media/input_media_video.py | 5 +- 6 files changed, 232 insertions(+), 125 deletions(-) diff --git a/pyrogram/methods/advanced/save_file.py b/pyrogram/methods/advanced/save_file.py index bc99b859..84361565 100644 --- a/pyrogram/methods/advanced/save_file.py +++ b/pyrogram/methods/advanced/save_file.py @@ -116,7 +116,7 @@ class SaveFile(Scaffold): else: raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer") - file_name = fp.name + file_name = getattr(fp, "name", "file.jpg") fp.seek(0, os.SEEK_END) file_size = fp.tell() @@ -148,53 +148,52 @@ class SaveFile(Scaffold): for session in pool: await session.start() - with fp: - fp.seek(part_size * file_part) + fp.seek(part_size * file_part) - while True: - chunk = fp.read(part_size) - - if not chunk: - if not is_big and not is_missing_part: - md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()]) - break - - if is_big: - rpc = raw.functions.upload.SaveBigFilePart( - file_id=file_id, - file_part=file_part, - file_total_parts=file_total_parts, - bytes=chunk - ) - else: - rpc = raw.functions.upload.SaveFilePart( - file_id=file_id, - file_part=file_part, - bytes=chunk - ) - - await queue.put(rpc) - - if is_missing_part: - return + while True: + chunk = fp.read(part_size) + if not chunk: if not is_big and not is_missing_part: - md5_sum.update(chunk) + md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()]) + break - file_part += 1 + if is_big: + rpc = raw.functions.upload.SaveBigFilePart( + file_id=file_id, + file_part=file_part, + file_total_parts=file_total_parts, + bytes=chunk + ) + else: + rpc = raw.functions.upload.SaveFilePart( + file_id=file_id, + file_part=file_part, + bytes=chunk + ) - if progress: - func = functools.partial( - progress, - min(file_part * part_size, file_size), - file_size, - *progress_args - ) + await queue.put(rpc) - if inspect.iscoroutinefunction(progress): - await func() - else: - await self.loop.run_in_executor(self.executor, func) + if is_missing_part: + return + + if not is_big and not is_missing_part: + md5_sum.update(chunk) + + file_part += 1 + + if progress: + func = functools.partial( + progress, + min(file_part * part_size, file_size), + file_size, + *progress_args + ) + + if inspect.iscoroutinefunction(progress): + await func() + else: + await self.loop.run_in_executor(self.executor, func) except StopTransmission: raise except Exception as e: @@ -222,3 +221,6 @@ class SaveFile(Scaffold): for session in pool: await session.stop() + + if isinstance(path, (str, PurePath)): + fp.close() diff --git a/pyrogram/methods/messages/send_media_group.py b/pyrogram/methods/messages/send_media_group.py index 1669ce32..4073ddec 100644 --- a/pyrogram/methods/messages/send_media_group.py +++ b/pyrogram/methods/messages/send_media_group.py @@ -19,7 +19,6 @@ import logging import os import re -import io from typing import Union, List from pyrogram import raw @@ -88,7 +87,44 @@ class SendMediaGroup(Scaffold): for i in media: if isinstance(i, types.InputMediaPhoto): - if os.path.isfile(i.media) or isinstance(i.media, io.IOBase): + if isinstance(i.media, str): + if os.path.isfile(i.media): + media = await self.send( + raw.functions.messages.UploadMedia( + peer=await self.resolve_peer(chat_id), + media=raw.types.InputMediaUploadedPhoto( + file=await self.save_file(i.media) + ) + ) + ) + + media = raw.types.InputMediaPhoto( + id=raw.types.InputPhoto( + id=media.photo.id, + access_hash=media.photo.access_hash, + file_reference=media.photo.file_reference + ) + ) + elif re.match("^https?://", i.media): + media = await self.send( + raw.functions.messages.UploadMedia( + peer=await self.resolve_peer(chat_id), + media=raw.types.InputMediaPhotoExternal( + url=i.media + ) + ) + ) + + media = raw.types.InputMediaPhoto( + id=raw.types.InputPhoto( + id=media.photo.id, + access_hash=media.photo.access_hash, + file_reference=media.photo.file_reference + ) + ) + else: + media = utils.get_input_media_from_file_id(i.media, FileType.PHOTO) + else: media = await self.send( raw.functions.messages.UploadMedia( peer=await self.resolve_peer(chat_id), @@ -105,34 +141,63 @@ class SendMediaGroup(Scaffold): file_reference=media.photo.file_reference ) ) - elif re.match("^https?://", i.media): - media = await self.send( - raw.functions.messages.UploadMedia( - peer=await self.resolve_peer(chat_id), - media=raw.types.InputMediaPhotoExternal( - url=i.media + elif isinstance(i, types.InputMediaVideo): + if isinstance(i.media, str): + if os.path.isfile(i.media): + media = await self.send( + raw.functions.messages.UploadMedia( + peer=await self.resolve_peer(chat_id), + media=raw.types.InputMediaUploadedDocument( + file=await self.save_file(i.media), + thumb=await self.save_file(i.thumb), + mime_type=self.guess_mime_type(i.media) or "video/mp4", + attributes=[ + raw.types.DocumentAttributeVideo( + supports_streaming=i.supports_streaming or None, + duration=i.duration, + w=i.width, + h=i.height + ), + raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) + ] + ) ) ) - ) - media = raw.types.InputMediaPhoto( - id=raw.types.InputPhoto( - id=media.photo.id, - access_hash=media.photo.access_hash, - file_reference=media.photo.file_reference + media = raw.types.InputMediaDocument( + id=raw.types.InputDocument( + id=media.document.id, + access_hash=media.document.access_hash, + file_reference=media.document.file_reference + ) ) - ) + elif re.match("^https?://", i.media): + media = await self.send( + raw.functions.messages.UploadMedia( + peer=await self.resolve_peer(chat_id), + media=raw.types.InputMediaDocumentExternal( + url=i.media + ) + ) + ) + + media = raw.types.InputMediaDocument( + id=raw.types.InputDocument( + id=media.document.id, + access_hash=media.document.access_hash, + file_reference=media.document.file_reference + ) + ) + else: + media = utils.get_input_media_from_file_id(i.media, FileType.VIDEO) else: - media = utils.get_input_media_from_file_id(i.media, FileType.PHOTO) - elif isinstance(i, types.InputMediaVideo): - if os.path.isfile(i.media) or isinstance(i.media, io.IOBase): media = await self.send( raw.functions.messages.UploadMedia( peer=await self.resolve_peer(chat_id), media=raw.types.InputMediaUploadedDocument( file=await self.save_file(i.media), thumb=await self.save_file(i.thumb), - mime_type=self.guess_mime_type(i.media) or "video/mp4", + mime_type=self.guess_mime_type(getattr(i.media, "name", "video.mp4")) or "video/mp4", attributes=[ raw.types.DocumentAttributeVideo( supports_streaming=i.supports_streaming or None, @@ -140,7 +205,7 @@ class SendMediaGroup(Scaffold): w=i.width, h=i.height ), - raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) + raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "video.mp4")) ] ) ) @@ -153,32 +218,60 @@ class SendMediaGroup(Scaffold): file_reference=media.document.file_reference ) ) - elif re.match("^https?://", i.media): - media = await self.send( - raw.functions.messages.UploadMedia( - peer=await self.resolve_peer(chat_id), - media=raw.types.InputMediaDocumentExternal( - url=i.media + elif isinstance(i, types.InputMediaAudio): + if isinstance(i.media, str): + if os.path.isfile(i.media): + media = await self.send( + raw.functions.messages.UploadMedia( + peer=await self.resolve_peer(chat_id), + media=raw.types.InputMediaUploadedDocument( + mime_type=self.guess_mime_type(i.media) or "audio/mpeg", + file=await self.save_file(i.media), + thumb=await self.save_file(i.thumb), + attributes=[ + raw.types.DocumentAttributeAudio( + duration=i.duration, + performer=i.performer, + title=i.title + ), + raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) + ] + ) ) ) - ) - media = raw.types.InputMediaDocument( - id=raw.types.InputDocument( - id=media.document.id, - access_hash=media.document.access_hash, - file_reference=media.document.file_reference + media = raw.types.InputMediaDocument( + id=raw.types.InputDocument( + id=media.document.id, + access_hash=media.document.access_hash, + file_reference=media.document.file_reference + ) ) - ) + elif re.match("^https?://", i.media): + media = await self.send( + raw.functions.messages.UploadMedia( + peer=await self.resolve_peer(chat_id), + media=raw.types.InputMediaDocumentExternal( + url=i.media + ) + ) + ) + + media = raw.types.InputMediaDocument( + id=raw.types.InputDocument( + id=media.document.id, + access_hash=media.document.access_hash, + file_reference=media.document.file_reference + ) + ) + else: + media = utils.get_input_media_from_file_id(i.media, FileType.AUDIO) else: - media = utils.get_input_media_from_file_id(i.media, FileType.VIDEO) - elif isinstance(i, types.InputMediaAudio): - if os.path.isfile(i.media): media = await self.send( raw.functions.messages.UploadMedia( peer=await self.resolve_peer(chat_id), media=raw.types.InputMediaUploadedDocument( - mime_type=self.guess_mime_type(i.media) or "audio/mpeg", + mime_type=self.guess_mime_type(getattr(i.media, "name", "audio.mp3")) or "audio/mpeg", file=await self.save_file(i.media), thumb=await self.save_file(i.thumb), attributes=[ @@ -187,7 +280,7 @@ class SendMediaGroup(Scaffold): performer=i.performer, title=i.title ), - raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) + raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "audio.mp3")) ] ) ) @@ -200,36 +293,61 @@ class SendMediaGroup(Scaffold): file_reference=media.document.file_reference ) ) - elif re.match("^https?://", i.media): - media = await self.send( - raw.functions.messages.UploadMedia( - peer=await self.resolve_peer(chat_id), - media=raw.types.InputMediaDocumentExternal( - url=i.media + elif isinstance(i, types.InputMediaDocument): + if isinstance(i.media, str): + if os.path.isfile(i.media): + media = await self.send( + raw.functions.messages.UploadMedia( + peer=await self.resolve_peer(chat_id), + media=raw.types.InputMediaUploadedDocument( + mime_type=self.guess_mime_type(i.media) or "application/zip", + file=await self.save_file(i.media), + thumb=await self.save_file(i.thumb), + attributes=[ + raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) + ] + ) ) ) - ) - media = raw.types.InputMediaDocument( - id=raw.types.InputDocument( - id=media.document.id, - access_hash=media.document.access_hash, - file_reference=media.document.file_reference + media = raw.types.InputMediaDocument( + id=raw.types.InputDocument( + id=media.document.id, + access_hash=media.document.access_hash, + file_reference=media.document.file_reference + ) ) - ) + elif re.match("^https?://", i.media): + media = await self.send( + raw.functions.messages.UploadMedia( + peer=await self.resolve_peer(chat_id), + media=raw.types.InputMediaDocumentExternal( + url=i.media + ) + ) + ) + + media = raw.types.InputMediaDocument( + id=raw.types.InputDocument( + id=media.document.id, + access_hash=media.document.access_hash, + file_reference=media.document.file_reference + ) + ) + else: + media = utils.get_input_media_from_file_id(i.media, FileType.DOCUMENT) else: - media = utils.get_input_media_from_file_id(i.media, FileType.AUDIO) - elif isinstance(i, types.InputMediaDocument): - if os.path.isfile(i.media): media = await self.send( raw.functions.messages.UploadMedia( peer=await self.resolve_peer(chat_id), media=raw.types.InputMediaUploadedDocument( - mime_type=self.guess_mime_type(i.media) or "application/zip", + mime_type=self.guess_mime_type( + getattr(i.media, "name", "file.zip") + ) or "application/zip", file=await self.save_file(i.media), thumb=await self.save_file(i.thumb), attributes=[ - raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) + raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "file.zip")) ] ) ) @@ -242,25 +360,8 @@ class SendMediaGroup(Scaffold): file_reference=media.document.file_reference ) ) - elif re.match("^https?://", i.media): - media = await self.send( - raw.functions.messages.UploadMedia( - peer=await self.resolve_peer(chat_id), - media=raw.types.InputMediaDocumentExternal( - url=i.media - ) - ) - ) - - media = raw.types.InputMediaDocument( - id=raw.types.InputDocument( - id=media.document.id, - access_hash=media.document.access_hash, - file_reference=media.document.file_reference - ) - ) - else: - media = utils.get_input_media_from_file_id(i.media, FileType.DOCUMENT) + else: + raise ValueError(f"{i.__class__.__name__} is not a supported type for send_media_group") multi_media.append( raw.types.InputSingleMedia( diff --git a/pyrogram/types/input_media/input_media_animation.py b/pyrogram/types/input_media/input_media_animation.py index e39a148e..9c5767fe 100644 --- a/pyrogram/types/input_media/input_media_animation.py +++ b/pyrogram/types/input_media/input_media_animation.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Optional, List +from typing import Optional, List, Union, BinaryIO from .input_media import InputMedia from ..messages_and_media import MessageEntity @@ -30,6 +30,7 @@ class InputMediaAnimation(InputMedia): Animation to send. Pass a file_id as string to send a file that exists on the Telegram servers or pass a file path as string to upload a new file that exists on your local machine or + pass a binary file-like object with its attribute “.name” set for in-memory uploads or pass an HTTP URL as a string for Telegram to get an animation from the Internet. thumb (``str``, *optional*): @@ -64,7 +65,7 @@ class InputMediaAnimation(InputMedia): def __init__( self, - media: str, + media: Union[str, BinaryIO], thumb: str = None, caption: str = "", parse_mode: Optional[str] = object, diff --git a/pyrogram/types/input_media/input_media_audio.py b/pyrogram/types/input_media/input_media_audio.py index e47bc818..3a66b5ee 100644 --- a/pyrogram/types/input_media/input_media_audio.py +++ b/pyrogram/types/input_media/input_media_audio.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Optional, List +from typing import Optional, List, BinaryIO, Union from .input_media import InputMedia from ..messages_and_media import MessageEntity @@ -32,6 +32,7 @@ class InputMediaAudio(InputMedia): Audio to send. Pass a file_id as string to send an audio that exists on the Telegram servers or pass a file path as string to upload a new audio that exists on your local machine or + pass a binary file-like object with its attribute “.name” set for in-memory uploads or pass an HTTP URL as a string for Telegram to get an audio file from the Internet. thumb (``str``, *optional*): @@ -66,7 +67,7 @@ class InputMediaAudio(InputMedia): def __init__( self, - media: str, + media: Union[str, BinaryIO], thumb: str = None, caption: str = "", parse_mode: Optional[str] = object, diff --git a/pyrogram/types/input_media/input_media_document.py b/pyrogram/types/input_media/input_media_document.py index 12b9fc94..31460060 100644 --- a/pyrogram/types/input_media/input_media_document.py +++ b/pyrogram/types/input_media/input_media_document.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Optional, List +from typing import Optional, List, Union, BinaryIO from .input_media import InputMedia from ..messages_and_media import MessageEntity @@ -30,6 +30,7 @@ class InputMediaDocument(InputMedia): File to send. Pass a file_id as string to send a file that exists on the Telegram servers or pass a file path as string to upload a new file that exists on your local machine or + pass a binary file-like object with its attribute “.name” set for in-memory uploads or pass an HTTP URL as a string for Telegram to get a file from the Internet. thumb (``str``): @@ -55,7 +56,7 @@ class InputMediaDocument(InputMedia): def __init__( self, - media: str, + media: Union[str, BinaryIO], thumb: str = None, caption: str = "", parse_mode: Optional[str] = object, diff --git a/pyrogram/types/input_media/input_media_video.py b/pyrogram/types/input_media/input_media_video.py index b0c36830..1199d862 100644 --- a/pyrogram/types/input_media/input_media_video.py +++ b/pyrogram/types/input_media/input_media_video.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Optional, List +from typing import Optional, List, Union, BinaryIO from .input_media import InputMedia from ..messages_and_media import MessageEntity @@ -31,6 +31,7 @@ class InputMediaVideo(InputMedia): Video to send. Pass a file_id as string to send a video that exists on the Telegram servers or pass a file path as string to upload a new video that exists on your local machine or + pass a binary file-like object with its attribute “.name” set for in-memory uploads or pass an HTTP URL as a string for Telegram to get a video from the Internet. thumb (``str``): @@ -68,7 +69,7 @@ class InputMediaVideo(InputMedia): def __init__( self, - media: str, + media: Union[str, BinaryIO], thumb: str = None, caption: str = "", parse_mode: Optional[str] = object,