mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-16 12:51:18 +00:00
Add support for in-memory downloads
This commit is contained in:
parent
0d054fa9bc
commit
01ca652f8c
@ -29,10 +29,10 @@ import tempfile
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from hashlib import sha256
|
||||
from importlib import import_module
|
||||
from io import StringIO
|
||||
from io import StringIO, BytesIO
|
||||
from mimetypes import MimeTypes
|
||||
from pathlib import Path
|
||||
from typing import Union, List, Optional, Callable
|
||||
from typing import Union, List, Optional, Callable, BinaryIO
|
||||
|
||||
import pyrogram
|
||||
from pyrogram import __version__, __license__
|
||||
@ -482,34 +482,6 @@ class Client(Methods):
|
||||
|
||||
return is_min
|
||||
|
||||
async def handle_download(self, packet):
|
||||
temp_file_path = ""
|
||||
final_file_path = ""
|
||||
|
||||
try:
|
||||
file_id, directory, file_name, file_size, progress, progress_args = packet
|
||||
|
||||
temp_file_path = await self.get_file(
|
||||
file_id=file_id,
|
||||
file_size=file_size,
|
||||
progress=progress,
|
||||
progress_args=progress_args
|
||||
)
|
||||
|
||||
if temp_file_path:
|
||||
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
shutil.move(temp_file_path, final_file_path)
|
||||
except Exception as e:
|
||||
log.error(e, exc_info=True)
|
||||
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
return final_file_path or None
|
||||
|
||||
async def handle_updates(self, updates):
|
||||
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
|
||||
is_min = (await self.fetch_peers(updates.users)) or (await self.fetch_peers(updates.chats))
|
||||
@ -747,13 +719,41 @@ class Client(Methods):
|
||||
else:
|
||||
log.warning(f'[{self.session_name}] No plugin loaded from "{root}"')
|
||||
|
||||
async def handle_download(self, packet):
|
||||
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
|
||||
|
||||
file = await self.get_file(
|
||||
file_id=file_id,
|
||||
file_size=file_size,
|
||||
in_memory=in_memory,
|
||||
progress=progress,
|
||||
progress_args=progress_args
|
||||
)
|
||||
|
||||
if file and not in_memory:
|
||||
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
shutil.move(file.name, file_path)
|
||||
|
||||
try:
|
||||
file.close()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return file_path
|
||||
|
||||
if file and in_memory:
|
||||
file.name = file_name
|
||||
return file
|
||||
|
||||
async def get_file(
|
||||
self,
|
||||
file_id: FileId,
|
||||
file_size: int,
|
||||
in_memory: bool,
|
||||
progress: Callable,
|
||||
progress_args: tuple = ()
|
||||
) -> str:
|
||||
) -> Optional[BinaryIO]:
|
||||
dc_id = file_id.dc_id
|
||||
|
||||
async with self.media_sessions_lock:
|
||||
@ -838,7 +838,8 @@ class Client(Methods):
|
||||
|
||||
limit = 1024 * 1024
|
||||
offset = 0
|
||||
file_name = ""
|
||||
|
||||
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb")
|
||||
|
||||
try:
|
||||
r = await session.invoke(
|
||||
@ -851,13 +852,10 @@ class Client(Methods):
|
||||
)
|
||||
|
||||
if isinstance(r, raw.types.upload.File):
|
||||
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
|
||||
file_name = f.name
|
||||
|
||||
while True:
|
||||
chunk = r.bytes
|
||||
|
||||
f.write(chunk)
|
||||
file.write(chunk)
|
||||
|
||||
offset += limit
|
||||
|
||||
@ -903,9 +901,6 @@ class Client(Methods):
|
||||
self.media_sessions[r.dc_id] = cdn_session
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
|
||||
file_name = f.name
|
||||
|
||||
while True:
|
||||
r2 = await cdn_session.invoke(
|
||||
raw.functions.upload.GetCdnFile(
|
||||
@ -952,7 +947,7 @@ class Client(Methods):
|
||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
|
||||
|
||||
f.write(decrypted_chunk)
|
||||
file.write(decrypted_chunk)
|
||||
|
||||
offset += limit
|
||||
|
||||
@ -977,14 +972,11 @@ class Client(Methods):
|
||||
if not isinstance(e, pyrogram.StopTransmission):
|
||||
log.error(e, exc_info=True)
|
||||
|
||||
try:
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
pass
|
||||
file.close()
|
||||
|
||||
return ""
|
||||
return None
|
||||
else:
|
||||
return file_name
|
||||
return file
|
||||
|
||||
def guess_mime_type(self, filename: str) -> Optional[str]:
|
||||
return self.mimetypes.guess_type(filename)[0]
|
||||
|
@ -18,9 +18,8 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Union, Optional, Callable
|
||||
from typing import Union, Optional, Callable, BinaryIO
|
||||
|
||||
import pyrogram
|
||||
from pyrogram import types
|
||||
@ -34,10 +33,11 @@ class DownloadMedia:
|
||||
self: "pyrogram.Client",
|
||||
message: Union["types.Message", str],
|
||||
file_name: str = DEFAULT_DOWNLOAD_DIR,
|
||||
in_memory: bool = False,
|
||||
block: bool = True,
|
||||
progress: Callable = None,
|
||||
progress_args: tuple = ()
|
||||
) -> Optional[str]:
|
||||
) -> Optional[Union[str, BinaryIO]]:
|
||||
"""Download the media from a message.
|
||||
|
||||
Parameters:
|
||||
@ -51,6 +51,11 @@ class DownloadMedia:
|
||||
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
||||
are considered directories. All non-existent folders will be created automatically.
|
||||
|
||||
in_memory (``bool``, *optional*):
|
||||
Pass True to download the media in-memory.
|
||||
A binary file-like object with its attribute ".name" set will be returned.
|
||||
Defaults to False.
|
||||
|
||||
block (``bool``, *optional*):
|
||||
Blocks the code execution until the file has been downloaded.
|
||||
Defaults to True.
|
||||
@ -78,14 +83,17 @@ class DownloadMedia:
|
||||
You can either keep ``*args`` or add every single extra argument in your function signature.
|
||||
|
||||
Returns:
|
||||
``str`` | ``None``: On success, the absolute path of the downloaded file is returned, otherwise, in case
|
||||
the download failed or was deliberately stopped with :meth:`~pyrogram.Client.stop_transmission`, None is
|
||||
returned.
|
||||
``str`` | ``None`` | ``BinaryIO``: On success, the absolute path of the downloaded file is returned,
|
||||
otherwise, in case the download failed or was deliberately stopped with
|
||||
:meth:`~pyrogram.Client.stop_transmission`, None is returned.
|
||||
Otherwise, in case ``in_memory=True``, a binary file-like object with its attribute ".name" set is returned.
|
||||
|
||||
Raises:
|
||||
ValueError: if the message doesn't contain any downloadable media
|
||||
|
||||
Example:
|
||||
Download media to file
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Download from Message
|
||||
@ -99,6 +107,15 @@ class DownloadMedia:
|
||||
print(f"{current * 100 / total:.1f}%")
|
||||
|
||||
await app.download_media(message, progress=progress)
|
||||
|
||||
Download media in-memory
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
file = await app.download_media(message, in_memory=True)
|
||||
|
||||
file_name = file.name
|
||||
file_bytes = bytes(file.getbuffer())
|
||||
"""
|
||||
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note",
|
||||
"new_chat_photo")
|
||||
@ -125,7 +142,7 @@ class DownloadMedia:
|
||||
media_file_name = getattr(media, "file_name", "")
|
||||
file_size = getattr(media, "file_size", 0)
|
||||
mime_type = getattr(media, "mime_type", "")
|
||||
date = getattr(media, "date", 0)
|
||||
date = getattr(media, "date", None)
|
||||
|
||||
directory, file_name = os.path.split(file_name)
|
||||
file_name = file_name or media_file_name or ""
|
||||
@ -153,12 +170,14 @@ class DownloadMedia:
|
||||
|
||||
file_name = "{}_{}_{}{}".format(
|
||||
FileType(file_id_obj.file_type).name.lower(),
|
||||
datetime.fromtimestamp(date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
|
||||
(date or datetime.now()).strftime("%Y-%m-%d_%H-%M-%S"),
|
||||
self.rnd_id(),
|
||||
extension
|
||||
)
|
||||
|
||||
downloader = self.handle_download((file_id_obj, directory, file_name, file_size, progress, progress_args))
|
||||
downloader = self.handle_download(
|
||||
(file_id_obj, directory, file_name, in_memory, file_size, progress, progress_args)
|
||||
)
|
||||
|
||||
if block:
|
||||
return await downloader
|
||||
|
@ -3329,6 +3329,7 @@ class Message(Object, Update):
|
||||
async def download(
|
||||
self,
|
||||
file_name: str = "",
|
||||
in_memory: bool = False,
|
||||
block: bool = True,
|
||||
progress: Callable = None,
|
||||
progress_args: tuple = ()
|
||||
@ -3353,6 +3354,11 @@ class Message(Object, Update):
|
||||
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
||||
are considered directories. All non-existent folders will be created automatically.
|
||||
|
||||
in_memory (``bool``, *optional*):
|
||||
Pass True to download the media in-memory.
|
||||
A binary file-like object with its attribute ".name" set will be returned.
|
||||
Defaults to False.
|
||||
|
||||
block (``bool``, *optional*):
|
||||
Blocks the code execution until the file has been downloaded.
|
||||
Defaults to True.
|
||||
@ -3389,6 +3395,7 @@ class Message(Object, Update):
|
||||
return await self._client.download_media(
|
||||
message=self,
|
||||
file_name=file_name,
|
||||
in_memory=in_memory,
|
||||
block=block,
|
||||
progress=progress,
|
||||
progress_args=progress_args,
|
||||
|
Loading…
Reference in New Issue
Block a user