From 956e5c1a4f6a7657b325bf873be4666aad1ee25f Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Wed, 4 May 2022 09:04:25 +0200 Subject: [PATCH] Clean up on download's stop_transmission and return None --- pyrogram/__init__.py | 2 +- pyrogram/client.py | 37 +++++++++++++++++++++---------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/pyrogram/__init__.py b/pyrogram/__init__.py index 164d8c60..7dccf104 100644 --- a/pyrogram/__init__.py +++ b/pyrogram/__init__.py @@ -23,7 +23,7 @@ __copyright__ = "Copyright (C) 2017-present Dan " from concurrent.futures.thread import ThreadPoolExecutor -class StopTransmission(StopAsyncIteration): +class StopTransmission(Exception): pass diff --git a/pyrogram/client.py b/pyrogram/client.py index c12de075..c1d04046 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -727,23 +727,27 @@ class Client(Methods): async def handle_download(self, packet): file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet - file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb", delete=False) - async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args): - file.write(chunk) + try: + async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args): + file.write(chunk) + except pyrogram.StopTransmission: + if not in_memory: + file.close() + os.remove(file.name) - if file and not in_memory: - file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) - os.makedirs(directory, exist_ok=True) - file.close() - shutil.move(file.name, file_path) - - return file_path - - if file and in_memory: - file.name = file_name - return file + return None + else: + if in_memory: + file.name = file_name + return file + else: + file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) + os.makedirs(directory, exist_ok=True) + file.close() + shutil.move(file.name, file_path) + return file_path async def get_file( self, @@ -970,9 +974,10 @@ class Client(Methods): break except Exception as e: raise e + except pyrogram.StopTransmission: + raise except Exception as e: - if not isinstance(e, pyrogram.StopTransmission): - log.error(e, exc_info=True) + log.error(e, exc_info=True) def guess_mime_type(self, filename: str) -> Optional[str]: return self.mimetypes.guess_type(filename)[0]