diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 993b9bf9..6dabd1f4 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -276,35 +276,36 @@ class Client: log.debug("{} started".format(name)) while True: - message = self.download_queue.get() + media = self.download_queue.get() - if message is None: + if media is None: break - message, done = message + media, file_name, done = media try: - if isinstance(message.media, types.MessageMediaDocument): - document = message.media.document + if isinstance(media, types.MessageMediaDocument): + document = media.document if isinstance(document, types.Document): - file_name = "doc_{}{}".format( - datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"), - mimetypes.guess_extension(document.mime_type) or ".unknown" - ) + if not file_name: + file_name = "doc_{}{}".format( + datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"), + mimetypes.guess_extension(document.mime_type) or ".unknown" + ) - for i in document.attributes: - if isinstance(i, types.DocumentAttributeFilename): - file_name = i.file_name - break - elif isinstance(i, types.DocumentAttributeSticker): - file_name = file_name.replace("doc", "sticker") - elif isinstance(i, types.DocumentAttributeAudio): - file_name = file_name.replace("doc", "audio") - elif isinstance(i, types.DocumentAttributeVideo): - file_name = file_name.replace("doc", "video") - elif isinstance(i, types.DocumentAttributeAnimated): - file_name = file_name.replace("doc", "gif") + for i in document.attributes: + if isinstance(i, types.DocumentAttributeFilename): + file_name = i.file_name + break + elif isinstance(i, types.DocumentAttributeSticker): + file_name = file_name.replace("doc", "sticker") + elif isinstance(i, types.DocumentAttributeAudio): + file_name = file_name.replace("doc", "audio") + elif isinstance(i, types.DocumentAttributeVideo): + file_name = file_name.replace("doc", "video") + elif isinstance(i, types.DocumentAttributeAnimated): + file_name = file_name.replace("doc", "gif") tmp_file_name = self.get_file( dc_id=document.dc_id, @@ -313,18 +314,12 @@ class Client: version=document.version ) - i = 1 - while True: - try: - os.renames("./{}".format(tmp_file_name), "./downloads/{}".format( - ".".join(file_name.split(".")[:-1]) - + (" ({}).".format(i) if i > 1 else ".") - + file_name.split(".")[-1] - )) - except FileExistsError: - i += 1 - else: - break + try: + os.remove("./downloads/{}".format(file_name)) + except FileNotFoundError: + pass + + os.renames("./{}".format(tmp_file_name), "./downloads/{}".format(file_name)) done.set() except Exception as e: @@ -1940,6 +1935,9 @@ class Client: f.flush() os.fsync(f.fileno()) + if len(chunk) < limit: + break + offset += limit except Exception as e: log.error(e) @@ -2301,7 +2299,10 @@ class Client: ) ) - def download_media(self, message: types.Message): + def download_media(self, message: types.Message, file_name: str = None): done = Event() - self.download_queue.put((message, done)) + media = message.media if isinstance(message, types.Message) else message + + self.download_queue.put((media, file_name, done)) + done.wait()