Add support for downloading files to file pointer, fix for https://github.com/pyrogram/pyrogram/issues/284
This commit is contained in:
parent
55d0b93cf0
commit
1e8c9812a1
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -1231,9 +1231,9 @@ class Client(Methods, BaseClient):
|
|||||||
|
|
||||||
temp_file_path = ""
|
temp_file_path = ""
|
||||||
final_file_path = ""
|
final_file_path = ""
|
||||||
|
path = [None]
|
||||||
try:
|
try:
|
||||||
data, directory, file_name, done, progress, progress_args, path = packet
|
data, done, progress, progress_args, out, path, to_file = packet
|
||||||
|
|
||||||
temp_file_path = self.get_file(
|
temp_file_path = self.get_file(
|
||||||
media_type=data.media_type,
|
media_type=data.media_type,
|
||||||
@ -1250,13 +1250,15 @@ class Client(Methods, BaseClient):
|
|||||||
file_size=data.file_size,
|
file_size=data.file_size,
|
||||||
is_big=data.is_big,
|
is_big=data.is_big,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
progress_args=progress_args
|
progress_args=progress_args,
|
||||||
|
out=out
|
||||||
)
|
)
|
||||||
|
if to_file:
|
||||||
if temp_file_path:
|
final_file_path = out.name
|
||||||
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
|
else:
|
||||||
os.makedirs(directory, exist_ok=True)
|
final_file_path = ''
|
||||||
shutil.move(temp_file_path, final_file_path)
|
if to_file:
|
||||||
|
out.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e, exc_info=True)
|
log.error(e, exc_info=True)
|
||||||
|
|
||||||
@ -1864,7 +1866,8 @@ class Client(Methods, BaseClient):
|
|||||||
file_size: int,
|
file_size: int,
|
||||||
is_big: bool,
|
is_big: bool,
|
||||||
progress: callable,
|
progress: callable,
|
||||||
progress_args: tuple = ()
|
progress_args: tuple = (),
|
||||||
|
out: io.IOBase = None
|
||||||
) -> str:
|
) -> str:
|
||||||
with self.media_sessions_lock:
|
with self.media_sessions_lock:
|
||||||
session = self.media_sessions.get(dc_id, None)
|
session = self.media_sessions.get(dc_id, None)
|
||||||
@ -1950,7 +1953,10 @@ class Client(Methods, BaseClient):
|
|||||||
limit = 1024 * 1024
|
limit = 1024 * 1024
|
||||||
offset = 0
|
offset = 0
|
||||||
file_name = ""
|
file_name = ""
|
||||||
|
if not out:
|
||||||
|
f = tempfile.NamedTemporaryFile("wb", delete=False)
|
||||||
|
else:
|
||||||
|
f = out
|
||||||
try:
|
try:
|
||||||
r = session.send(
|
r = session.send(
|
||||||
functions.upload.GetFile(
|
functions.upload.GetFile(
|
||||||
@ -1961,36 +1967,37 @@ class Client(Methods, BaseClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(r, types.upload.File):
|
if isinstance(r, types.upload.File):
|
||||||
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
|
if hasattr(f, "name"):
|
||||||
file_name = f.name
|
file_name = f.name
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
chunk = r.bytes
|
chunk = r.bytes
|
||||||
|
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
|
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
|
|
||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
progress(
|
progress(
|
||||||
min(offset, file_size)
|
|
||||||
if file_size != 0
|
|
||||||
else offset,
|
|
||||||
file_size,
|
|
||||||
*progress_args
|
|
||||||
)
|
|
||||||
|
|
||||||
r = session.send(
|
min(offset, file_size)
|
||||||
functions.upload.GetFile(
|
if file_size != 0
|
||||||
location=location,
|
else offset,
|
||||||
offset=offset,
|
file_size,
|
||||||
limit=limit
|
*progress_args
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
r = session.send(
|
||||||
|
functions.upload.GetFile(
|
||||||
|
location=location,
|
||||||
|
offset=offset,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
elif isinstance(r, types.upload.FileCdnRedirect):
|
elif isinstance(r, types.upload.FileCdnRedirect):
|
||||||
with self.media_sessions_lock:
|
with self.media_sessions_lock:
|
||||||
cdn_session = self.media_sessions.get(r.dc_id, None)
|
cdn_session = self.media_sessions.get(r.dc_id, None)
|
||||||
@ -2003,70 +2010,71 @@ class Client(Methods, BaseClient):
|
|||||||
self.media_sessions[r.dc_id] = cdn_session
|
self.media_sessions[r.dc_id] = cdn_session
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
|
if hasattr(f, "name"):
|
||||||
file_name = f.name
|
file_name = f.name
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
r2 = cdn_session.send(
|
r2 = cdn_session.send(
|
||||||
functions.upload.GetCdnFile(
|
functions.upload.GetCdnFile(
|
||||||
file_token=r.file_token,
|
file_token=r.file_token,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
||||||
try:
|
try:
|
||||||
session.send(
|
session.send(
|
||||||
functions.upload.ReuploadCdnFile(
|
functions.upload.ReuploadCdnFile(
|
||||||
file_token=r.file_token,
|
file_token=r.file_token,
|
||||||
request_token=r2.request_token
|
request_token=r2.request_token
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except VolumeLocNotFound:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk = r2.bytes
|
|
||||||
|
|
||||||
# https://core.telegram.org/cdn#decrypting-files
|
|
||||||
decrypted_chunk = AES.ctr256_decrypt(
|
|
||||||
chunk,
|
|
||||||
r.encryption_key,
|
|
||||||
bytearray(
|
|
||||||
r.encryption_iv[:-4]
|
|
||||||
+ (offset // 16).to_bytes(4, "big")
|
|
||||||
)
|
)
|
||||||
)
|
except VolumeLocNotFound:
|
||||||
|
|
||||||
hashes = session.send(
|
|
||||||
functions.upload.GetCdnFileHashes(
|
|
||||||
file_token=r.file_token,
|
|
||||||
offset=offset
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://core.telegram.org/cdn#verifying-files
|
|
||||||
for i, h in enumerate(hashes):
|
|
||||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
|
||||||
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
|
||||||
|
|
||||||
f.write(decrypted_chunk)
|
|
||||||
|
|
||||||
offset += limit
|
|
||||||
|
|
||||||
if progress:
|
|
||||||
progress(
|
|
||||||
min(offset, file_size)
|
|
||||||
if file_size != 0
|
|
||||||
else offset,
|
|
||||||
file_size,
|
|
||||||
*progress_args
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(chunk) < limit:
|
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = r2.bytes
|
||||||
|
|
||||||
|
# https://core.telegram.org/cdn#decrypting-files
|
||||||
|
decrypted_chunk = AES.ctr256_decrypt(
|
||||||
|
chunk,
|
||||||
|
r.encryption_key,
|
||||||
|
bytearray(
|
||||||
|
r.encryption_iv[:-4]
|
||||||
|
+ (offset // 16).to_bytes(4, "big")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
hashes = session.send(
|
||||||
|
functions.upload.GetCdnFileHashes(
|
||||||
|
file_token=r.file_token,
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://core.telegram.org/cdn#verifying-files
|
||||||
|
for i, h in enumerate(hashes):
|
||||||
|
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||||
|
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
||||||
|
|
||||||
|
f.write(decrypted_chunk)
|
||||||
|
|
||||||
|
offset += limit
|
||||||
|
|
||||||
|
if progress:
|
||||||
|
progress(
|
||||||
|
|
||||||
|
min(offset, file_size)
|
||||||
|
if file_size != 0
|
||||||
|
else offset,
|
||||||
|
file_size,
|
||||||
|
*progress_args
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(chunk) < limit:
|
||||||
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -2074,7 +2082,8 @@ class Client(Methods, BaseClient):
|
|||||||
log.error(e, exc_info=True)
|
log.error(e, exc_info=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.remove(file_name)
|
if out:
|
||||||
|
os.remove(file_name)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -17,7 +17,9 @@
|
|||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -37,6 +39,7 @@ class DownloadMedia(BaseClient):
|
|||||||
message: Union["pyrogram.Message", str],
|
message: Union["pyrogram.Message", str],
|
||||||
file_ref: str = None,
|
file_ref: str = None,
|
||||||
file_name: str = DEFAULT_DOWNLOAD_DIR,
|
file_name: str = DEFAULT_DOWNLOAD_DIR,
|
||||||
|
out: io.IOBase = None,
|
||||||
block: bool = True,
|
block: bool = True,
|
||||||
progress: callable = None,
|
progress: callable = None,
|
||||||
progress_args: tuple = ()
|
progress_args: tuple = ()
|
||||||
@ -58,6 +61,9 @@ class DownloadMedia(BaseClient):
|
|||||||
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
||||||
are considered directories. All non-existent folders will be created automatically.
|
are considered directories. All non-existent folders will be created automatically.
|
||||||
|
|
||||||
|
out (``io.IOBase``, *optional*):
|
||||||
|
A custom *file-like object* to be used when downloading file. Overrides file_name
|
||||||
|
|
||||||
block (``bool``, *optional*):
|
block (``bool``, *optional*):
|
||||||
Blocks the code execution until the file has been downloaded.
|
Blocks the code execution until the file has been downloaded.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
@ -238,6 +244,13 @@ class DownloadMedia(BaseClient):
|
|||||||
extension
|
extension
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not out:
|
||||||
|
out = open(os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))), 'wb')
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
to_file = True
|
||||||
|
else:
|
||||||
|
to_file = False
|
||||||
|
self.download_queue.put((data, done, progress, progress_args, out, path, to_file))
|
||||||
# Cast to string because Path objects aren't supported by Python 3.5
|
# Cast to string because Path objects aren't supported by Python 3.5
|
||||||
self.download_queue.put((data, str(directory), str(file_name), done, progress, progress_args, path))
|
self.download_queue.put((data, str(directory), str(file_name), done, progress, progress_args, path))
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
import io
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Match, Union
|
from typing import List, Match, Union
|
||||||
|
|
||||||
@ -2985,6 +2985,7 @@ class Message(Object, Update):
|
|||||||
def download(
|
def download(
|
||||||
self,
|
self,
|
||||||
file_name: str = "",
|
file_name: str = "",
|
||||||
|
out: io.IOBase = None,
|
||||||
block: bool = True,
|
block: bool = True,
|
||||||
progress: callable = None,
|
progress: callable = None,
|
||||||
progress_args: tuple = ()
|
progress_args: tuple = ()
|
||||||
@ -3009,6 +3010,9 @@ class Message(Object, Update):
|
|||||||
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
You can also specify a path for downloading files in a custom location: paths that end with "/"
|
||||||
are considered directories. All non-existent folders will be created automatically.
|
are considered directories. All non-existent folders will be created automatically.
|
||||||
|
|
||||||
|
out (``io.IOBase``, *optional*):
|
||||||
|
A custom *file-like object* to be used when downloading file. Overrides file_name
|
||||||
|
|
||||||
block (``bool``, *optional*):
|
block (``bool``, *optional*):
|
||||||
Blocks the code execution until the file has been downloaded.
|
Blocks the code execution until the file has been downloaded.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
@ -3045,6 +3049,7 @@ class Message(Object, Update):
|
|||||||
return self._client.download_media(
|
return self._client.download_media(
|
||||||
message=self,
|
message=self,
|
||||||
file_name=file_name,
|
file_name=file_name,
|
||||||
|
out=out,
|
||||||
block=block,
|
block=block,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
progress_args=progress_args,
|
progress_args=progress_args,
|
||||||
|
Loading…
Reference in New Issue
Block a user