Add support for in-memory downloads

This commit is contained in:
Dan 2022-04-24 11:56:07 +02:00
parent 0d054fa9bc
commit 01ca652f8c
3 changed files with 160 additions and 142 deletions

View File

@ -29,10 +29,10 @@ import tempfile
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from hashlib import sha256 from hashlib import sha256
from importlib import import_module from importlib import import_module
from io import StringIO from io import StringIO, BytesIO
from mimetypes import MimeTypes from mimetypes import MimeTypes
from pathlib import Path from pathlib import Path
from typing import Union, List, Optional, Callable from typing import Union, List, Optional, Callable, BinaryIO
import pyrogram import pyrogram
from pyrogram import __version__, __license__ from pyrogram import __version__, __license__
@ -482,34 +482,6 @@ class Client(Methods):
return is_min return is_min
async def handle_download(self, packet):
temp_file_path = ""
final_file_path = ""
try:
file_id, directory, file_name, file_size, progress, progress_args = packet
temp_file_path = await self.get_file(
file_id=file_id,
file_size=file_size,
progress=progress,
progress_args=progress_args
)
if temp_file_path:
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
os.makedirs(directory, exist_ok=True)
shutil.move(temp_file_path, final_file_path)
except Exception as e:
log.error(e, exc_info=True)
try:
os.remove(temp_file_path)
except OSError:
pass
else:
return final_file_path or None
async def handle_updates(self, updates): async def handle_updates(self, updates):
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)): if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
is_min = (await self.fetch_peers(updates.users)) or (await self.fetch_peers(updates.chats)) is_min = (await self.fetch_peers(updates.users)) or (await self.fetch_peers(updates.chats))
@ -747,13 +719,41 @@ class Client(Methods):
else: else:
log.warning(f'[{self.session_name}] No plugin loaded from "{root}"') log.warning(f'[{self.session_name}] No plugin loaded from "{root}"')
async def handle_download(self, packet):
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
file = await self.get_file(
file_id=file_id,
file_size=file_size,
in_memory=in_memory,
progress=progress,
progress_args=progress_args
)
if file and not in_memory:
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
os.makedirs(directory, exist_ok=True)
shutil.move(file.name, file_path)
try:
file.close()
except FileNotFoundError:
pass
return file_path
if file and in_memory:
file.name = file_name
return file
async def get_file( async def get_file(
self, self,
file_id: FileId, file_id: FileId,
file_size: int, file_size: int,
in_memory: bool,
progress: Callable, progress: Callable,
progress_args: tuple = () progress_args: tuple = ()
) -> str: ) -> Optional[BinaryIO]:
dc_id = file_id.dc_id dc_id = file_id.dc_id
async with self.media_sessions_lock: async with self.media_sessions_lock:
@ -838,7 +838,8 @@ class Client(Methods):
limit = 1024 * 1024 limit = 1024 * 1024
offset = 0 offset = 0
file_name = ""
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb")
try: try:
r = await session.invoke( r = await session.invoke(
@ -851,13 +852,10 @@ class Client(Methods):
) )
if isinstance(r, raw.types.upload.File): if isinstance(r, raw.types.upload.File):
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
while True: while True:
chunk = r.bytes chunk = r.bytes
f.write(chunk) file.write(chunk)
offset += limit offset += limit
@ -903,9 +901,6 @@ class Client(Methods):
self.media_sessions[r.dc_id] = cdn_session self.media_sessions[r.dc_id] = cdn_session
try: try:
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
while True: while True:
r2 = await cdn_session.invoke( r2 = await cdn_session.invoke(
raw.functions.upload.GetCdnFile( raw.functions.upload.GetCdnFile(
@ -952,7 +947,7 @@ class Client(Methods):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest()) CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
f.write(decrypted_chunk) file.write(decrypted_chunk)
offset += limit offset += limit
@ -977,14 +972,11 @@ class Client(Methods):
if not isinstance(e, pyrogram.StopTransmission): if not isinstance(e, pyrogram.StopTransmission):
log.error(e, exc_info=True) log.error(e, exc_info=True)
try: file.close()
os.remove(file_name)
except OSError:
pass
return "" return None
else: else:
return file_name return file
def guess_mime_type(self, filename: str) -> Optional[str]: def guess_mime_type(self, filename: str) -> Optional[str]:
return self.mimetypes.guess_type(filename)[0] return self.mimetypes.guess_type(filename)[0]

View File

@ -18,9 +18,8 @@
import asyncio import asyncio
import os import os
import time
from datetime import datetime from datetime import datetime
from typing import Union, Optional, Callable from typing import Union, Optional, Callable, BinaryIO
import pyrogram import pyrogram
from pyrogram import types from pyrogram import types
@ -34,10 +33,11 @@ class DownloadMedia:
self: "pyrogram.Client", self: "pyrogram.Client",
message: Union["types.Message", str], message: Union["types.Message", str],
file_name: str = DEFAULT_DOWNLOAD_DIR, file_name: str = DEFAULT_DOWNLOAD_DIR,
in_memory: bool = False,
block: bool = True, block: bool = True,
progress: Callable = None, progress: Callable = None,
progress_args: tuple = () progress_args: tuple = ()
) -> Optional[str]: ) -> Optional[Union[str, BinaryIO]]:
"""Download the media from a message. """Download the media from a message.
Parameters: Parameters:
@ -51,6 +51,11 @@ class DownloadMedia:
You can also specify a path for downloading files in a custom location: paths that end with "/" You can also specify a path for downloading files in a custom location: paths that end with "/"
are considered directories. All non-existent folders will be created automatically. are considered directories. All non-existent folders will be created automatically.
in_memory (``bool``, *optional*):
Pass True to download the media in-memory.
A binary file-like object with its attribute ".name" set will be returned.
Defaults to False.
block (``bool``, *optional*): block (``bool``, *optional*):
Blocks the code execution until the file has been downloaded. Blocks the code execution until the file has been downloaded.
Defaults to True. Defaults to True.
@ -78,14 +83,17 @@ class DownloadMedia:
You can either keep ``*args`` or add every single extra argument in your function signature. You can either keep ``*args`` or add every single extra argument in your function signature.
Returns: Returns:
``str`` | ``None``: On success, the absolute path of the downloaded file is returned, otherwise, in case ``str`` | ``None`` | ``BinaryIO``: On success, the absolute path of the downloaded file is returned,
the download failed or was deliberately stopped with :meth:`~pyrogram.Client.stop_transmission`, None is otherwise, in case the download failed or was deliberately stopped with
returned. :meth:`~pyrogram.Client.stop_transmission`, None is returned.
Otherwise, in case ``in_memory=True``, a binary file-like object with its attribute ".name" set is returned.
Raises: Raises:
ValueError: if the message doesn't contain any downloadable media ValueError: if the message doesn't contain any downloadable media
Example: Example:
Download media to file
.. code-block:: python .. code-block:: python
# Download from Message # Download from Message
@ -99,6 +107,15 @@ class DownloadMedia:
print(f"{current * 100 / total:.1f}%") print(f"{current * 100 / total:.1f}%")
await app.download_media(message, progress=progress) await app.download_media(message, progress=progress)
Download media in-memory
.. code-block:: python
file = await app.download_media(message, in_memory=True)
file_name = file.name
file_bytes = bytes(file.getbuffer())
""" """
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note", available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note",
"new_chat_photo") "new_chat_photo")
@ -125,7 +142,7 @@ class DownloadMedia:
media_file_name = getattr(media, "file_name", "") media_file_name = getattr(media, "file_name", "")
file_size = getattr(media, "file_size", 0) file_size = getattr(media, "file_size", 0)
mime_type = getattr(media, "mime_type", "") mime_type = getattr(media, "mime_type", "")
date = getattr(media, "date", 0) date = getattr(media, "date", None)
directory, file_name = os.path.split(file_name) directory, file_name = os.path.split(file_name)
file_name = file_name or media_file_name or "" file_name = file_name or media_file_name or ""
@ -153,12 +170,14 @@ class DownloadMedia:
file_name = "{}_{}_{}{}".format( file_name = "{}_{}_{}{}".format(
FileType(file_id_obj.file_type).name.lower(), FileType(file_id_obj.file_type).name.lower(),
datetime.fromtimestamp(date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"), (date or datetime.now()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(), self.rnd_id(),
extension extension
) )
downloader = self.handle_download((file_id_obj, directory, file_name, file_size, progress, progress_args)) downloader = self.handle_download(
(file_id_obj, directory, file_name, in_memory, file_size, progress, progress_args)
)
if block: if block:
return await downloader return await downloader

View File

@ -3329,6 +3329,7 @@ class Message(Object, Update):
async def download( async def download(
self, self,
file_name: str = "", file_name: str = "",
in_memory: bool = False,
block: bool = True, block: bool = True,
progress: Callable = None, progress: Callable = None,
progress_args: tuple = () progress_args: tuple = ()
@ -3353,6 +3354,11 @@ class Message(Object, Update):
You can also specify a path for downloading files in a custom location: paths that end with "/" You can also specify a path for downloading files in a custom location: paths that end with "/"
are considered directories. All non-existent folders will be created automatically. are considered directories. All non-existent folders will be created automatically.
in_memory (``bool``, *optional*):
Pass True to download the media in-memory.
A binary file-like object with its attribute ".name" set will be returned.
Defaults to False.
block (``bool``, *optional*): block (``bool``, *optional*):
Blocks the code execution until the file has been downloaded. Blocks the code execution until the file has been downloaded.
Defaults to True. Defaults to True.
@ -3389,6 +3395,7 @@ class Message(Object, Update):
return await self._client.download_media( return await self._client.download_media(
message=self, message=self,
file_name=file_name, file_name=file_name,
in_memory=in_memory,
block=block, block=block,
progress=progress, progress=progress,
progress_args=progress_args, progress_args=progress_args,