Add support for in-memory uploads in send_media_group (#519)

* Add support for in-memory uploads for send_media_group

* update input_media_photo docs

* update type hints

Co-authored-by: Dan <14043624+delivrance@users.noreply.github.com>
This commit is contained in:
Dan 2021-06-01 13:57:31 +02:00
parent 0d12d8c1bb
commit 523ed3e7cb
6 changed files with 232 additions and 125 deletions

View File

@ -116,7 +116,7 @@ class SaveFile(Scaffold):
else: else:
raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer") raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer")
file_name = fp.name file_name = getattr(fp, "name", "file.jpg")
fp.seek(0, os.SEEK_END) fp.seek(0, os.SEEK_END)
file_size = fp.tell() file_size = fp.tell()
@ -148,53 +148,52 @@ class SaveFile(Scaffold):
for session in pool: for session in pool:
await session.start() await session.start()
with fp: fp.seek(part_size * file_part)
fp.seek(part_size * file_part)
while True: while True:
chunk = fp.read(part_size) chunk = fp.read(part_size)
if not chunk:
if not is_big and not is_missing_part:
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
break
if is_big:
rpc = raw.functions.upload.SaveBigFilePart(
file_id=file_id,
file_part=file_part,
file_total_parts=file_total_parts,
bytes=chunk
)
else:
rpc = raw.functions.upload.SaveFilePart(
file_id=file_id,
file_part=file_part,
bytes=chunk
)
await queue.put(rpc)
if is_missing_part:
return
if not chunk:
if not is_big and not is_missing_part: if not is_big and not is_missing_part:
md5_sum.update(chunk) md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
break
file_part += 1 if is_big:
rpc = raw.functions.upload.SaveBigFilePart(
file_id=file_id,
file_part=file_part,
file_total_parts=file_total_parts,
bytes=chunk
)
else:
rpc = raw.functions.upload.SaveFilePart(
file_id=file_id,
file_part=file_part,
bytes=chunk
)
if progress: await queue.put(rpc)
func = functools.partial(
progress,
min(file_part * part_size, file_size),
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress): if is_missing_part:
await func() return
else:
await self.loop.run_in_executor(self.executor, func) if not is_big and not is_missing_part:
md5_sum.update(chunk)
file_part += 1
if progress:
func = functools.partial(
progress,
min(file_part * part_size, file_size),
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
except StopTransmission: except StopTransmission:
raise raise
except Exception as e: except Exception as e:
@ -222,3 +221,6 @@ class SaveFile(Scaffold):
for session in pool: for session in pool:
await session.stop() await session.stop()
if isinstance(path, (str, PurePath)):
fp.close()

View File

@ -19,7 +19,6 @@
import logging import logging
import os import os
import re import re
import io
from typing import Union, List from typing import Union, List
from pyrogram import raw from pyrogram import raw
@ -88,7 +87,44 @@ class SendMediaGroup(Scaffold):
for i in media: for i in media:
if isinstance(i, types.InputMediaPhoto): if isinstance(i, types.InputMediaPhoto):
if os.path.isfile(i.media) or isinstance(i.media, io.IOBase): if isinstance(i.media, str):
if os.path.isfile(i.media):
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedPhoto(
file=await self.save_file(i.media)
)
)
)
media = raw.types.InputMediaPhoto(
id=raw.types.InputPhoto(
id=media.photo.id,
access_hash=media.photo.access_hash,
file_reference=media.photo.file_reference
)
)
elif re.match("^https?://", i.media):
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaPhotoExternal(
url=i.media
)
)
)
media = raw.types.InputMediaPhoto(
id=raw.types.InputPhoto(
id=media.photo.id,
access_hash=media.photo.access_hash,
file_reference=media.photo.file_reference
)
)
else:
media = utils.get_input_media_from_file_id(i.media, FileType.PHOTO)
else:
media = await self.send( media = await self.send(
raw.functions.messages.UploadMedia( raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
@ -105,34 +141,63 @@ class SendMediaGroup(Scaffold):
file_reference=media.photo.file_reference file_reference=media.photo.file_reference
) )
) )
elif re.match("^https?://", i.media): elif isinstance(i, types.InputMediaVideo):
media = await self.send( if isinstance(i.media, str):
raw.functions.messages.UploadMedia( if os.path.isfile(i.media):
peer=await self.resolve_peer(chat_id), media = await self.send(
media=raw.types.InputMediaPhotoExternal( raw.functions.messages.UploadMedia(
url=i.media peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument(
file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb),
mime_type=self.guess_mime_type(i.media) or "video/mp4",
attributes=[
raw.types.DocumentAttributeVideo(
supports_streaming=i.supports_streaming or None,
duration=i.duration,
w=i.width,
h=i.height
),
raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
]
)
) )
) )
)
media = raw.types.InputMediaPhoto( media = raw.types.InputMediaDocument(
id=raw.types.InputPhoto( id=raw.types.InputDocument(
id=media.photo.id, id=media.document.id,
access_hash=media.photo.access_hash, access_hash=media.document.access_hash,
file_reference=media.photo.file_reference file_reference=media.document.file_reference
)
) )
) elif re.match("^https?://", i.media):
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaDocumentExternal(
url=i.media
)
)
)
media = raw.types.InputMediaDocument(
id=raw.types.InputDocument(
id=media.document.id,
access_hash=media.document.access_hash,
file_reference=media.document.file_reference
)
)
else:
media = utils.get_input_media_from_file_id(i.media, FileType.VIDEO)
else: else:
media = utils.get_input_media_from_file_id(i.media, FileType.PHOTO)
elif isinstance(i, types.InputMediaVideo):
if os.path.isfile(i.media) or isinstance(i.media, io.IOBase):
media = await self.send( media = await self.send(
raw.functions.messages.UploadMedia( raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument( media=raw.types.InputMediaUploadedDocument(
file=await self.save_file(i.media), file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb), thumb=await self.save_file(i.thumb),
mime_type=self.guess_mime_type(i.media) or "video/mp4", mime_type=self.guess_mime_type(getattr(i.media, "name", "video.mp4")) or "video/mp4",
attributes=[ attributes=[
raw.types.DocumentAttributeVideo( raw.types.DocumentAttributeVideo(
supports_streaming=i.supports_streaming or None, supports_streaming=i.supports_streaming or None,
@ -140,7 +205,7 @@ class SendMediaGroup(Scaffold):
w=i.width, w=i.width,
h=i.height h=i.height
), ),
raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "video.mp4"))
] ]
) )
) )
@ -153,32 +218,60 @@ class SendMediaGroup(Scaffold):
file_reference=media.document.file_reference file_reference=media.document.file_reference
) )
) )
elif re.match("^https?://", i.media): elif isinstance(i, types.InputMediaAudio):
media = await self.send( if isinstance(i.media, str):
raw.functions.messages.UploadMedia( if os.path.isfile(i.media):
peer=await self.resolve_peer(chat_id), media = await self.send(
media=raw.types.InputMediaDocumentExternal( raw.functions.messages.UploadMedia(
url=i.media peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(i.media) or "audio/mpeg",
file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb),
attributes=[
raw.types.DocumentAttributeAudio(
duration=i.duration,
performer=i.performer,
title=i.title
),
raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
]
)
) )
) )
)
media = raw.types.InputMediaDocument( media = raw.types.InputMediaDocument(
id=raw.types.InputDocument( id=raw.types.InputDocument(
id=media.document.id, id=media.document.id,
access_hash=media.document.access_hash, access_hash=media.document.access_hash,
file_reference=media.document.file_reference file_reference=media.document.file_reference
)
) )
) elif re.match("^https?://", i.media):
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaDocumentExternal(
url=i.media
)
)
)
media = raw.types.InputMediaDocument(
id=raw.types.InputDocument(
id=media.document.id,
access_hash=media.document.access_hash,
file_reference=media.document.file_reference
)
)
else:
media = utils.get_input_media_from_file_id(i.media, FileType.AUDIO)
else: else:
media = utils.get_input_media_from_file_id(i.media, FileType.VIDEO)
elif isinstance(i, types.InputMediaAudio):
if os.path.isfile(i.media):
media = await self.send( media = await self.send(
raw.functions.messages.UploadMedia( raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument( media=raw.types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(i.media) or "audio/mpeg", mime_type=self.guess_mime_type(getattr(i.media, "name", "audio.mp3")) or "audio/mpeg",
file=await self.save_file(i.media), file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb), thumb=await self.save_file(i.thumb),
attributes=[ attributes=[
@ -187,7 +280,7 @@ class SendMediaGroup(Scaffold):
performer=i.performer, performer=i.performer,
title=i.title title=i.title
), ),
raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "audio.mp3"))
] ]
) )
) )
@ -200,36 +293,61 @@ class SendMediaGroup(Scaffold):
file_reference=media.document.file_reference file_reference=media.document.file_reference
) )
) )
elif re.match("^https?://", i.media): elif isinstance(i, types.InputMediaDocument):
media = await self.send( if isinstance(i.media, str):
raw.functions.messages.UploadMedia( if os.path.isfile(i.media):
peer=await self.resolve_peer(chat_id), media = await self.send(
media=raw.types.InputMediaDocumentExternal( raw.functions.messages.UploadMedia(
url=i.media peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(i.media) or "application/zip",
file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb),
attributes=[
raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
]
)
) )
) )
)
media = raw.types.InputMediaDocument( media = raw.types.InputMediaDocument(
id=raw.types.InputDocument( id=raw.types.InputDocument(
id=media.document.id, id=media.document.id,
access_hash=media.document.access_hash, access_hash=media.document.access_hash,
file_reference=media.document.file_reference file_reference=media.document.file_reference
)
) )
) elif re.match("^https?://", i.media):
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaDocumentExternal(
url=i.media
)
)
)
media = raw.types.InputMediaDocument(
id=raw.types.InputDocument(
id=media.document.id,
access_hash=media.document.access_hash,
file_reference=media.document.file_reference
)
)
else:
media = utils.get_input_media_from_file_id(i.media, FileType.DOCUMENT)
else: else:
media = utils.get_input_media_from_file_id(i.media, FileType.AUDIO)
elif isinstance(i, types.InputMediaDocument):
if os.path.isfile(i.media):
media = await self.send( media = await self.send(
raw.functions.messages.UploadMedia( raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument( media=raw.types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(i.media) or "application/zip", mime_type=self.guess_mime_type(
getattr(i.media, "name", "file.zip")
) or "application/zip",
file=await self.save_file(i.media), file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb), thumb=await self.save_file(i.thumb),
attributes=[ attributes=[
raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media)) raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "file.zip"))
] ]
) )
) )
@ -242,25 +360,8 @@ class SendMediaGroup(Scaffold):
file_reference=media.document.file_reference file_reference=media.document.file_reference
) )
) )
elif re.match("^https?://", i.media): else:
media = await self.send( raise ValueError(f"{i.__class__.__name__} is not a supported type for send_media_group")
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaDocumentExternal(
url=i.media
)
)
)
media = raw.types.InputMediaDocument(
id=raw.types.InputDocument(
id=media.document.id,
access_hash=media.document.access_hash,
file_reference=media.document.file_reference
)
)
else:
media = utils.get_input_media_from_file_id(i.media, FileType.DOCUMENT)
multi_media.append( multi_media.append(
raw.types.InputSingleMedia( raw.types.InputSingleMedia(

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, List from typing import Optional, List, Union, BinaryIO
from .input_media import InputMedia from .input_media import InputMedia
from ..messages_and_media import MessageEntity from ..messages_and_media import MessageEntity
@ -30,6 +30,7 @@ class InputMediaAnimation(InputMedia):
Animation to send. Animation to send.
Pass a file_id as string to send a file that exists on the Telegram servers or Pass a file_id as string to send a file that exists on the Telegram servers or
pass a file path as string to upload a new file that exists on your local machine or pass a file path as string to upload a new file that exists on your local machine or
pass a binary file-like object with its attribute .name set for in-memory uploads or
pass an HTTP URL as a string for Telegram to get an animation from the Internet. pass an HTTP URL as a string for Telegram to get an animation from the Internet.
thumb (``str``, *optional*): thumb (``str``, *optional*):
@ -64,7 +65,7 @@ class InputMediaAnimation(InputMedia):
def __init__( def __init__(
self, self,
media: str, media: Union[str, BinaryIO],
thumb: str = None, thumb: str = None,
caption: str = "", caption: str = "",
parse_mode: Optional[str] = object, parse_mode: Optional[str] = object,

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, List from typing import Optional, List, BinaryIO, Union
from .input_media import InputMedia from .input_media import InputMedia
from ..messages_and_media import MessageEntity from ..messages_and_media import MessageEntity
@ -32,6 +32,7 @@ class InputMediaAudio(InputMedia):
Audio to send. Audio to send.
Pass a file_id as string to send an audio that exists on the Telegram servers or Pass a file_id as string to send an audio that exists on the Telegram servers or
pass a file path as string to upload a new audio that exists on your local machine or pass a file path as string to upload a new audio that exists on your local machine or
pass a binary file-like object with its attribute .name set for in-memory uploads or
pass an HTTP URL as a string for Telegram to get an audio file from the Internet. pass an HTTP URL as a string for Telegram to get an audio file from the Internet.
thumb (``str``, *optional*): thumb (``str``, *optional*):
@ -66,7 +67,7 @@ class InputMediaAudio(InputMedia):
def __init__( def __init__(
self, self,
media: str, media: Union[str, BinaryIO],
thumb: str = None, thumb: str = None,
caption: str = "", caption: str = "",
parse_mode: Optional[str] = object, parse_mode: Optional[str] = object,

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, List from typing import Optional, List, Union, BinaryIO
from .input_media import InputMedia from .input_media import InputMedia
from ..messages_and_media import MessageEntity from ..messages_and_media import MessageEntity
@ -30,6 +30,7 @@ class InputMediaDocument(InputMedia):
File to send. File to send.
Pass a file_id as string to send a file that exists on the Telegram servers or Pass a file_id as string to send a file that exists on the Telegram servers or
pass a file path as string to upload a new file that exists on your local machine or pass a file path as string to upload a new file that exists on your local machine or
pass a binary file-like object with its attribute .name set for in-memory uploads or
pass an HTTP URL as a string for Telegram to get a file from the Internet. pass an HTTP URL as a string for Telegram to get a file from the Internet.
thumb (``str``): thumb (``str``):
@ -55,7 +56,7 @@ class InputMediaDocument(InputMedia):
def __init__( def __init__(
self, self,
media: str, media: Union[str, BinaryIO],
thumb: str = None, thumb: str = None,
caption: str = "", caption: str = "",
parse_mode: Optional[str] = object, parse_mode: Optional[str] = object,

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, List from typing import Optional, List, Union, BinaryIO
from .input_media import InputMedia from .input_media import InputMedia
from ..messages_and_media import MessageEntity from ..messages_and_media import MessageEntity
@ -31,6 +31,7 @@ class InputMediaVideo(InputMedia):
Video to send. Video to send.
Pass a file_id as string to send a video that exists on the Telegram servers or Pass a file_id as string to send a video that exists on the Telegram servers or
pass a file path as string to upload a new video that exists on your local machine or pass a file path as string to upload a new video that exists on your local machine or
pass a binary file-like object with its attribute .name set for in-memory uploads or
pass an HTTP URL as a string for Telegram to get a video from the Internet. pass an HTTP URL as a string for Telegram to get a video from the Internet.
thumb (``str``): thumb (``str``):
@ -68,7 +69,7 @@ class InputMediaVideo(InputMedia):
def __init__( def __init__(
self, self,
media: str, media: Union[str, BinaryIO],
thumb: str = None, thumb: str = None,
caption: str = "", caption: str = "",
parse_mode: Optional[str] = object, parse_mode: Optional[str] = object,