mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-16 20:59:29 +00:00
Add support for in-memory downloads
This commit is contained in:
parent
0d054fa9bc
commit
01ca652f8c
@ -29,10 +29,10 @@ import tempfile
|
|||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from io import StringIO
|
from io import StringIO, BytesIO
|
||||||
from mimetypes import MimeTypes
|
from mimetypes import MimeTypes
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, List, Optional, Callable
|
from typing import Union, List, Optional, Callable, BinaryIO
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
from pyrogram import __version__, __license__
|
from pyrogram import __version__, __license__
|
||||||
@ -482,34 +482,6 @@ class Client(Methods):
|
|||||||
|
|
||||||
return is_min
|
return is_min
|
||||||
|
|
||||||
async def handle_download(self, packet):
|
|
||||||
temp_file_path = ""
|
|
||||||
final_file_path = ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_id, directory, file_name, file_size, progress, progress_args = packet
|
|
||||||
|
|
||||||
temp_file_path = await self.get_file(
|
|
||||||
file_id=file_id,
|
|
||||||
file_size=file_size,
|
|
||||||
progress=progress,
|
|
||||||
progress_args=progress_args
|
|
||||||
)
|
|
||||||
|
|
||||||
if temp_file_path:
|
|
||||||
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
|
|
||||||
os.makedirs(directory, exist_ok=True)
|
|
||||||
shutil.move(temp_file_path, final_file_path)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(e, exc_info=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.remove(temp_file_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
return final_file_path or None
|
|
||||||
|
|
||||||
async def handle_updates(self, updates):
|
async def handle_updates(self, updates):
|
||||||
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
|
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
|
||||||
is_min = (await self.fetch_peers(updates.users)) or (await self.fetch_peers(updates.chats))
|
is_min = (await self.fetch_peers(updates.users)) or (await self.fetch_peers(updates.chats))
|
||||||
@ -747,13 +719,41 @@ class Client(Methods):
|
|||||||
else:
|
else:
|
||||||
log.warning(f'[{self.session_name}] No plugin loaded from "{root}"')
|
log.warning(f'[{self.session_name}] No plugin loaded from "{root}"')
|
||||||
|
|
||||||
|
async def handle_download(self, packet):
|
||||||
|
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
|
||||||
|
|
||||||
|
file = await self.get_file(
|
||||||
|
file_id=file_id,
|
||||||
|
file_size=file_size,
|
||||||
|
in_memory=in_memory,
|
||||||
|
progress=progress,
|
||||||
|
progress_args=progress_args
|
||||||
|
)
|
||||||
|
|
||||||
|
if file and not in_memory:
|
||||||
|
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
shutil.move(file.name, file_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
file.close()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
if file and in_memory:
|
||||||
|
file.name = file_name
|
||||||
|
return file
|
||||||
|
|
||||||
async def get_file(
|
async def get_file(
|
||||||
self,
|
self,
|
||||||
file_id: FileId,
|
file_id: FileId,
|
||||||
file_size: int,
|
file_size: int,
|
||||||
|
in_memory: bool,
|
||||||
progress: Callable,
|
progress: Callable,
|
||||||
progress_args: tuple = ()
|
progress_args: tuple = ()
|
||||||
) -> str:
|
) -> Optional[BinaryIO]:
|
||||||
dc_id = file_id.dc_id
|
dc_id = file_id.dc_id
|
||||||
|
|
||||||
async with self.media_sessions_lock:
|
async with self.media_sessions_lock:
|
||||||
@ -838,7 +838,8 @@ class Client(Methods):
|
|||||||
|
|
||||||
limit = 1024 * 1024
|
limit = 1024 * 1024
|
||||||
offset = 0
|
offset = 0
|
||||||
file_name = ""
|
|
||||||
|
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = await session.invoke(
|
r = await session.invoke(
|
||||||
@ -851,13 +852,10 @@ class Client(Methods):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(r, raw.types.upload.File):
|
if isinstance(r, raw.types.upload.File):
|
||||||
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
|
|
||||||
file_name = f.name
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
chunk = r.bytes
|
chunk = r.bytes
|
||||||
|
|
||||||
f.write(chunk)
|
file.write(chunk)
|
||||||
|
|
||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
@ -903,9 +901,6 @@ class Client(Methods):
|
|||||||
self.media_sessions[r.dc_id] = cdn_session
|
self.media_sessions[r.dc_id] = cdn_session
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
|
|
||||||
file_name = f.name
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
r2 = await cdn_session.invoke(
|
r2 = await cdn_session.invoke(
|
||||||
raw.functions.upload.GetCdnFile(
|
raw.functions.upload.GetCdnFile(
|
||||||
@ -952,7 +947,7 @@ class Client(Methods):
|
|||||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||||
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
|
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
|
||||||
|
|
||||||
f.write(decrypted_chunk)
|
file.write(decrypted_chunk)
|
||||||
|
|
||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
@ -977,14 +972,11 @@ class Client(Methods):
|
|||||||
if not isinstance(e, pyrogram.StopTransmission):
|
if not isinstance(e, pyrogram.StopTransmission):
|
||||||
log.error(e, exc_info=True)
|
log.error(e, exc_info=True)
|
||||||
|
|
||||||
try:
|
file.close()
|
||||||
os.remove(file_name)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return ""
|
return None
|
||||||
else:
|
else:
|
||||||
return file_name
|
return file
|
||||||
|
|
||||||
def guess_mime_type(self, filename: str) -> Optional[str]:
|
def guess_mime_type(self, filename: str) -> Optional[str]:
|
||||||
return self.mimetypes.guess_type(filename)[0]
|
return self.mimetypes.guess_type(filename)[0]
|
||||||
|
@ -18,9 +18,8 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Union, Optional, Callable
|
from typing import Union, Optional, Callable, BinaryIO
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
from pyrogram import types
|
from pyrogram import types
|
||||||
@ -34,10 +33,11 @@ class DownloadMedia:
|
|||||||
self: "pyrogram.Client",
|
self: "pyrogram.Client",
|
||||||
message: Union["types.Message", str],
|
message: Union["types.Message", str],
|
||||||
file_name: str = DEFAULT_DOWNLOAD_DIR,
|
file_name: str = DEFAULT_DOWNLOAD_DIR,
|
||||||
|
in_memory: bool = False,
|
||||||
block: bool = True,
|
block: bool = True,
|
||||||
progress: Callable = None,
|
progress: Callable = None,
|
||||||
progress_args: tuple = ()
|
progress_args: tuple = ()
|
||||||
) -> Optional[str]:
|
) -> Optional[Union[str, BinaryIO]]:
|
||||||
"""Download the media from a message.
|
"""Download the media from a message.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
@ -51,6 +51,11 @@ class DownloadMedia:
|
|||||||
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
||||||
are considered directories. All non-existent folders will be created automatically.
|
are considered directories. All non-existent folders will be created automatically.
|
||||||
|
|
||||||
|
in_memory (``bool``, *optional*):
|
||||||
|
Pass True to download the media in-memory.
|
||||||
|
A binary file-like object with its attribute ".name" set will be returned.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
block (``bool``, *optional*):
|
block (``bool``, *optional*):
|
||||||
Blocks the code execution until the file has been downloaded.
|
Blocks the code execution until the file has been downloaded.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
@ -78,14 +83,17 @@ class DownloadMedia:
|
|||||||
You can either keep ``*args`` or add every single extra argument in your function signature.
|
You can either keep ``*args`` or add every single extra argument in your function signature.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
``str`` | ``None``: On success, the absolute path of the downloaded file is returned, otherwise, in case
|
``str`` | ``None`` | ``BinaryIO``: On success, the absolute path of the downloaded file is returned,
|
||||||
the download failed or was deliberately stopped with :meth:`~pyrogram.Client.stop_transmission`, None is
|
otherwise, in case the download failed or was deliberately stopped with
|
||||||
returned.
|
:meth:`~pyrogram.Client.stop_transmission`, None is returned.
|
||||||
|
Otherwise, in case ``in_memory=True``, a binary file-like object with its attribute ".name" set is returned.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the message doesn't contain any downloadable media
|
ValueError: if the message doesn't contain any downloadable media
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
Download media to file
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Download from Message
|
# Download from Message
|
||||||
@ -99,6 +107,15 @@ class DownloadMedia:
|
|||||||
print(f"{current * 100 / total:.1f}%")
|
print(f"{current * 100 / total:.1f}%")
|
||||||
|
|
||||||
await app.download_media(message, progress=progress)
|
await app.download_media(message, progress=progress)
|
||||||
|
|
||||||
|
Download media in-memory
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
file = await app.download_media(message, in_memory=True)
|
||||||
|
|
||||||
|
file_name = file.name
|
||||||
|
file_bytes = bytes(file.getbuffer())
|
||||||
"""
|
"""
|
||||||
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note",
|
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note",
|
||||||
"new_chat_photo")
|
"new_chat_photo")
|
||||||
@ -125,7 +142,7 @@ class DownloadMedia:
|
|||||||
media_file_name = getattr(media, "file_name", "")
|
media_file_name = getattr(media, "file_name", "")
|
||||||
file_size = getattr(media, "file_size", 0)
|
file_size = getattr(media, "file_size", 0)
|
||||||
mime_type = getattr(media, "mime_type", "")
|
mime_type = getattr(media, "mime_type", "")
|
||||||
date = getattr(media, "date", 0)
|
date = getattr(media, "date", None)
|
||||||
|
|
||||||
directory, file_name = os.path.split(file_name)
|
directory, file_name = os.path.split(file_name)
|
||||||
file_name = file_name or media_file_name or ""
|
file_name = file_name or media_file_name or ""
|
||||||
@ -153,12 +170,14 @@ class DownloadMedia:
|
|||||||
|
|
||||||
file_name = "{}_{}_{}{}".format(
|
file_name = "{}_{}_{}{}".format(
|
||||||
FileType(file_id_obj.file_type).name.lower(),
|
FileType(file_id_obj.file_type).name.lower(),
|
||||||
datetime.fromtimestamp(date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
|
(date or datetime.now()).strftime("%Y-%m-%d_%H-%M-%S"),
|
||||||
self.rnd_id(),
|
self.rnd_id(),
|
||||||
extension
|
extension
|
||||||
)
|
)
|
||||||
|
|
||||||
downloader = self.handle_download((file_id_obj, directory, file_name, file_size, progress, progress_args))
|
downloader = self.handle_download(
|
||||||
|
(file_id_obj, directory, file_name, in_memory, file_size, progress, progress_args)
|
||||||
|
)
|
||||||
|
|
||||||
if block:
|
if block:
|
||||||
return await downloader
|
return await downloader
|
||||||
|
@ -3329,6 +3329,7 @@ class Message(Object, Update):
|
|||||||
async def download(
|
async def download(
|
||||||
self,
|
self,
|
||||||
file_name: str = "",
|
file_name: str = "",
|
||||||
|
in_memory: bool = False,
|
||||||
block: bool = True,
|
block: bool = True,
|
||||||
progress: Callable = None,
|
progress: Callable = None,
|
||||||
progress_args: tuple = ()
|
progress_args: tuple = ()
|
||||||
@ -3353,6 +3354,11 @@ class Message(Object, Update):
|
|||||||
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
||||||
are considered directories. All non-existent folders will be created automatically.
|
are considered directories. All non-existent folders will be created automatically.
|
||||||
|
|
||||||
|
in_memory (``bool``, *optional*):
|
||||||
|
Pass True to download the media in-memory.
|
||||||
|
A binary file-like object with its attribute ".name" set will be returned.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
block (``bool``, *optional*):
|
block (``bool``, *optional*):
|
||||||
Blocks the code execution until the file has been downloaded.
|
Blocks the code execution until the file has been downloaded.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
@ -3389,6 +3395,7 @@ class Message(Object, Update):
|
|||||||
return await self._client.download_media(
|
return await self._client.download_media(
|
||||||
message=self,
|
message=self,
|
||||||
file_name=file_name,
|
file_name=file_name,
|
||||||
|
in_memory=in_memory,
|
||||||
block=block,
|
block=block,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
progress_args=progress_args,
|
progress_args=progress_args,
|
||||||
|
Loading…
Reference in New Issue
Block a user