diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py
index 0e186d85..b68bd891 100644
--- a/pyrogram/client/client.py
+++ b/pyrogram/client/client.py
@@ -15,7 +15,7 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-
+import io
import logging
import math
import os
@@ -1231,9 +1231,9 @@ class Client(Methods, BaseClient):
temp_file_path = ""
final_file_path = ""
-
+ path = [None]
try:
- data, directory, file_name, done, progress, progress_args, path = packet
+ data, done, progress, progress_args, out, path, to_file = packet
temp_file_path = self.get_file(
media_type=data.media_type,
@@ -1250,13 +1250,15 @@ class Client(Methods, BaseClient):
file_size=data.file_size,
is_big=data.is_big,
progress=progress,
- progress_args=progress_args
+ progress_args=progress_args,
+ out=out
)
-
- if temp_file_path:
- final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
- os.makedirs(directory, exist_ok=True)
- shutil.move(temp_file_path, final_file_path)
+ if to_file:
+ final_file_path = out.name
+ else:
+ final_file_path = ''
+ if to_file:
+ out.close()
except Exception as e:
log.error(e, exc_info=True)
@@ -1864,7 +1866,8 @@ class Client(Methods, BaseClient):
file_size: int,
is_big: bool,
progress: callable,
- progress_args: tuple = ()
+ progress_args: tuple = (),
+ out: io.IOBase = None
) -> str:
with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
@@ -1950,7 +1953,10 @@ class Client(Methods, BaseClient):
limit = 1024 * 1024
offset = 0
file_name = ""
-
+ if not out:
+ f = tempfile.NamedTemporaryFile("wb", delete=False)
+ else:
+ f = out
try:
r = session.send(
functions.upload.GetFile(
@@ -1961,36 +1967,37 @@ class Client(Methods, BaseClient):
)
if isinstance(r, types.upload.File):
- with tempfile.NamedTemporaryFile("wb", delete=False) as f:
+ if hasattr(f, "name"):
file_name = f.name
- while True:
- chunk = r.bytes
+ while True:
+ chunk = r.bytes
- if not chunk:
- break
+ if not chunk:
+ break
- f.write(chunk)
+ f.write(chunk)
- offset += limit
+ offset += limit
- if progress:
- progress(
- min(offset, file_size)
- if file_size != 0
- else offset,
- file_size,
- *progress_args
- )
+ if progress:
+ progress(
- r = session.send(
- functions.upload.GetFile(
- location=location,
- offset=offset,
- limit=limit
- )
+ min(offset, file_size)
+ if file_size != 0
+ else offset,
+ file_size,
+ *progress_args
)
+ r = session.send(
+ functions.upload.GetFile(
+ location=location,
+ offset=offset,
+ limit=limit
+ )
+ )
+
elif isinstance(r, types.upload.FileCdnRedirect):
with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
@@ -2003,70 +2010,71 @@ class Client(Methods, BaseClient):
self.media_sessions[r.dc_id] = cdn_session
try:
- with tempfile.NamedTemporaryFile("wb", delete=False) as f:
+ if hasattr(f, "name"):
file_name = f.name
- while True:
- r2 = cdn_session.send(
- functions.upload.GetCdnFile(
- file_token=r.file_token,
- offset=offset,
- limit=limit
- )
+ while True:
+ r2 = cdn_session.send(
+ functions.upload.GetCdnFile(
+ file_token=r.file_token,
+ offset=offset,
+ limit=limit
)
+ )
- if isinstance(r2, types.upload.CdnFileReuploadNeeded):
- try:
- session.send(
- functions.upload.ReuploadCdnFile(
- file_token=r.file_token,
- request_token=r2.request_token
- )
+ if isinstance(r2, types.upload.CdnFileReuploadNeeded):
+ try:
+ session.send(
+ functions.upload.ReuploadCdnFile(
+ file_token=r.file_token,
+ request_token=r2.request_token
)
- except VolumeLocNotFound:
- break
- else:
- continue
-
- chunk = r2.bytes
-
- # https://core.telegram.org/cdn#decrypting-files
- decrypted_chunk = AES.ctr256_decrypt(
- chunk,
- r.encryption_key,
- bytearray(
- r.encryption_iv[:-4]
- + (offset // 16).to_bytes(4, "big")
)
- )
-
- hashes = session.send(
- functions.upload.GetCdnFileHashes(
- file_token=r.file_token,
- offset=offset
- )
- )
-
- # https://core.telegram.org/cdn#verifying-files
- for i, h in enumerate(hashes):
- cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
- assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
-
- f.write(decrypted_chunk)
-
- offset += limit
-
- if progress:
- progress(
- min(offset, file_size)
- if file_size != 0
- else offset,
- file_size,
- *progress_args
- )
-
- if len(chunk) < limit:
+ except VolumeLocNotFound:
break
+ else:
+ continue
+
+ chunk = r2.bytes
+
+ # https://core.telegram.org/cdn#decrypting-files
+ decrypted_chunk = AES.ctr256_decrypt(
+ chunk,
+ r.encryption_key,
+ bytearray(
+ r.encryption_iv[:-4]
+ + (offset // 16).to_bytes(4, "big")
+ )
+ )
+
+ hashes = session.send(
+ functions.upload.GetCdnFileHashes(
+ file_token=r.file_token,
+ offset=offset
+ )
+ )
+
+ # https://core.telegram.org/cdn#verifying-files
+ for i, h in enumerate(hashes):
+ cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
+ assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
+
+ f.write(decrypted_chunk)
+
+ offset += limit
+
+ if progress:
+ progress(
+
+ min(offset, file_size)
+ if file_size != 0
+ else offset,
+ file_size,
+ *progress_args
+ )
+
+ if len(chunk) < limit:
+ break
except Exception as e:
raise e
except Exception as e:
@@ -2074,7 +2082,8 @@ class Client(Methods, BaseClient):
log.error(e, exc_info=True)
try:
- os.remove(file_name)
+ if out:
+ os.remove(file_name)
except OSError:
pass
diff --git a/pyrogram/client/methods/messages/download_media.py b/pyrogram/client/methods/messages/download_media.py
index 22054397..2176e4aa 100644
--- a/pyrogram/client/methods/messages/download_media.py
+++ b/pyrogram/client/methods/messages/download_media.py
@@ -17,7 +17,9 @@
# along with Pyrogram. If not, see .
import binascii
+import io
import os
+import re
import struct
import time
from datetime import datetime
@@ -37,6 +39,7 @@ class DownloadMedia(BaseClient):
message: Union["pyrogram.Message", str],
file_ref: str = None,
file_name: str = DEFAULT_DOWNLOAD_DIR,
+ out: io.IOBase = None,
block: bool = True,
progress: callable = None,
progress_args: tuple = ()
@@ -58,6 +61,9 @@ class DownloadMedia(BaseClient):
You can also specify a path for downloading files in a custom location: paths that end with "/"
are considered directories. All non-existent folders will be created automatically.
+ out (``io.IOBase``, *optional*):
+ A custom *file-like object* to be used when downloading file. Overrides file_name
+
block (``bool``, *optional*):
Blocks the code execution until the file has been downloaded.
Defaults to True.
@@ -238,6 +244,13 @@ class DownloadMedia(BaseClient):
extension
)
+ if not out:
+ out = open(os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))), 'wb')
+ os.makedirs(directory, exist_ok=True)
+ to_file = True
+ else:
+ to_file = False
+ self.download_queue.put((data, done, progress, progress_args, out, path, to_file))
# Cast to string because Path objects aren't supported by Python 3.5
self.download_queue.put((data, str(directory), str(file_name), done, progress, progress_args, path))
diff --git a/pyrogram/client/types/messages_and_media/message.py b/pyrogram/client/types/messages_and_media/message.py
index 215f86d0..cff8c578 100644
--- a/pyrogram/client/types/messages_and_media/message.py
+++ b/pyrogram/client/types/messages_and_media/message.py
@@ -15,7 +15,7 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-
+import io
from functools import partial
from typing import List, Match, Union
@@ -2964,7 +2964,7 @@ class Message(Object, Update):
chat_id=message.chat.id,
message_id=message_id,
)
-
+
Example:
.. code-block:: python
@@ -2985,6 +2985,7 @@ class Message(Object, Update):
def download(
self,
file_name: str = "",
+ out: io.IOBase = None,
block: bool = True,
progress: callable = None,
progress_args: tuple = ()
@@ -3009,6 +3010,9 @@ class Message(Object, Update):
You can also specify a path for downloading files in a custom location: paths that end with "/"
are considered directories. All non-existent folders will be created automatically.
+ out (``io.IOBase``, *optional*):
+ A custom *file-like object* to be used when downloading file. Overrides file_name
+
block (``bool``, *optional*):
Blocks the code execution until the file has been downloaded.
Defaults to True.
@@ -3045,6 +3049,7 @@ class Message(Object, Update):
return self._client.download_media(
message=self,
file_name=file_name,
+ out=out,
block=block,
progress=progress,
progress_args=progress_args,
@@ -3074,7 +3079,7 @@ class Message(Object, Update):
Parameters:
option (``int``):
Index of the poll option you want to vote for (0 to 9).
-
+
Returns:
:obj:`Poll`: On success, the poll with the chosen option is returned.