diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 0e186d85..b68bd891 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -15,7 +15,7 @@ # # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . - +import io import logging import math import os @@ -1231,9 +1231,9 @@ class Client(Methods, BaseClient): temp_file_path = "" final_file_path = "" - + path = [None] try: - data, directory, file_name, done, progress, progress_args, path = packet + data, done, progress, progress_args, out, path, to_file = packet temp_file_path = self.get_file( media_type=data.media_type, @@ -1250,13 +1250,15 @@ class Client(Methods, BaseClient): file_size=data.file_size, is_big=data.is_big, progress=progress, - progress_args=progress_args + progress_args=progress_args, + out=out ) - - if temp_file_path: - final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) - os.makedirs(directory, exist_ok=True) - shutil.move(temp_file_path, final_file_path) + if to_file: + final_file_path = out.name + else: + final_file_path = '' + if to_file: + out.close() except Exception as e: log.error(e, exc_info=True) @@ -1864,7 +1866,8 @@ class Client(Methods, BaseClient): file_size: int, is_big: bool, progress: callable, - progress_args: tuple = () + progress_args: tuple = (), + out: io.IOBase = None ) -> str: with self.media_sessions_lock: session = self.media_sessions.get(dc_id, None) @@ -1950,7 +1953,10 @@ class Client(Methods, BaseClient): limit = 1024 * 1024 offset = 0 file_name = "" - + if not out: + f = tempfile.NamedTemporaryFile("wb", delete=False) + else: + f = out try: r = session.send( functions.upload.GetFile( @@ -1961,36 +1967,37 @@ class Client(Methods, BaseClient): ) if isinstance(r, types.upload.File): - with tempfile.NamedTemporaryFile("wb", delete=False) as f: + if hasattr(f, "name"): file_name = f.name - while True: - chunk = r.bytes + while True: + chunk = r.bytes - if not chunk: - break + if not chunk: + break - f.write(chunk) + f.write(chunk) - offset += limit + offset += limit - if progress: - progress( - min(offset, file_size) - if file_size != 0 - else offset, - file_size, - *progress_args - ) + if progress: + progress( - r = session.send( - functions.upload.GetFile( - location=location, - offset=offset, - limit=limit - ) + min(offset, file_size) + if file_size != 0 + else offset, + file_size, + *progress_args ) + r = session.send( + functions.upload.GetFile( + location=location, + offset=offset, + limit=limit + ) + ) + elif isinstance(r, types.upload.FileCdnRedirect): with self.media_sessions_lock: cdn_session = self.media_sessions.get(r.dc_id, None) @@ -2003,70 +2010,71 @@ class Client(Methods, BaseClient): self.media_sessions[r.dc_id] = cdn_session try: - with tempfile.NamedTemporaryFile("wb", delete=False) as f: + if hasattr(f, "name"): file_name = f.name - while True: - r2 = cdn_session.send( - functions.upload.GetCdnFile( - file_token=r.file_token, - offset=offset, - limit=limit - ) + while True: + r2 = cdn_session.send( + functions.upload.GetCdnFile( + file_token=r.file_token, + offset=offset, + limit=limit ) + ) - if isinstance(r2, types.upload.CdnFileReuploadNeeded): - try: - session.send( - functions.upload.ReuploadCdnFile( - file_token=r.file_token, - request_token=r2.request_token - ) + if isinstance(r2, types.upload.CdnFileReuploadNeeded): + try: + session.send( + functions.upload.ReuploadCdnFile( + file_token=r.file_token, + request_token=r2.request_token ) - except VolumeLocNotFound: - break - else: - continue - - chunk = r2.bytes - - # https://core.telegram.org/cdn#decrypting-files - decrypted_chunk = AES.ctr256_decrypt( - chunk, - r.encryption_key, - bytearray( - r.encryption_iv[:-4] - + (offset // 16).to_bytes(4, "big") ) - ) - - hashes = session.send( - functions.upload.GetCdnFileHashes( - file_token=r.file_token, - offset=offset - ) - ) - - # https://core.telegram.org/cdn#verifying-files - for i, h in enumerate(hashes): - cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] - assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) - - f.write(decrypted_chunk) - - offset += limit - - if progress: - progress( - min(offset, file_size) - if file_size != 0 - else offset, - file_size, - *progress_args - ) - - if len(chunk) < limit: + except VolumeLocNotFound: break + else: + continue + + chunk = r2.bytes + + # https://core.telegram.org/cdn#decrypting-files + decrypted_chunk = AES.ctr256_decrypt( + chunk, + r.encryption_key, + bytearray( + r.encryption_iv[:-4] + + (offset // 16).to_bytes(4, "big") + ) + ) + + hashes = session.send( + functions.upload.GetCdnFileHashes( + file_token=r.file_token, + offset=offset + ) + ) + + # https://core.telegram.org/cdn#verifying-files + for i, h in enumerate(hashes): + cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] + assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) + + f.write(decrypted_chunk) + + offset += limit + + if progress: + progress( + + min(offset, file_size) + if file_size != 0 + else offset, + file_size, + *progress_args + ) + + if len(chunk) < limit: + break except Exception as e: raise e except Exception as e: @@ -2074,7 +2082,8 @@ class Client(Methods, BaseClient): log.error(e, exc_info=True) try: - os.remove(file_name) + if out: + os.remove(file_name) except OSError: pass diff --git a/pyrogram/client/methods/messages/download_media.py b/pyrogram/client/methods/messages/download_media.py index 22054397..2176e4aa 100644 --- a/pyrogram/client/methods/messages/download_media.py +++ b/pyrogram/client/methods/messages/download_media.py @@ -17,7 +17,9 @@ # along with Pyrogram. If not, see . import binascii +import io import os +import re import struct import time from datetime import datetime @@ -37,6 +39,7 @@ class DownloadMedia(BaseClient): message: Union["pyrogram.Message", str], file_ref: str = None, file_name: str = DEFAULT_DOWNLOAD_DIR, + out: io.IOBase = None, block: bool = True, progress: callable = None, progress_args: tuple = () @@ -58,6 +61,9 @@ class DownloadMedia(BaseClient): You can also specify a path for downloading files in a custom location: paths that end with "/" are considered directories. All non-existent folders will be created automatically. + out (``io.IOBase``, *optional*): + A custom *file-like object* to be used when downloading file. Overrides file_name + block (``bool``, *optional*): Blocks the code execution until the file has been downloaded. Defaults to True. @@ -238,6 +244,13 @@ class DownloadMedia(BaseClient): extension ) + if not out: + out = open(os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))), 'wb') + os.makedirs(directory, exist_ok=True) + to_file = True + else: + to_file = False + self.download_queue.put((data, done, progress, progress_args, out, path, to_file)) # Cast to string because Path objects aren't supported by Python 3.5 self.download_queue.put((data, str(directory), str(file_name), done, progress, progress_args, path)) diff --git a/pyrogram/client/types/messages_and_media/message.py b/pyrogram/client/types/messages_and_media/message.py index 215f86d0..cff8c578 100644 --- a/pyrogram/client/types/messages_and_media/message.py +++ b/pyrogram/client/types/messages_and_media/message.py @@ -15,7 +15,7 @@ # # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . - +import io from functools import partial from typing import List, Match, Union @@ -2964,7 +2964,7 @@ class Message(Object, Update): chat_id=message.chat.id, message_id=message_id, ) - + Example: .. code-block:: python @@ -2985,6 +2985,7 @@ class Message(Object, Update): def download( self, file_name: str = "", + out: io.IOBase = None, block: bool = True, progress: callable = None, progress_args: tuple = () @@ -3009,6 +3010,9 @@ class Message(Object, Update): You can also specify a path for downloading files in a custom location: paths that end with "/" are considered directories. All non-existent folders will be created automatically. + out (``io.IOBase``, *optional*): + A custom *file-like object* to be used when downloading file. Overrides file_name + block (``bool``, *optional*): Blocks the code execution until the file has been downloaded. Defaults to True. @@ -3045,6 +3049,7 @@ class Message(Object, Update): return self._client.download_media( message=self, file_name=file_name, + out=out, block=block, progress=progress, progress_args=progress_args, @@ -3074,7 +3079,7 @@ class Message(Object, Update): Parameters: option (``int``): Index of the poll option you want to vote for (0 to 9). - + Returns: :obj:`Poll`: On success, the poll with the chosen option is returned.