Add support for progress callback when downloading media

This commit is contained in:
Dan 2018-02-24 17:16:25 +01:00
parent 2e4802fbda
commit ed4ff07742

View File

@ -300,10 +300,10 @@ class Client:
if media is None:
break
media, file_name, done = media
tmp_file_name = ""
try:
media, file_name, done, progress, path = media
tmp_file_name = None
if isinstance(media, types.MessageMediaDocument):
document = media.document
@ -331,15 +331,10 @@ class Client:
dc_id=document.dc_id,
id=document.id,
access_hash=document.access_hash,
version=document.version
version=document.version,
size=document.size,
progress=progress
)
try:
os.remove("./downloads/{}".format(file_name))
except FileNotFoundError:
pass
os.renames("./{}".format(tmp_file_name), "./downloads/{}".format(file_name))
elif isinstance(media, types.MessageMediaPhoto):
photo = media.photo
@ -355,23 +350,32 @@ class Client:
dc_id=photo_loc.dc_id,
volume_id=photo_loc.volume_id,
local_id=photo_loc.local_id,
secret=photo_loc.secret
secret=photo_loc.secret,
size=photo.sizes[-1].size,
progress=progress
)
try:
os.remove("downloads/{}".format(file_name))
except FileNotFoundError:
pass
if file_name is not None:
path[0] = "downloads/{}".format(file_name)
try:
os.remove("downloads/{}".format(file_name))
except OSError:
pass
finally:
try:
os.renames("{}".format(tmp_file_name), "downloads/{}".format(file_name))
except OSError:
pass
except Exception as e:
log.error(e, exc_info=True)
finally:
print(done)
done.set()
try:
os.remove("{}".format(tmp_file_name))
except FileNotFoundError:
except OSError:
pass
log.debug("{} stopped".format(name))
@ -1861,7 +1865,9 @@ class Client:
volume_id: int = None,
local_id: int = None,
secret: int = None,
version: int = 0) -> str:
version: int = 0,
size: int = None,
progress: callable = None) -> str:
if dc_id != self.dc_id:
exported_auth = self.send(
functions.auth.ExportAuthorization(
@ -1936,6 +1942,9 @@ class Client:
offset += limit
if progress:
progress(offset, size)
r = session.send(
functions.upload.GetFile(
location=location,
@ -2007,10 +2016,13 @@ class Client:
f.flush()
os.fsync(f.fileno())
offset += limit
if progress:
progress(min(offset, size), size)
if len(chunk) < limit:
break
offset += limit
except Exception as e:
log.error(e)
finally:
@ -2371,14 +2383,36 @@ class Client:
)
)
def download_media(self, message: types.Message, file_name: str = None, block: bool = True):
done = Event()
media = message.media if isinstance(message, types.Message) else message
def download_media(self,
message: types.Message,
file_name: str = None,
block: bool = True,
progress: callable = None):
"""Use this method to download the media from a Message.
self.download_queue.put((media, file_name, done))
Files are saved in the *downloads* folder.
if block:
done.wait()
Args:
message (:obj:`Message <pyrogram.api.types.Message>`):
The Message containing the media.
file_name (:obj:`str`):
Specify a file_name to be used
"""
if isinstance(message, types.Message):
done = Event()
media = message.media
path = [None]
if media is not None:
self.download_queue.put((media, file_name, done, progress, path))
else:
return
if block:
done.wait()
return path[0]
def add_contacts(self, contacts: list):
"""Use this method to add contacts to your Telegram address book.
@ -2408,8 +2442,8 @@ class Client:
Args:
ids (:obj:`list`):
A list of unique identifiers for the target users. Can be an ID (int), a username (string)
or phone number (string).
A list of unique identifiers for the target users.
Can be an ID (int), a username (string) or phone number (string).
Returns:
True on success.