diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 97c16196..2082029c 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -300,10 +300,10 @@ class Client: if media is None: break - media, file_name, done = media - tmp_file_name = "" - try: + media, file_name, done, progress, path = media + tmp_file_name = None + if isinstance(media, types.MessageMediaDocument): document = media.document @@ -331,15 +331,10 @@ class Client: dc_id=document.dc_id, id=document.id, access_hash=document.access_hash, - version=document.version + version=document.version, + size=document.size, + progress=progress ) - - try: - os.remove("./downloads/{}".format(file_name)) - except FileNotFoundError: - pass - - os.renames("./{}".format(tmp_file_name), "./downloads/{}".format(file_name)) elif isinstance(media, types.MessageMediaPhoto): photo = media.photo @@ -355,23 +350,32 @@ class Client: dc_id=photo_loc.dc_id, volume_id=photo_loc.volume_id, local_id=photo_loc.local_id, - secret=photo_loc.secret + secret=photo_loc.secret, + size=photo.sizes[-1].size, + progress=progress ) - try: - os.remove("downloads/{}".format(file_name)) - except FileNotFoundError: - pass + if file_name is not None: + path[0] = "downloads/{}".format(file_name) + try: + os.remove("downloads/{}".format(file_name)) + except OSError: + pass + finally: + try: os.renames("{}".format(tmp_file_name), "downloads/{}".format(file_name)) + except OSError: + pass except Exception as e: log.error(e, exc_info=True) finally: + print(done) done.set() try: os.remove("{}".format(tmp_file_name)) - except FileNotFoundError: + except OSError: pass log.debug("{} stopped".format(name)) @@ -1861,7 +1865,9 @@ class Client: volume_id: int = None, local_id: int = None, secret: int = None, - version: int = 0) -> str: + version: int = 0, + size: int = None, + progress: callable = None) -> str: if dc_id != self.dc_id: exported_auth = self.send( functions.auth.ExportAuthorization( @@ -1936,6 +1942,9 @@ class Client: offset += limit + if progress: + progress(offset, size) + r = session.send( functions.upload.GetFile( location=location, @@ -2007,10 +2016,13 @@ class Client: f.flush() os.fsync(f.fileno()) + offset += limit + + if progress: + progress(min(offset, size), size) + if len(chunk) < limit: break - - offset += limit except Exception as e: log.error(e) finally: @@ -2371,14 +2383,36 @@ class Client: ) ) - def download_media(self, message: types.Message, file_name: str = None, block: bool = True): - done = Event() - media = message.media if isinstance(message, types.Message) else message + def download_media(self, + message: types.Message, + file_name: str = None, + block: bool = True, + progress: callable = None): + """Use this method to download the media from a Message. - self.download_queue.put((media, file_name, done)) + Files are saved in the *downloads* folder. - if block: - done.wait() + Args: + message (:obj:`Message `): + The Message containing the media. + + file_name (:obj:`str`): + Specify a file_name to be used + """ + if isinstance(message, types.Message): + done = Event() + media = message.media + path = [None] + + if media is not None: + self.download_queue.put((media, file_name, done, progress, path)) + else: + return + + if block: + done.wait() + + return path[0] def add_contacts(self, contacts: list): """Use this method to add contacts to your Telegram address book. @@ -2408,8 +2442,8 @@ class Client: Args: ids (:obj:`list`): - A list of unique identifiers for the target users. Can be an ID (int), a username (string) - or phone number (string). + A list of unique identifiers for the target users. + Can be an ID (int), a username (string) or phone number (string). Returns: True on success.