Add support for in-memory downloads

This commit is contained in:
Dan 2022-04-24 11:56:07 +02:00
parent 0d054fa9bc
commit 01ca652f8c
3 changed files with 160 additions and 142 deletions

View File

@ -29,10 +29,10 @@ import tempfile
from concurrent.futures.thread import ThreadPoolExecutor
from hashlib import sha256
from importlib import import_module
from io import StringIO
from io import StringIO, BytesIO
from mimetypes import MimeTypes
from pathlib import Path
from typing import Union, List, Optional, Callable
from typing import Union, List, Optional, Callable, BinaryIO
import pyrogram
from pyrogram import __version__, __license__
@ -482,34 +482,6 @@ class Client(Methods):
return is_min
async def handle_download(self, packet):
temp_file_path = ""
final_file_path = ""
try:
file_id, directory, file_name, file_size, progress, progress_args = packet
temp_file_path = await self.get_file(
file_id=file_id,
file_size=file_size,
progress=progress,
progress_args=progress_args
)
if temp_file_path:
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
os.makedirs(directory, exist_ok=True)
shutil.move(temp_file_path, final_file_path)
except Exception as e:
log.error(e, exc_info=True)
try:
os.remove(temp_file_path)
except OSError:
pass
else:
return final_file_path or None
async def handle_updates(self, updates):
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
is_min = (await self.fetch_peers(updates.users)) or (await self.fetch_peers(updates.chats))
@ -747,13 +719,41 @@ class Client(Methods):
else:
log.warning(f'[{self.session_name}] No plugin loaded from "{root}"')
async def handle_download(self, packet):
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
file = await self.get_file(
file_id=file_id,
file_size=file_size,
in_memory=in_memory,
progress=progress,
progress_args=progress_args
)
if file and not in_memory:
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
os.makedirs(directory, exist_ok=True)
shutil.move(file.name, file_path)
try:
file.close()
except FileNotFoundError:
pass
return file_path
if file and in_memory:
file.name = file_name
return file
async def get_file(
self,
file_id: FileId,
file_size: int,
in_memory: bool,
progress: Callable,
progress_args: tuple = ()
) -> str:
) -> Optional[BinaryIO]:
dc_id = file_id.dc_id
async with self.media_sessions_lock:
@ -838,7 +838,8 @@ class Client(Methods):
limit = 1024 * 1024
offset = 0
file_name = ""
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb")
try:
r = await session.invoke(
@ -851,43 +852,40 @@ class Client(Methods):
)
if isinstance(r, raw.types.upload.File):
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
while True:
chunk = r.bytes
while True:
chunk = r.bytes
file.write(chunk)
f.write(chunk)
offset += limit
offset += limit
if progress:
func = functools.partial(
progress,
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
if len(chunk) < limit:
break
r = await session.invoke(
raw.functions.upload.GetFile(
location=location,
offset=offset,
limit=limit
),
sleep_threshold=30
if progress:
func = functools.partial(
progress,
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
if len(chunk) < limit:
break
r = await session.invoke(
raw.functions.upload.GetFile(
location=location,
offset=offset,
limit=limit
),
sleep_threshold=30
)
elif isinstance(r, raw.types.upload.FileCdnRedirect):
async with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
@ -903,88 +901,82 @@ class Client(Methods):
self.media_sessions[r.dc_id] = cdn_session
try:
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
while True:
r2 = await cdn_session.invoke(
raw.functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
limit=limit
)
while True:
r2 = await cdn_session.invoke(
raw.functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
limit=limit
)
)
if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded):
try:
await session.invoke(
raw.functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded):
try:
await session.invoke(
raw.functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
except VolumeLocNotFound:
break
else:
continue
chunk = r2.bytes
# https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = aes.ctr256_decrypt(
chunk,
r.encryption_key,
bytearray(
r.encryption_iv[:-4]
+ (offset // 16).to_bytes(4, "big")
)
)
hashes = await session.invoke(
raw.functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset
)
)
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
f.write(decrypted_chunk)
offset += limit
if progress:
func = functools.partial(
progress,
min(offset, file_size) if file_size != 0 else offset,
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
if len(chunk) < limit:
except VolumeLocNotFound:
break
else:
continue
chunk = r2.bytes
# https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = aes.ctr256_decrypt(
chunk,
r.encryption_key,
bytearray(
r.encryption_iv[:-4]
+ (offset // 16).to_bytes(4, "big")
)
)
hashes = await session.invoke(
raw.functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset
)
)
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
file.write(decrypted_chunk)
offset += limit
if progress:
func = functools.partial(
progress,
min(offset, file_size) if file_size != 0 else offset,
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
if len(chunk) < limit:
break
except Exception as e:
raise e
except Exception as e:
if not isinstance(e, pyrogram.StopTransmission):
log.error(e, exc_info=True)
try:
os.remove(file_name)
except OSError:
pass
file.close()
return ""
return None
else:
return file_name
return file
def guess_mime_type(self, filename: str) -> Optional[str]:
return self.mimetypes.guess_type(filename)[0]

View File

@ -18,9 +18,8 @@
import asyncio
import os
import time
from datetime import datetime
from typing import Union, Optional, Callable
from typing import Union, Optional, Callable, BinaryIO
import pyrogram
from pyrogram import types
@ -34,10 +33,11 @@ class DownloadMedia:
self: "pyrogram.Client",
message: Union["types.Message", str],
file_name: str = DEFAULT_DOWNLOAD_DIR,
in_memory: bool = False,
block: bool = True,
progress: Callable = None,
progress_args: tuple = ()
) -> Optional[str]:
) -> Optional[Union[str, BinaryIO]]:
"""Download the media from a message.
Parameters:
@ -51,6 +51,11 @@ class DownloadMedia:
You can also specify a path for downloading files in a custom location: paths that end with "/"
are considered directories. All non-existent folders will be created automatically.
in_memory (``bool``, *optional*):
Pass True to download the media in-memory.
A binary file-like object with its attribute ".name" set will be returned.
Defaults to False.
block (``bool``, *optional*):
Blocks the code execution until the file has been downloaded.
Defaults to True.
@ -78,14 +83,17 @@ class DownloadMedia:
You can either keep ``*args`` or add every single extra argument in your function signature.
Returns:
``str`` | ``None``: On success, the absolute path of the downloaded file is returned, otherwise, in case
the download failed or was deliberately stopped with :meth:`~pyrogram.Client.stop_transmission`, None is
returned.
``str`` | ``None`` | ``BinaryIO``: On success, the absolute path of the downloaded file is returned,
otherwise, in case the download failed or was deliberately stopped with
:meth:`~pyrogram.Client.stop_transmission`, None is returned.
Otherwise, in case ``in_memory=True``, a binary file-like object with its attribute ".name" set is returned.
Raises:
ValueError: if the message doesn't contain any downloadable media
Example:
Download media to file
.. code-block:: python
# Download from Message
@ -99,6 +107,15 @@ class DownloadMedia:
print(f"{current * 100 / total:.1f}%")
await app.download_media(message, progress=progress)
Download media in-memory
.. code-block:: python
file = await app.download_media(message, in_memory=True)
file_name = file.name
file_bytes = bytes(file.getbuffer())
"""
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note",
"new_chat_photo")
@ -125,7 +142,7 @@ class DownloadMedia:
media_file_name = getattr(media, "file_name", "")
file_size = getattr(media, "file_size", 0)
mime_type = getattr(media, "mime_type", "")
date = getattr(media, "date", 0)
date = getattr(media, "date", None)
directory, file_name = os.path.split(file_name)
file_name = file_name or media_file_name or ""
@ -153,12 +170,14 @@ class DownloadMedia:
file_name = "{}_{}_{}{}".format(
FileType(file_id_obj.file_type).name.lower(),
datetime.fromtimestamp(date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
(date or datetime.now()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(),
extension
)
downloader = self.handle_download((file_id_obj, directory, file_name, file_size, progress, progress_args))
downloader = self.handle_download(
(file_id_obj, directory, file_name, in_memory, file_size, progress, progress_args)
)
if block:
return await downloader

View File

@ -3329,6 +3329,7 @@ class Message(Object, Update):
async def download(
self,
file_name: str = "",
in_memory: bool = False,
block: bool = True,
progress: Callable = None,
progress_args: tuple = ()
@ -3353,6 +3354,11 @@ class Message(Object, Update):
You can also specify a path for downloading files in a custom location: paths that end with "/"
are considered directories. All non-existent folders will be created automatically.
in_memory (``bool``, *optional*):
Pass True to download the media in-memory.
A binary file-like object with its attribute ".name" set will be returned.
Defaults to False.
block (``bool``, *optional*):
Blocks the code execution until the file has been downloaded.
Defaults to True.
@ -3389,6 +3395,7 @@ class Message(Object, Update):
return await self._client.download_media(
message=self,
file_name=file_name,
in_memory=in_memory,
block=block,
progress=progress,
progress_args=progress_args,