From 523ed3e7cb7a0e90e64bedd641973460fe7b4dfb Mon Sep 17 00:00:00 2001
From: Dan <14043624+delivrance@users.noreply.github.com>
Date: Tue, 1 Jun 2021 13:57:31 +0200
Subject: [PATCH] Add support for in-memory uploads in send_media_group (#519)
* Add support for in-memory uploads for send_media_group
* update input_media_photo docs
* update type hints
Co-authored-by: Dan <14043624+delivrance@users.noreply.github.com>
---
pyrogram/methods/advanced/save_file.py | 86 +++---
pyrogram/methods/messages/send_media_group.py | 251 ++++++++++++------
.../input_media/input_media_animation.py | 5 +-
.../types/input_media/input_media_audio.py | 5 +-
.../types/input_media/input_media_document.py | 5 +-
.../types/input_media/input_media_video.py | 5 +-
6 files changed, 232 insertions(+), 125 deletions(-)
diff --git a/pyrogram/methods/advanced/save_file.py b/pyrogram/methods/advanced/save_file.py
index bc99b859..84361565 100644
--- a/pyrogram/methods/advanced/save_file.py
+++ b/pyrogram/methods/advanced/save_file.py
@@ -116,7 +116,7 @@ class SaveFile(Scaffold):
else:
raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer")
- file_name = fp.name
+ file_name = getattr(fp, "name", "file.jpg")
fp.seek(0, os.SEEK_END)
file_size = fp.tell()
@@ -148,53 +148,52 @@ class SaveFile(Scaffold):
for session in pool:
await session.start()
- with fp:
- fp.seek(part_size * file_part)
+ fp.seek(part_size * file_part)
- while True:
- chunk = fp.read(part_size)
-
- if not chunk:
- if not is_big and not is_missing_part:
- md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
- break
-
- if is_big:
- rpc = raw.functions.upload.SaveBigFilePart(
- file_id=file_id,
- file_part=file_part,
- file_total_parts=file_total_parts,
- bytes=chunk
- )
- else:
- rpc = raw.functions.upload.SaveFilePart(
- file_id=file_id,
- file_part=file_part,
- bytes=chunk
- )
-
- await queue.put(rpc)
-
- if is_missing_part:
- return
+ while True:
+ chunk = fp.read(part_size)
+ if not chunk:
if not is_big and not is_missing_part:
- md5_sum.update(chunk)
+ md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
+ break
- file_part += 1
+ if is_big:
+ rpc = raw.functions.upload.SaveBigFilePart(
+ file_id=file_id,
+ file_part=file_part,
+ file_total_parts=file_total_parts,
+ bytes=chunk
+ )
+ else:
+ rpc = raw.functions.upload.SaveFilePart(
+ file_id=file_id,
+ file_part=file_part,
+ bytes=chunk
+ )
- if progress:
- func = functools.partial(
- progress,
- min(file_part * part_size, file_size),
- file_size,
- *progress_args
- )
+ await queue.put(rpc)
- if inspect.iscoroutinefunction(progress):
- await func()
- else:
- await self.loop.run_in_executor(self.executor, func)
+ if is_missing_part:
+ return
+
+ if not is_big and not is_missing_part:
+ md5_sum.update(chunk)
+
+ file_part += 1
+
+ if progress:
+ func = functools.partial(
+ progress,
+ min(file_part * part_size, file_size),
+ file_size,
+ *progress_args
+ )
+
+ if inspect.iscoroutinefunction(progress):
+ await func()
+ else:
+ await self.loop.run_in_executor(self.executor, func)
except StopTransmission:
raise
except Exception as e:
@@ -222,3 +221,6 @@ class SaveFile(Scaffold):
for session in pool:
await session.stop()
+
+ if isinstance(path, (str, PurePath)):
+ fp.close()
diff --git a/pyrogram/methods/messages/send_media_group.py b/pyrogram/methods/messages/send_media_group.py
index 1669ce32..4073ddec 100644
--- a/pyrogram/methods/messages/send_media_group.py
+++ b/pyrogram/methods/messages/send_media_group.py
@@ -19,7 +19,6 @@
import logging
import os
import re
-import io
from typing import Union, List
from pyrogram import raw
@@ -88,7 +87,44 @@ class SendMediaGroup(Scaffold):
for i in media:
if isinstance(i, types.InputMediaPhoto):
- if os.path.isfile(i.media) or isinstance(i.media, io.IOBase):
+ if isinstance(i.media, str):
+ if os.path.isfile(i.media):
+ media = await self.send(
+ raw.functions.messages.UploadMedia(
+ peer=await self.resolve_peer(chat_id),
+ media=raw.types.InputMediaUploadedPhoto(
+ file=await self.save_file(i.media)
+ )
+ )
+ )
+
+ media = raw.types.InputMediaPhoto(
+ id=raw.types.InputPhoto(
+ id=media.photo.id,
+ access_hash=media.photo.access_hash,
+ file_reference=media.photo.file_reference
+ )
+ )
+ elif re.match("^https?://", i.media):
+ media = await self.send(
+ raw.functions.messages.UploadMedia(
+ peer=await self.resolve_peer(chat_id),
+ media=raw.types.InputMediaPhotoExternal(
+ url=i.media
+ )
+ )
+ )
+
+ media = raw.types.InputMediaPhoto(
+ id=raw.types.InputPhoto(
+ id=media.photo.id,
+ access_hash=media.photo.access_hash,
+ file_reference=media.photo.file_reference
+ )
+ )
+ else:
+ media = utils.get_input_media_from_file_id(i.media, FileType.PHOTO)
+ else:
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
@@ -105,34 +141,63 @@ class SendMediaGroup(Scaffold):
file_reference=media.photo.file_reference
)
)
- elif re.match("^https?://", i.media):
- media = await self.send(
- raw.functions.messages.UploadMedia(
- peer=await self.resolve_peer(chat_id),
- media=raw.types.InputMediaPhotoExternal(
- url=i.media
+ elif isinstance(i, types.InputMediaVideo):
+ if isinstance(i.media, str):
+ if os.path.isfile(i.media):
+ media = await self.send(
+ raw.functions.messages.UploadMedia(
+ peer=await self.resolve_peer(chat_id),
+ media=raw.types.InputMediaUploadedDocument(
+ file=await self.save_file(i.media),
+ thumb=await self.save_file(i.thumb),
+ mime_type=self.guess_mime_type(i.media) or "video/mp4",
+ attributes=[
+ raw.types.DocumentAttributeVideo(
+ supports_streaming=i.supports_streaming or None,
+ duration=i.duration,
+ w=i.width,
+ h=i.height
+ ),
+ raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
+ ]
+ )
)
)
- )
- media = raw.types.InputMediaPhoto(
- id=raw.types.InputPhoto(
- id=media.photo.id,
- access_hash=media.photo.access_hash,
- file_reference=media.photo.file_reference
+ media = raw.types.InputMediaDocument(
+ id=raw.types.InputDocument(
+ id=media.document.id,
+ access_hash=media.document.access_hash,
+ file_reference=media.document.file_reference
+ )
)
- )
+ elif re.match("^https?://", i.media):
+ media = await self.send(
+ raw.functions.messages.UploadMedia(
+ peer=await self.resolve_peer(chat_id),
+ media=raw.types.InputMediaDocumentExternal(
+ url=i.media
+ )
+ )
+ )
+
+ media = raw.types.InputMediaDocument(
+ id=raw.types.InputDocument(
+ id=media.document.id,
+ access_hash=media.document.access_hash,
+ file_reference=media.document.file_reference
+ )
+ )
+ else:
+ media = utils.get_input_media_from_file_id(i.media, FileType.VIDEO)
else:
- media = utils.get_input_media_from_file_id(i.media, FileType.PHOTO)
- elif isinstance(i, types.InputMediaVideo):
- if os.path.isfile(i.media) or isinstance(i.media, io.IOBase):
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument(
file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb),
- mime_type=self.guess_mime_type(i.media) or "video/mp4",
+ mime_type=self.guess_mime_type(getattr(i.media, "name", "video.mp4")) or "video/mp4",
attributes=[
raw.types.DocumentAttributeVideo(
supports_streaming=i.supports_streaming or None,
@@ -140,7 +205,7 @@ class SendMediaGroup(Scaffold):
w=i.width,
h=i.height
),
- raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
+ raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "video.mp4"))
]
)
)
@@ -153,32 +218,60 @@ class SendMediaGroup(Scaffold):
file_reference=media.document.file_reference
)
)
- elif re.match("^https?://", i.media):
- media = await self.send(
- raw.functions.messages.UploadMedia(
- peer=await self.resolve_peer(chat_id),
- media=raw.types.InputMediaDocumentExternal(
- url=i.media
+ elif isinstance(i, types.InputMediaAudio):
+ if isinstance(i.media, str):
+ if os.path.isfile(i.media):
+ media = await self.send(
+ raw.functions.messages.UploadMedia(
+ peer=await self.resolve_peer(chat_id),
+ media=raw.types.InputMediaUploadedDocument(
+ mime_type=self.guess_mime_type(i.media) or "audio/mpeg",
+ file=await self.save_file(i.media),
+ thumb=await self.save_file(i.thumb),
+ attributes=[
+ raw.types.DocumentAttributeAudio(
+ duration=i.duration,
+ performer=i.performer,
+ title=i.title
+ ),
+ raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
+ ]
+ )
)
)
- )
- media = raw.types.InputMediaDocument(
- id=raw.types.InputDocument(
- id=media.document.id,
- access_hash=media.document.access_hash,
- file_reference=media.document.file_reference
+ media = raw.types.InputMediaDocument(
+ id=raw.types.InputDocument(
+ id=media.document.id,
+ access_hash=media.document.access_hash,
+ file_reference=media.document.file_reference
+ )
)
- )
+ elif re.match("^https?://", i.media):
+ media = await self.send(
+ raw.functions.messages.UploadMedia(
+ peer=await self.resolve_peer(chat_id),
+ media=raw.types.InputMediaDocumentExternal(
+ url=i.media
+ )
+ )
+ )
+
+ media = raw.types.InputMediaDocument(
+ id=raw.types.InputDocument(
+ id=media.document.id,
+ access_hash=media.document.access_hash,
+ file_reference=media.document.file_reference
+ )
+ )
+ else:
+ media = utils.get_input_media_from_file_id(i.media, FileType.AUDIO)
else:
- media = utils.get_input_media_from_file_id(i.media, FileType.VIDEO)
- elif isinstance(i, types.InputMediaAudio):
- if os.path.isfile(i.media):
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument(
- mime_type=self.guess_mime_type(i.media) or "audio/mpeg",
+ mime_type=self.guess_mime_type(getattr(i.media, "name", "audio.mp3")) or "audio/mpeg",
file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb),
attributes=[
@@ -187,7 +280,7 @@ class SendMediaGroup(Scaffold):
performer=i.performer,
title=i.title
),
- raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
+ raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "audio.mp3"))
]
)
)
@@ -200,36 +293,61 @@ class SendMediaGroup(Scaffold):
file_reference=media.document.file_reference
)
)
- elif re.match("^https?://", i.media):
- media = await self.send(
- raw.functions.messages.UploadMedia(
- peer=await self.resolve_peer(chat_id),
- media=raw.types.InputMediaDocumentExternal(
- url=i.media
+ elif isinstance(i, types.InputMediaDocument):
+ if isinstance(i.media, str):
+ if os.path.isfile(i.media):
+ media = await self.send(
+ raw.functions.messages.UploadMedia(
+ peer=await self.resolve_peer(chat_id),
+ media=raw.types.InputMediaUploadedDocument(
+ mime_type=self.guess_mime_type(i.media) or "application/zip",
+ file=await self.save_file(i.media),
+ thumb=await self.save_file(i.thumb),
+ attributes=[
+ raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
+ ]
+ )
)
)
- )
- media = raw.types.InputMediaDocument(
- id=raw.types.InputDocument(
- id=media.document.id,
- access_hash=media.document.access_hash,
- file_reference=media.document.file_reference
+ media = raw.types.InputMediaDocument(
+ id=raw.types.InputDocument(
+ id=media.document.id,
+ access_hash=media.document.access_hash,
+ file_reference=media.document.file_reference
+ )
)
- )
+ elif re.match("^https?://", i.media):
+ media = await self.send(
+ raw.functions.messages.UploadMedia(
+ peer=await self.resolve_peer(chat_id),
+ media=raw.types.InputMediaDocumentExternal(
+ url=i.media
+ )
+ )
+ )
+
+ media = raw.types.InputMediaDocument(
+ id=raw.types.InputDocument(
+ id=media.document.id,
+ access_hash=media.document.access_hash,
+ file_reference=media.document.file_reference
+ )
+ )
+ else:
+ media = utils.get_input_media_from_file_id(i.media, FileType.DOCUMENT)
else:
- media = utils.get_input_media_from_file_id(i.media, FileType.AUDIO)
- elif isinstance(i, types.InputMediaDocument):
- if os.path.isfile(i.media):
media = await self.send(
raw.functions.messages.UploadMedia(
peer=await self.resolve_peer(chat_id),
media=raw.types.InputMediaUploadedDocument(
- mime_type=self.guess_mime_type(i.media) or "application/zip",
+ mime_type=self.guess_mime_type(
+ getattr(i.media, "name", "file.zip")
+ ) or "application/zip",
file=await self.save_file(i.media),
thumb=await self.save_file(i.thumb),
attributes=[
- raw.types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
+ raw.types.DocumentAttributeFilename(file_name=getattr(i.media, "name", "file.zip"))
]
)
)
@@ -242,25 +360,8 @@ class SendMediaGroup(Scaffold):
file_reference=media.document.file_reference
)
)
- elif re.match("^https?://", i.media):
- media = await self.send(
- raw.functions.messages.UploadMedia(
- peer=await self.resolve_peer(chat_id),
- media=raw.types.InputMediaDocumentExternal(
- url=i.media
- )
- )
- )
-
- media = raw.types.InputMediaDocument(
- id=raw.types.InputDocument(
- id=media.document.id,
- access_hash=media.document.access_hash,
- file_reference=media.document.file_reference
- )
- )
- else:
- media = utils.get_input_media_from_file_id(i.media, FileType.DOCUMENT)
+ else:
+ raise ValueError(f"{i.__class__.__name__} is not a supported type for send_media_group")
multi_media.append(
raw.types.InputSingleMedia(
diff --git a/pyrogram/types/input_media/input_media_animation.py b/pyrogram/types/input_media/input_media_animation.py
index e39a148e..9c5767fe 100644
--- a/pyrogram/types/input_media/input_media_animation.py
+++ b/pyrogram/types/input_media/input_media_animation.py
@@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-from typing import Optional, List
+from typing import Optional, List, Union, BinaryIO
from .input_media import InputMedia
from ..messages_and_media import MessageEntity
@@ -30,6 +30,7 @@ class InputMediaAnimation(InputMedia):
Animation to send.
Pass a file_id as string to send a file that exists on the Telegram servers or
pass a file path as string to upload a new file that exists on your local machine or
+ pass a binary file-like object with its attribute “.name” set for in-memory uploads or
pass an HTTP URL as a string for Telegram to get an animation from the Internet.
thumb (``str``, *optional*):
@@ -64,7 +65,7 @@ class InputMediaAnimation(InputMedia):
def __init__(
self,
- media: str,
+ media: Union[str, BinaryIO],
thumb: str = None,
caption: str = "",
parse_mode: Optional[str] = object,
diff --git a/pyrogram/types/input_media/input_media_audio.py b/pyrogram/types/input_media/input_media_audio.py
index e47bc818..3a66b5ee 100644
--- a/pyrogram/types/input_media/input_media_audio.py
+++ b/pyrogram/types/input_media/input_media_audio.py
@@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-from typing import Optional, List
+from typing import Optional, List, BinaryIO, Union
from .input_media import InputMedia
from ..messages_and_media import MessageEntity
@@ -32,6 +32,7 @@ class InputMediaAudio(InputMedia):
Audio to send.
Pass a file_id as string to send an audio that exists on the Telegram servers or
pass a file path as string to upload a new audio that exists on your local machine or
+ pass a binary file-like object with its attribute “.name” set for in-memory uploads or
pass an HTTP URL as a string for Telegram to get an audio file from the Internet.
thumb (``str``, *optional*):
@@ -66,7 +67,7 @@ class InputMediaAudio(InputMedia):
def __init__(
self,
- media: str,
+ media: Union[str, BinaryIO],
thumb: str = None,
caption: str = "",
parse_mode: Optional[str] = object,
diff --git a/pyrogram/types/input_media/input_media_document.py b/pyrogram/types/input_media/input_media_document.py
index 12b9fc94..31460060 100644
--- a/pyrogram/types/input_media/input_media_document.py
+++ b/pyrogram/types/input_media/input_media_document.py
@@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-from typing import Optional, List
+from typing import Optional, List, Union, BinaryIO
from .input_media import InputMedia
from ..messages_and_media import MessageEntity
@@ -30,6 +30,7 @@ class InputMediaDocument(InputMedia):
File to send.
Pass a file_id as string to send a file that exists on the Telegram servers or
pass a file path as string to upload a new file that exists on your local machine or
+ pass a binary file-like object with its attribute “.name” set for in-memory uploads or
pass an HTTP URL as a string for Telegram to get a file from the Internet.
thumb (``str``):
@@ -55,7 +56,7 @@ class InputMediaDocument(InputMedia):
def __init__(
self,
- media: str,
+ media: Union[str, BinaryIO],
thumb: str = None,
caption: str = "",
parse_mode: Optional[str] = object,
diff --git a/pyrogram/types/input_media/input_media_video.py b/pyrogram/types/input_media/input_media_video.py
index b0c36830..1199d862 100644
--- a/pyrogram/types/input_media/input_media_video.py
+++ b/pyrogram/types/input_media/input_media_video.py
@@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-from typing import Optional, List
+from typing import Optional, List, Union, BinaryIO
from .input_media import InputMedia
from ..messages_and_media import MessageEntity
@@ -31,6 +31,7 @@ class InputMediaVideo(InputMedia):
Video to send.
Pass a file_id as string to send a video that exists on the Telegram servers or
pass a file path as string to upload a new video that exists on your local machine or
+ pass a binary file-like object with its attribute “.name” set for in-memory uploads or
pass an HTTP URL as a string for Telegram to get a video from the Internet.
thumb (``str``):
@@ -68,7 +69,7 @@ class InputMediaVideo(InputMedia):
def __init__(
self,
- media: str,
+ media: Union[str, BinaryIO],
thumb: str = None,
caption: str = "",
parse_mode: Optional[str] = object,